# DeepMusic v1.0
# Jean-Pierre Briot
# 08/04/2019
# Representation

import numpy as np
from music21 import stream, note, interval, duration, corpus, converter

import config

# See also global variables within config

# Global variables
corpus_list = [None]						# list of music streams
corpus_size = None						# length
music_duration_quarter_length_v = []		# duration (number of quarters) for each music (vector)
time_signatures_dict = {'4/4': 0, '3/4': 0}		# time signatures count - Only 4/4 and 3/4 considered here
smallest_note_duration_quarter_length = None	# smallest note duration (in quarter_length) initial setting
number_parts = None					# number parts for all music of the corpus
max_number_quarters = None				# maximum number of quarters considered (notes not considered above that)
interval_semitones_v = None				# number of semitones between lowest and highest note pitches for each part (vector)
number_special_elements_one_hot = 2		# number of additional special elements in one hot encoding - usually 2: hold + rest
hold_index = 0							# one hot index of hold
rest_index = 1							# one hot index of rest
nothing_before_index = -1					# initial value of current index to mark the case of the first note (when creating a score)

# Constants
lowest_note_pitch = note.Note('C9').pitch		# lowest note pitch initial setting - as low as possible
highest_note_pitch = note.Note('C0').pitch		# highest note pitch initial setting - as high as possible
smallest_note_duration_quarter_length_initial = 100	# smallest note duration (in quarter_length) initial setting - as large as possible

# load corpus (music list)
# Ex: load_music_list(['bach/bwv344', 'bach/bwv345'])
#	-> [stream_bwv344, stream_bwv344]
def load_corpus(list_music_names):
	global corpus_list
	global corpus_size
	corpus_list = list(map(corpus.parse, list_music_names))
	if config.deep_music_analysis_verbose:
		print(str(len(corpus_list)) + ' music loaded.')
	return(corpus_list)

# load n chorales
# max: 371
# Ex: load_n_chorales(10)
#	-> [stream_bach/bwv269, stream_bach/bwv347 ... stream_bach/bwv38.6]
def load_n_chorales(n):
	list_music_names = []
	for chorale_name in corpus.chorales.Iterator(1, n, returnType='filename'):
		list_music_names.append(chorale_name)
	# remove the singular chorales, ex: #10 has 9 voices (parts)
	if len(list_music_names) >= 11:
		list_music_names.remove(list_music_names[10])
	return(load_corpus(list_music_names))

# transpose corpus into all keys
# Ex: transpose_corpus_in_all_keys(corpus_list)
#	-> corpus_list_TR_12
def transpose_corpus_into_all_keys(corpus_list):
	transposed_corpus_list = corpus_list[:]		# shallow copy to avoid additive extend
	for i in range(1, 12):			# 12
		transposed_corpus_list.extend(transpose_corpus(corpus_list, i))
	return(transposed_corpus_list)

# transpose corpus in a given key (number of semitones)
# Ex: transpose_corpus(corpus_list, 5)
#	-> corpus_list_TR_5
def transpose_corpus(corpus_list, semitones):
#	if config.deep_music_analysis_verbose:
	print('Transpose corpus of ' + str(len(corpus_list)) + ' music streams by ' + str(semitones))
	transposed_corpus_list = []
	for dummy, music in enumerate(corpus_list):
		transposed_corpus_list.append(music.transpose(semitones))
	return(transposed_corpus_list)

# analyze corpus (to analyze total_duration, smallest note duration, lowest and highest notes, for each part of each music)
# Ex: analyze_corpus(corpus_list)
#	-> None
def analyze_corpus(music_list):
	global number_parts
	global music_duration_quarter_length_v
	smallest_note_duration_quarter_length = smallest_note_duration_quarter_length_initial
	size_corpus = len(corpus_list)
