# Bach Chorale example
# Feedforward
# Jean-Pierre Briot
# 05/04/2019

from keras.models import Sequential
from keras.layers import Dense, Activation

from representation import *
from metrics import *

config.deep_music_analysis_verbose = False	# analysis verbose flag - default value = False
config.deep_music_training_verbose = True	# training verbose flag - default value = False
config.deep_music_generate_verbose = True	# generation verbose flag - default value = False

corpus_type = 1
number_chorales = 80
is_transposed = False
is_aligned = False
number_epochs = 50
number_chorales_regenerated = 2
alignement_key = key.Key('C')

if is_transposed:
	transposed_name_extension = 'tr_'
elif is_aligned:
	transposed_name_extension = 'al_'
else:
	transposed_name_extension = ''

print('Loading and parsing the corpus (' + str(number_chorales) + ' Bach chorales)')

if corpus_type == 1:
	corpus_names_list = []
	for i in range(number_chorales):				# max: 80
		corpus_names_list.append('bach/bwv' + str(344 + i))
	corpus_list = load_corpus(corpus_names_list)
else:		# corpus_type = 2
	corpus_list = load_n_chorales(number_chorales)	# max = 371

if is_transposed:
	print('Transpose/Extend the corpus (Bach chorales) into all keys')
	corpus_list = transpose_corpus_into_all_keys(corpus_list)

if is_aligned:
	print('Align corpus onto ' + str(alignement_key) + ' key')
	corpus_list = align_corpus(corpus_list, alignement_key)

print('Analyze the corpus (Bach chorales)')

analyze_corpus(corpus_list)

print('Durations (in quarter length): ' + str(music_duration_quarter_length_v))

max_number_quarters = min(music_duration_quarter_length_v)

config.max_number_time_steps = int(max_number_quarters / config.time_step_quarter_length)

print('Max number quarters (in quarter length): ' + str(max_number_quarters))
print('Max number time steps: ' + str(config.max_number_time_steps))

print('Parse Brazilian hymn')

print('You have to download file National_Anthems_-_Brazil.mid and to place it within a folder named mid/')

br_score = converter.parse('mid/National_Anthems_-_Brazil.mid')
br_parts = br_score.getElementsByClass('Part')
br_soprano_part = br_parts[0]

print('Reanalyze/compute lowest and highest pitches of soprano part')

analyze_soprano_part(br_soprano_part)

print('Encode the corpus (Bach chorales)')

corpus_data = encode_corpus(corpus_list)
# structure: hierarchy: music/part/encoded_data

print('Construction of the training and validation (test) datasets')

X_train = []
y_train = []

for i in range(len(corpus_data)):
	music_data = corpus_data[i]
	X_train.append(music_data[0])
	y_train.append(music_data[1] + music_data[2] + music_data[3])

# X_train : [[0, 0, 0 ..  0, (:) ... (:) 0, 0, 0 .. 0]  ... [0, 0, 0 ..  0, (:) ... (:) 0, 0, 0 .. 0]]
#                  chorale 1				       chorale P
#		   time step 1		 time step N	time step 1	      time step N

# y_train : [[0, 0, 0 ..  0, (:) ... (:) 0, 0, 0 .. 0, (::) 0, 0, 0 ..  0 (:) ... (:) 0, 0, 0 .. 0, (::) 0, 0, 0 ..  0, (:) ... (:) 0, 0, 0 .. 0]
#                chorale 1 part 1			       chorale 1 part 2				  chorale 1 part 3
#		  time step 1		time step N      time step 1		    time step N     time step 1		time step N
#		 [0, 0, 0 ..  0, (:) ... (:) 0, 0, 0 .. 0, (::) 0, 0, 0 ..  0 (:) ... (:) 0, 0, 0 .. 0, (::) 0, 0, 0 ..  0, (:) ... (:) 0, 0, 0 .. 0]
#		 chorale 2
#		 (:::)
#		 [0, 0, 0 ..  0, (:) ... (:) 0, 0, 0 .. 0, (::) 0, 0, 0 ..  0 (:) ... (:) 0, 0, 0 .. 0, (::) 0, 0, 0 ..  0, (:) ... (:) 0, 0, 0 .. 0]]
#		 chorale P

X_train, y_train, X_test, y_test = split_training_test(X_train, y_train, 4)

# numpy representations

