# Bach Chorale example
# Autoencoder decoding
# Variational code copied and adapted from Keras example (François Chollet), 2016
# Jean-Pierre Briot
# 05/04/2019

import random
import matplotlib.pyplot as plt

from keras.layers import Lambda, Input, Dense
from keras.models import Model
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model

from keras import backend as K

from representation import *
from metrics import *

config.deep_music_analysis_verbose = False	# analysis verbose flag - default value = False
config.deep_music_training_verbose = True	# training verbose flag - default value = False
config.deep_music_generate_verbose = True	# generation verbose flag - default value = False

corpus_type = 1
number_chorales = 80
is_transposed = False
number_epochs = 50
number_chorales_regenerated = 2					# to be doubled (training and test)

if is_transposed:
	transposed_name_extension = 'tr_'
else:
	transposed_name_extension = ''

print('Loading and parsing the corpus (' + str(number_chorales) + ' Bach chorales)')

if corpus_type == 1:
	corpus_names_list = []
	for i in range(number_chorales):				# max: 80
		corpus_names_list.append('bach/bwv' + str(344 + i))
	corpus_list = load_corpus(corpus_names_list)
else:		# corpus_type = 2
	corpus_list = load_n_chorales(number_chorales)	# max = 371

if is_transposed:
	print('Transpose/Extend the corpus (Bach chorales) into all keys')
	corpus_list = transpose_corpus_into_all_keys(corpus_list)

print('Analyze the corpus (Bach chorales)')

analyze_corpus(corpus_list)

print('Durations (in quarter length): ' + str(music_duration_quarter_length_v))

max_number_quarters = min(music_duration_quarter_length_v)

config.max_number_time_steps = int(max_number_quarters / config.time_step_quarter_length)

print('Max number quarters (in quarter length): ' + str(max_number_quarters))
print('Max number time steps: ' + str(config.max_number_time_steps))

print('Encode the corpus (Bach chorales)')

corpus_data = encode_corpus(corpus_list)
# structure: hierarchy: music/part/encoded_data

print('Construction of the training and validation (test) datasets')

X_train = []
y_train = []

for i in range(len(corpus_data)):
	X_train.append(corpus_data[i][0])

# X_train : [[0, 0, 0 ..  0, (:) ... (:) 0, 0, 0 .. 0], ... [0, 0, 0 ..  0, (:) ... (:) 0, 0, 0 .. 0]]
#                  chorale 1				       chorale P
#		   time step 1  		 time step N     time step 1		      time step N

X_train, y_train, X_test, y_test = split_training_test(X_train, X_train, 4)

# numpy representations

X_train = np.array(X_train)
y_train = X_train
X_test = np.array(X_test)
y_test = X_test

print('X_train shape: ' + str(X_train.shape))		# (64, 3520)
print('y_train shape: ' + str(y_train.shape))			# (64, 3520)
print('X_test shape: ' + str(X_test.shape))			# (16, 3520)
print('y_test shape: ' + str(y_test.shape))			# (16, 3520)

# Define the deep learning model/architecture

input_size = config.max_number_time_steps * config.one_hot_size_v[0]

output_size = input_size

print('Input size: ' + str(input_size))
print('Output size: ' + str(output_size))

# Input size: 3520 = 22 * 4 * 4 * 10
#				one_hot_size_[v0] * (time_slice / 4) * max_number_quarters
#											(4 * max number measures)
# Output size = Input size = 3520

print('Define and create the autoencoder model/network')

# network parameters
input_shape = (input_size, )
intermediate_dim = 512
batch_size = 20
latent_dim = 2
epochs = 50

# reparameterization trick
# instead of sampling from Q(z|X), sample eps = N(0,I)
# z = z_mean + sqrt(var)*eps
def sampling(args):
	"""Reparameterization trick by sampling fr an isotropic unit Gaussian.

	# Arguments
	args (tensor): mean and log of variance of Q(z|X)

	# Returns
	z (tensor): sampled latent vector
	"""

	z_mean, z_log_var = args
	batch = K.shape(z_mean)[0]
	dim = K.int_shape(z_mean)[1]
	# by default, random_normal has mean=0 and std=1.0
	epsilon = K.random_normal(shape=(batch, dim))
	return(z_mean + K.exp(0.5 * z_log_var) * epsilon)

# VAE model = encoder + decoder
# build encoder model

inputs = Input(shape = input_shape,
	name = 'encoder_input')
x = Dense(intermediate_dim,
	activation = 'relu')(inputs)
z_mean = Dense(latent_dim,
	name = 'z_mean')(x)
z_log_var = Dense(latent_dim,
	name = 'z_log_var')(x)

# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary with the TensorFlow backend
z = Lambda(sampling,
	output_shape = (latent_dim,),
	name = 'z')([z_mean, z_log_var])

# instantiate encoder model
encoder = Model(inputs,
	[z_mean, z_log_var, z],
	name = 'encoder')
encoder.summary()

# build decoder model
latent_inputs = Input(shape = (latent_dim,),
	name = 'z_sampling')
x = Dense(intermediate_dim,
	activation = 'relu')(latent_inputs)
outputs = Dense(input_size,
	activation = 'sigmoid')(x)

# instantiate decoder model
decoder = Model(latent_inputs,
	outputs,
	name = 'decoder')
decoder.summary()

# instantiate VAE model
outputs = decoder(encoder(inputs)[2])

vae = Model(inputs,
	outputs,
	name = 'vae_mlp')

models = (encoder, decoder)
data = (X_test, y_test)

reconstruction_loss = mse(inputs, outputs)

reconstruction_loss *= input_size
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss,
			axis = -1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)

vae.compile(optimizer = 'adam')

vae.summary()

# Training

print('Training the model/network')

vae.fit(X_train,
	epochs = epochs,
	batch_size = batch_size,
	verbose = config.deep_music_training_verbose,
	validation_data = (X_test, None))
#	vae.save_weights('vae_mlp_mnist.h5')

print('Model trained')

def create_melodies_from_labels(labels_list):
	data_list = decoder.predict(np.array(labels_list))
	data_list2 = []
	for i in range(len(data_list)):			# transforms	[[0, 0 .. 0 (:) ... (:) 0, 0 .. 0], ... [0, 0 .. 0 (:) ... (:) 0, 0 .. 0]]
		data_list2.append([data_list[i]])		# into		[[[0, 0 .. 0 (:) ... (:) 0, 0 .. 0]], ... [[0, 0 .. 0 (:) ... (:) 0, 0 .. 0]]]
	return(data_list2)

print('Compute the ranges of the variational autoencoder latent variables')

z_mean_data, z_log_var_data, z_data = encoder.predict(np.concatenate([X_train, X_test]))

min_z1 = min(z_data[:, 0])
max_z1 = max(z_data[:, 0])
min_z2 = min(z_data[:, 1])
max_z2 = max(z_data[:, 1])

print('z1: [' + str(min_z1) + ', ' + str(max_z1) + '] z2: [' + str(min_z2) + ', ' + str(max_z2) + ']')

print('Create specific melodies')

labels_list_extreme = [[min_z1, min_z2], [min_z1, max_z2], [max_z1, max_z2], [max_z1, min_z2], [0, 0]]

def generate_interpolation(min, max, steps):
	list = []
	for i in range(steps + 1):
		list.append(min + ((max - min) * i) / steps)
	return(list)

list_interpolation_z1 = generate_interpolation(min_z1, max_z1, 4)

labels_list_interpolation_z1 = []

for i in range(len(list_interpolation_z1)):
	labels_list_interpolation_z1.append([list_interpolation_z1[i], max_z2])

list_interpolation_z2 = generate_interpolation(min_z2, max_z2, 4)

labels_list_interpolation_z2 = []

for i in range(len(list_interpolation_z2)):
	labels_list_interpolation_z2.append([max_z1, list_interpolation_z2[i]])

labels_list_extreme = [[-20, -20], [-20, 20], [20, 20], [20, -20], [0, 0]]
labels_list_interpolation_z1 = [[20, 20], [10, 20], [0, 20], [-10, 20], [-20, 20]]
labels_list_interpolation_z2 = [[20, 20], [20, 10], [20, 0], [20, -10], [20, -20]]

scores_extreme = create_scores(1, create_melodies_from_labels(labels_list_extreme))

scores_interpolation_z1 = create_scores(1, create_melodies_from_labels(labels_list_interpolation_z1))

scores_interpolation_z2 = create_scores(1, create_melodies_from_labels(labels_list_interpolation_z2))

print('Write the scores')

scores_extreme[0].write('midi', 'mid/auto_var_mel_-_-.mid')
scores_extreme[1].write('midi', 'mid/auto_var_mel_-_+.mid')
scores_extreme[2].write('midi', 'mid/auto_var_mel_+_+.mid')
scores_extreme[3].write('midi', 'mid/auto_var_mel_+_-.mid')
scores_extreme[4].write('midi', 'mid/auto_var_mel_0_0.mid')

for i in range(len(scores_interpolation_z1)):
	scores_interpolation_z1[i].write('midi', 'mid/auto_var_mel_inter_z1_' + str(i) + '.mid')

for i in range(len(scores_interpolation_z2)):
	scores_interpolation_z2[i].write('midi', 'mid/auto_var_mel_inter_z2_' + str(i) + '.mid')