#	music_duration_quarter_length_v = [0] * size_corpus
	for dummy in range(size_corpus):
		music_duration_quarter_length_v.append(0)
	number_parts = len(corpus_list[0].getElementsByClass('Part'))
	config.lowest_note_pitch_v = [lowest_note_pitch] * number_parts
	config.highest_note_pitch_v = [highest_note_pitch] * number_parts
	interval_semitones_v = [None] * number_parts
	config.one_hot_size_v = [None] * number_parts
	for m, music in enumerate(corpus_list):
		music_duration_quarter_length = None
		part_duration_quarter_length_v = [None] * number_parts
		part_duration_quarter_length = 0
		parts_list = music.getElementsByClass('Part')
		if len(parts_list) != number_parts:
			raise ValueError('Music within corpus have different number of parts')
		else:
			measures_list = parts_list[0].getElementsByClass('Measure')		# all measures of the first part
			time_signature_object = measures_list[0].timeSignature
				# we assume that there is only one time signature
				# and that it is specified in the 1st measure (of the first part)
				# TODO: more generic/safe
			time_signature = str(time_signature_object.numerator) + '/' + str(time_signature_object.denominator)
			time_signatures_dict[time_signature] = time_signatures_dict[time_signature] + 1	# count the number of each time signature
			if config.deep_music_analysis_verbose:
				print(str(music) + ' is in ' +  str(time_signature) + ', has ' + str(number_parts) + ' parts and ' + str(len(measures_list)) + ' measures.')
			for p, part in enumerate(parts_list):
				part_duration_quarter_length = 0
				measures_list = part.getElementsByClass('Measure')
				if len(measures_list) == 0:
					raise ValueError(str(part) + ' has no measures')
				else:
					for measure in measures_list:
						notes_list = measure.getElementsByClass('Note')
						for note in notes_list:
							if interval.Interval(config.lowest_note_pitch_v[p], note.pitch).semitones < 0:
								config.lowest_note_pitch_v[p] = note.pitch
							if interval.Interval(config.highest_note_pitch_v[p], note.pitch).semitones > 0:
								config.highest_note_pitch_v[p] = note.pitch
							if note.quarterLength < smallest_note_duration_quarter_length:
								smallest_note_duration_quarter_length = note.quarterLength
							part_duration_quarter_length = part_duration_quarter_length + note.quarterLength
						rests_list = measure.getElementsByClass('Rest')
						for rest in rests_list:
							if rest.quarterLength < smallest_note_duration_quarter_length:
								smallest_note_duration_quarter_length = rest.quarterLength
							part_duration_quarter_length = part_duration_quarter_length + rest.quarterLength
							# we assume that rests are explicit - we do not need to consider the offsets
							# TODO: more generic/safe
				part_duration_quarter_length_v[p] = part_duration_quarter_length
			music_duration_quarter_length = part_duration_quarter_length_v[0]
			for p in range(number_parts - 1):
				if part_duration_quarter_length_v[p] != music_duration_quarter_length:
					raise ValueError('Parts of ' + str(music) + ' have different lengths')
			else:
				music_duration_quarter_length_v[m] = music_duration_quarter_length
	for p in range(number_parts):
		interval_semitones_v[p] = interval.Interval(config.lowest_note_pitch_v[p], config.highest_note_pitch_v[p]).semitones
										# number of elements = (length of interval + 1)
		config.one_hot_size_v[p] = interval_semitones_v[p] + 1 + number_special_elements_one_hot
	config.time_step_quarter_length = smallest_note_duration_quarter_length
	if config.deep_music_analysis_verbose:
		print('Quarter length smallest note duration: ' + str(smallest_note_duration_quarter_length) + ' ; quarter length parts ; lowest notes pitches: ' + str(config.lowest_note_pitch_v) + ' ; highest notes pitches: ' + str(config.highest_note_pitch_v))
					# we assume there is consistency between music pieces within the corpus
					# TODO: more generic/safe