X_train = np.array(X_train)
y_train = np.array(y_train)
X_test = np.array(X_test)
y_test = np.array(y_test)

print('X_train shape: ' + str(X_train.shape))		# (64, 3520)
print('y_train shape: ' + str(y_train.shape))			# (64, 11520)
print('X_test shape: ' + str(X_test.shape))			# (16, 11520)
print('y_test shape: ' + str(y_test.shape))			# (16, 11520)

# Define the deep learning model/architecture

input_size = config.max_number_time_steps * config.one_hot_size_v[0]

output_size = config.max_number_time_steps * (config.one_hot_size_v[1] + config.one_hot_size_v[2] + config.one_hot_size_v[3])

print('Input size: ' + str(input_size))
print('Output size: ' + str(output_size))

# Input size: 3520 = 22 * 4 * 4 * 10
#				one_hot_size_[v0] * (time_slice / 4) * max_number_quarters
#											(4 * max number measures)
# Output size: 11520 = (21 + 23 + 28) * 4 * 4 * 10

hidden_layer_1_size = 500
hidden_layer_2_size = 500

print('Define and create the model/network')

model = Sequential()
model.add(Dense(hidden_layer_1_size,
	input_shape = (input_size, )))
model.add(Activation('sigmoid'))
model.add(Dense(hidden_layer_2_size))
model.add(Activation('sigmoid'))
model.add(Dense(output_size))
model.add(Activation('sigmoid'))

model.compile(loss = 'binary_crossentropy',
	optimizer = 'rmsprop',
	metrics =	['accuracy'])

# Training

print('Training the model/network')

history = model.fit(X_train,
	y_train,
	batch_size = 30,
	epochs = number_epochs,
	verbose = config.deep_music_training_verbose,
	validation_data = (X_test, y_test))

print('Model trained')

if config.deep_music_training_verbose:
	print('Show metrics')
	show_metrics(model, history, X_train, y_train, X_test, y_test)

# create scores of chorales from a list of soprano parts and associated data
# new definition - more refactoring - but generates non sense score
# TODO
# Ex: create_chorales_counterpoints_from_soprano_data([data1, ... dataN], [sop_part1, ... sop_partN])
#	-> [score1, ... scoreN]
def create_chorales_counterpoints_from_soprano_data(soprano_data_list, soprano_parts_list):
	if config.deep_music_generation_verbose:
		print('Create counterpoints of ' + str(len(soprano_parts_list)) + ' soprano parts')
	counterpoint_data_list = model.predict(soprano_data_list)
	list_scores = []
	for i in range(len(soprano_data_list)):
		counterpoint_data = counterpoint_data_list[i]
		alto_data_end_index = config.max_number_time_steps * config.one_hot_size_v[1]
		tenor_data_end_index = alto_data_end_index + (config.max_number_time_steps * config.one_hot_size_v[2])
		bass_data_end_index = tenor_data_end_index + (config.max_number_time_steps * config.one_hot_size_v[3])
		score = create_score(3, [counterpoint_data[0:alto_data_end_index],
							counterpoint_data[alto_data_end_index:tenor_data_end_index],
							counterpoint_data[tenor_data_end_index:bass_data_end_index]])
#		score.insert(0, soprano_parts_list[i])
		list_scores.append(score)
	return(list_scores)

# previous definition - less refactoring - but works
def create_chorales_counterpoints_from_soprano_data(soprano_data_list, soprano_parts_list):
	if config.deep_music_generation_verbose:
		print('Create counterpoints of ' + str(len(soprano_parts_list)) + ' soprano parts')
	number_music = len(soprano_data_list)
	list_scores = []
	counterpoint_data_list = model.predict(soprano_data_list)
	for i in range(number_music):
		counterpoint_data = counterpoint_data_list[i]
		score = stream.Score(id = 'Deep Chorale ' + str(i))
		soprano_part = soprano_parts_list[i]
		alto_part = stream.Part(id = 'Alto')
		tenor_part = stream.Part(id = 'Tenor')
		bass_part = stream.Part(id = 'Bass')
		score.insert(0, soprano_part)
		score.insert(0, alto_part)
		score.insert(0, tenor_part)
		score.insert(0, bass_part)
		alto_data_end_index = config.max_number_time_steps * config.one_hot_size_v[1]
		tenor_data_end_index = alto_data_end_index + (config.max_number_time_steps * config.one_hot_size_v[2])
		bass_data_end_index = tenor_data_end_index + (config.max_number_time_steps * config.one_hot_size_v[3])
		create_score_part(1, alto_part, counterpoint_data[0:alto_data_end_index])
		create_score_part(2, tenor_part, counterpoint_data[alto_data_end_index:tenor_data_end_index])
		create_score_part(3, bass_part, counterpoint_data[tenor_data_end_index:bass_data_end_index])
		list_scores.append(score)
	return(list_scores)