# analyze a single soprano part
# this is necessary to possibly reset the lowest and highest notes - for chorale generation from an arbitrary soprano part
# Ex: analyze_soprano_part(soprano_part))
#	-> None
def analyze_soprano_part(part):
	measures_list = part.getElementsByClass('Measure')
	if len(measures_list) == 0:
		print('Warning: ' + str(part) + ' has no measures.')
		notes_list = list(part.getElementsByClass('Note'))
	else:
		notes_list = []
		for measure in measures_list:
			notes_list.extend(list(measure.getElementsByClass('Note')))
	for note in notes_list:
		if interval.Interval(config.lowest_note_pitch_v[0], note.pitch).semitones < 0:
			config.lowest_note_pitch_v[0] = note.pitch
			print('Warning: New soprano part lowest pitch: ' + str(note.pitch))
		if interval.Interval(config.highest_note_pitch_v[0], note.pitch).semitones > 0:
			config.highest_note_pitch_v[0] = note.pitch
			print('Warning: New soprano part highest pitch: ' + str(note.pitch))
	interval_semitones = interval.Interval(config.lowest_note_pitch_v[0], config.highest_note_pitch_v[0]).semitones
					# number of elements = (length of interval + 1)
	config.one_hot_size_v[0] = interval_semitones + 1 + number_special_elements_one_hot
	if config.deep_music_analysis_verbose:
		print('Soprano part new lowest note pitch: ' + str(config.lowest_note_pitch_v[0]) + ' ; highest note pitch: ' + str(config.highest_note_pitch_v[0]))

# encode corpus
# Ex: encode_corpus(corpus_list)
#	-> [encode_music(corpus_list[0]), ... encode_music(corpus_list[last])]
def encode_corpus(music_list):
	return(list(map(encode_music, music_list)))

# encode music
# Ex: encode_music(music)
# music is a Music21 stream
#	-> [encode_part(0, music[0]), ... encode_part(last, music[last])]
def encode_music(music):
					# we assume that all music of the corpus have the same number of parts
					# TODO: check that
	parts_list = music.getElementsByClass('Part')
	parts_data = []
	for i, part in enumerate(parts_list):
		parts_data.append(encode_part(i, part))
	if config.deep_music_analysis_verbose:
		print('Has encoded ' + str(music) + ' with ' + str(len(parts_list)) + ' parts.')
	return(parts_data)

# encode part
# Ex: encode_part(index_part, part)
#	-> [[...], [...], ... [...]]
def encode_part(index_part, part):
	# rests are not considered (yet)
	# TODO: consider rests
	measures_list = part.getElementsByClass('Measure')
	if len(measures_list) == 0:
		print('Warning: ' + str(part) + ' has no measures.')
		notes_list = list(part.getElementsByClass(['Note', 'Rest']))			# notes or rests
	else:
		notes_list = []
		for measure in measures_list:
			notes_list.extend(list(measure.getElementsByClass(['Note', 'Rest'])))
	return(encode_notes_list(index_part, notes_list))

# encode a list of notes (and rests)
# Ex: encode_notes_list(0, [<music21.note.Note C>, ...])
#	-> [[0, 0, 0, ... 0, (:) 0, 0, 1, ... 0 (:) ... (:) 0, 0, 0, ... 0]
def encode_notes_list(index_part, notes_list):
	lowest_note_pitch = config.lowest_note_pitch_v[index_part]
	highest_note_pitch = config.highest_note_pitch_v[index_part]
	one_hot_size = config.one_hot_size_v[index_part]
	encoded_hold = encode_hold(one_hot_size)
	number_encoded_notes = 0
	number_time_steps = 0
	data_list = []
	for note in notes_list:
		if number_time_steps >= config.max_number_time_steps:
			break
			# we assume that rests are explicit - we do not need to consider the offsets
			# TODO: more generic/safe
		else:
			if note.isNote:
				if interval.Interval(lowest_note_pitch, note.pitch).semitones < 0:
					print(str(note.pitch) + ' is lower than lowest pitch ' + str(lowest_note_pitch))
					raise ValueError(str(note.pitch) + ' is lower than lowest pitch ' + str(lowest_note_pitch))
				elif interval.Interval(highest_note_pitch, note.pitch).semitones > 0:
					print(str(note.pitch) + ' is higher than highest pitch ' + str(highest_note_pitch))
					raise ValueError(str(note.pitch) + ' is higher than highest pitch ' + str(highest_note_pitch))
				index_note = interval.Interval(lowest_note_pitch, note.pitch).semitones + number_special_elements_one_hot
			elif note.isRest:
				index_note = rest_index
			else:
				raise ValueError(str(note) + ' is neither a Note or a Rest')
			data_list.extend(encode_note_index(index_note, one_hot_size))
			number_time_steps = number_time_steps + 1
			number_hold_time_steps = int((note.quarterLength / config.time_step_quarter_length)) - 1
			for dummy in range(number_hold_time_steps):
				if number_time_steps >= config.max_number_time_steps:
					break
				else:
					data_list.extend(encoded_hold)
					number_time_steps = number_time_steps + 1
			number_encoded_notes = number_encoded_notes + 1				# notes or rests
	if config.deep_music_analysis_verbose:
		print('Has encoded part ' + str(index_part) + ' with ' + str(number_encoded_notes) + ' notes or rests and ' + str(number_time_steps) + ' time steps.')
	return(data_list)

# encode a note in one hot
# Ex: encode_note(<music21.note.Note C>, 10)
#	-> [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
def encode_note(note, size):
	note_index = interval.Interval(config.lowest_note_pitch_v[0], note.pitch).semitones + number_special_elements_one_hot
	return(encode_note_index(note_index, size))

# encode a note index in one hot
# Ex: encode_note_index(3, 10)
#	-> [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
def encode_note_index(note_index, size):
	if (note_index >= size) or (note_index < 0):				# warning - bug - TODO
		print('Warning: ' + str(note_index) + ' is outside of one hot size: ' + str(size) + '.')
		return(encode_rest(size))			# replace note by a rest
	one_hot = [0] * size
	one_hot[note_index] = 1
	return(one_hot)

# encode hold in one hot
# Ex: encode_hold(0, 10)
#	-> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
def encode_hold(size):
	global hold_index
	return(encode_note_index(hold_index, size))

# encode rest in one hot
# Ex: encode_rest(1, 10)
#	-> [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
def encode_rest(size):
	global rest_index
	return(encode_note_index(rest_index, size))

# create scores
# Ex: create_scores(1, [<data_music1>, ... <data_musicN>])
# [[<data_part1>, ... <data_partP>], ... [<data_part1>, ... <data_partP>]]
# [[[0, 0 ... 0, (:) ... (:) 0, 0 ... 0], ... [0, 0 ... 0, (:) ... (:) 0, 0 ... 0]], ... [[0, 0 ... 0, (:) ... (:) 0, 0 ... 0], ... [0, 0 ... 0, (:) ... (:) 0, 0 ... 0]]]
#	-> [create_score(1, <data_music1>), ... create_score(1, <data_musicN>)]
def create_scores(number_parts, music_parts_data_list):
	scores_list = []
	for i in range(len(music_parts_data_list)):
		score = create_score(number_parts, music_parts_data_list[i])
		scores_list.append(score)
	return(scores_list)

# create score from encoded data
# Ex: create_score(<P>, [<data_part1>, ... <data_partP>])
# [[0, 0 ... 0, (:) ... (:) 0, 0 ... 0], ... [0, 0 ... 0, (:) ... (:) 0, 0 ... 0]]
#	-> score (stream)
def create_score(number_parts, parts_data_list):
	score  = stream.Score(id = 'Generated score')
	for i in range(number_parts):
		part = stream.Part(id = 'Part ' + str(i))
		score.insert(0, part)
		create_score_part(i, part, parts_data_list[i])
	return(score)