# create scores of chorales from a list of soprano parts
# Ex: create_chorales_counterpoints_from_soprano_parts([sop_part1, ... sop_partN])
#	-> [score1, ... scoreN]
def create_chorales_counterpoints_from_soprano_parts(soprano_parts_list):
	soprano_data_list = []
	for i in range(len(soprano_parts_list)):
		soprano_data_list.append(encode_part(0, soprano_parts_list[i]))
	return(create_chorales_counterpoints_from_soprano_data(np.array(soprano_data_list), soprano_parts_list))

print('Write ' + str(number_chorales_regenerated) + ' first (training) and ' + str(number_chorales_regenerated) + ' last (test) chorales')

# Warning/TODO: we assume that first are training and last are test - does not hold if shuffling when training/test split

train_list = corpus_list[0:number_chorales_regenerated]
test_list = corpus_list[- number_chorales_regenerated:]

for i in range(number_chorales_regenerated):
	train_list[i].write('midi', 'mid/forward_ch_' + str(corpus_type) + '_' + str(transposed_name_extension) + 'train_' + str(i + 1) + '.mid')

for i in range(number_chorales_regenerated):
	test_list[i].write('midi', 'mid/forward_ch_' + str(corpus_type) + '_' + str(transposed_name_extension) + 'test_' + str(i + 1) + '.mid')

print('Regenerate chorales counterpoints from soprano melodies - From training dataset and from test dataset')

corpus_soprano_parts_train_list = []
corpus_soprano_parts_test_list = []

for i in range(number_chorales_regenerated):
	train_music = train_list[i]
	test_music = test_list[i]
	train_parts = train_music.getElementsByClass('Part')
	test_parts = test_music.getElementsByClass('Part')
	corpus_soprano_parts_train_list.append(train_parts[0])
	corpus_soprano_parts_test_list.append(test_parts[0])

regenerated_chorales_train_scores_list = create_chorales_counterpoints_from_soprano_data(X_train[0:number_chorales_regenerated], corpus_soprano_parts_train_list)

regenerated_chorales_test_scores_list = create_chorales_counterpoints_from_soprano_data(X_test[- number_chorales_regenerated:], corpus_soprano_parts_test_list)

print('Write all regenerated chorales')

for i in range(number_chorales_regenerated):
	regenerated_chorales_train_scores_list[i].write('midi', 'mid/forward_ch_' + str(corpus_type) + '_' + str(transposed_name_extension) + 'train_' + str(i + 1) + '_re.mid')

for i in range(number_chorales_regenerated):
	regenerated_chorales_test_scores_list[i].write('midi', 'mid/forward_ch_' + str(corpus_type) + '_' + str(transposed_name_extension) + 'test_' + str(i + 1) + '_re.mid')

print('Generate chorale counterpoint from Brazilian hymn original soprano voice')

br_chorale_score = create_chorales_counterpoints_from_soprano_parts([br_soprano_part])[0]

print('Write chorale generated from Brazilian hymn original soprano voice')

br_chorale_score.write('midi', 'mid/forward_ch_br.mid')

print('Generate chorale counterpoint from Brazilian hymn original soprano voice slower and removing intro')

slower_br_soprano_part = br_soprano_part.scaleOffsets(2)

for dummy, element in enumerate(slower_br_soprano_part[3:]):
	if element.offset >= 18:
		break
	else:
		slower_br_soprano_part.remove(element, shiftOffsets = True)

for dummy, element in enumerate(slower_br_soprano_part[3:]):
	element.setOffsetBySite(slower_br_soprano_part, element.offset - 8.0)

slower_br_chorale_score = create_chorales_counterpoints_from_soprano_parts([slower_br_soprano_part])[0]

print('Write chorale generated from Brazilian hymn original soprano voice slower and removing intro')

slower_br_chorale_score.write('midi', 'mid/forward_ch_br_2.mid')