# create part of a score
# Ex: create_score_part(0, <stream>, <data>)
# [0, 0 ... 0, (:) ... (:) 0, 0 ... 0]
#	-> None
def create_score_part(index_part, part, part_data):
	one_hot_size = config.one_hot_size_v[index_part]
	lowest_note_pitch = config.lowest_note_pitch_v[index_part]
	current_index = nothing_before_index
	current_time_steps = 0
	offset_time_steps = 0
	for t in range(int(len(part_data) / one_hot_size)):
			# TODO: check consistency in data length and total (or/and maximum) number of times steps
		index = decode_one_hot(part_data[(t * one_hot_size):((t + 1) * one_hot_size)])
		if index == hold_index:
			if current_index == nothing_before_index:
#				raise ValueError('Hold index after nothing')
				print('Warning: Hold index after nothing in part ' + str(index_part + 1) + '.')
			current_time_steps = current_time_steps + 1
			offset_time_steps = offset_time_steps + 1
		elif index == rest_index:
			'Case of rests not considered yet'		# rests are implicitly considered via the offsets of the notes
											# TODO: consider (mark) explicitly rests
		else:
			if current_index == nothing_before_index:
				current_index = index
				current_time_steps = 1
				offset_time_steps = offset_time_steps + 1
			elif current_time_steps > 0:
				current_note = create_note_with(lowest_note_pitch, current_index, current_time_steps, offset_time_steps)
				part.append(current_note)
				current_index = index
				current_time_steps = 1
				offset_time_steps = offset_time_steps + 1
			else:
				raise ValueError('Strange case')
	if current_time_steps > 0:
		current_note = create_note_with(lowest_note_pitch, current_index, current_time_steps, offset_time_steps)
		part.append(current_note)

# create note
# Ex: create_note_with(<music21.pitch.Pitch A4>, 5, time_steps, offset_time_steps)
#	-> <music21.note.Note C>
def create_note_with(lowest_note_pitch, index, time_steps, offset_time_steps):
	new_note = note.Note()
	new_note.pitch = lowest_note_pitch.transpose(index - number_special_elements_one_hot)
	new_note.quarterLength = time_steps * config.time_step_quarter_length
	new_note.offset = float(offset_time_steps * config.time_step_quarter_length)
	return(new_note)

# decode one hot into a corresponding index
# Ex: decode_one_hot([0.2, 0.3, 0.7, 0.1 ... 0.2])
#	-> 2
def decode_one_hot(one_hot_encoding):
	if config.is_deterministic:
		return(decode_deterministic_one_hot(one_hot_encoding))
	else:
		return(decode_sample_one_hot(one_hot_encoding))

# decode one hot into a corresponding index
# deterministic interpretation: the index with the highest value (argmax)
def decode_deterministic_one_hot(one_hot_encoding):
	return(np.argmax(one_hot_encoding))

# decode one hot into a corresponding index
# non deterministic interpretation: sampling according to probabilities - warning: one hot vector must be a probability distribution
def decode_sample_one_hot(one_hot_encoding):
	one_hot_encoding = np.asarray(one_hot_encoding).astype('float64')
	try:
		one_hot_encoding = np.random.multinomial(1, one_hot_encoding)
	except ValueError as err:
		config.multinomial_error_number = config.multinomial_error_number + 1
	config.multinomial_sampling_number = config.multinomial_sampling_number + 1
	return(decode_deterministic_one_hot(one_hot_encoding))

# split training dataset into training and test dataset
# Ex : split_training_test(X_train, y_train, 4)
#	-> (X_train, y_train, X_test, y_test)
# Warning: works only if hierarchic structures of X_train and y_train in respect to the list of chorales
# TODO: we should shuffle the examples
def split_training_test(X_train, y_train, train_test_ratio):
	train_test_index = int((len(X_train) * train_test_ratio) / (train_test_ratio + 1))
	return(X_train[0:train_test_index], y_train[0:train_test_index], X_train[train_test_index:], y_train[train_test_index:])
