# MNIST handwritten digit Autoencoding
# Copied with small adaptation from Keras example, Francois Chollet, 2016
# Attribute arithmetic
# Jean-Pierre Briot
# 06/05/2019

import numpy as np
import random
from statistics import mean

from keras.datasets import mnist

import matplotlib.pyplot as plt

from keras.layers import Lambda, Input, Dense
from keras.models import Model
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model

from keras import backend as K

# Dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()
	
# reshaping data into vectors of 28x28 (= 784) pixels vectors
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

# scaling the data to help training (grayscale (0, 255) -> divide by 255)
X_train /= 255
X_test /= 255

# Setup the parameters
input_size  = 28 * 28		# 784

# network parameters
input_shape = (input_size, )
intermediate_dim = 512
batch_size = 20
latent_dim = 2
epochs = 10

# reparameterization trick
# instead of sampling from Q(z|X), sample eps = N(0,I)
# z = z_mean + sqrt(var)*eps
def sampling(args):
	"""Reparameterization trick by sampling fr an isotropic unit Gaussian.

	# Arguments
	args (tensor): mean and log of variance of Q(z|X)

	# Returns
	z (tensor): sampled latent vector
	"""

	z_mean, z_log_var = args
	batch = K.shape(z_mean)[0]
	dim = K.int_shape(z_mean)[1]
	# by default, random_normal has mean=0 and std=1.0
	epsilon = K.random_normal(shape=(batch, dim))
	return(z_mean + K.exp(0.5 * z_log_var) * epsilon)

# VAE model = encoder + decoder
# build encoder model

inputs = Input(shape = input_shape,
	name = 'encoder_input')
x = Dense(intermediate_dim,
	activation = 'relu')(inputs)
z_mean = Dense(latent_dim,
	name = 'z_mean')(x)
z_log_var = Dense(latent_dim,
	name = 'z_log_var')(x)

# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary with the TensorFlow backend
z = Lambda(sampling,
	output_shape = (latent_dim,),
	name = 'z')([z_mean, z_log_var])

# instantiate encoder model
encoder = Model(inputs,
	[z_mean, z_log_var, z],
	name = 'encoder')
encoder.summary()

# build decoder model
latent_inputs = Input(shape = (latent_dim,),
	name = 'z_sampling')
x = Dense(intermediate_dim,
	activation = 'relu')(latent_inputs)
outputs = Dense(input_size,
	activation = 'sigmoid')(x)

# instantiate decoder model
decoder = Model(latent_inputs,
	outputs,
	name = 'decoder')
decoder.summary()

# instantiate VAE model
outputs = decoder(encoder(inputs)[2])

vae = Model(inputs,
	outputs,
	name = 'vae_mlp')

models = (encoder, decoder)
data = (X_test, y_test)

reconstruction_loss = mse(inputs, outputs)

reconstruction_loss *= input_size
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss,
			axis = -1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)

vae.compile(optimizer = 'adam')

vae.summary()

# Training

print('Training the model/network')

vae.fit(X_train,
	epochs = epochs,
	batch_size = batch_size,
	validation_data = (X_test, None))
#	vae.save_weights('vae_mlp_mnist.h5')

print('Model trained')

z_mean, _, _ = encoder.predict(X_test,
                                          batch_size = batch_size)

plt.figure(figsize = (8, 8))
plt.scatter(z_mean[:, 0],
			z_mean[:, 1],
			c = y_test)
plt.colorbar()
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.show()

n = 30
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))

# linearly spaced coordinates corresponding to the 2D plot # of digit classes in the latent space
grid_x = np.linspace(-4, 4, n)
grid_y = np.linspace(-4, 4, n)[::-1]

for i, yi in enumerate(grid_y):
	for j, xi in enumerate(grid_x):
		z_sample = np.array([[xi, yi]])
		x_decoded = decoder.predict(z_sample)
		digit = x_decoded[0].reshape(digit_size, digit_size)
		figure[i * digit_size: (i + 1) * digit_size, j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize = (8, 8))
start_range = digit_size // 2
end_range = n * digit_size + start_range + 1
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap='Greys_r')
plt.show()

number_elements = 100

X_elements = X_train[:number_elements]
y_elements = y_train[:number_elements]

round_numbers = [3, 6, 8, 9]
angular_numbers = [1, 4, 7]

round_elements = []
angular_elements= []

for i in range(len(X_elements)):
	if y_elements[i] in round_numbers:
		round_elements.append(X_elements[i])
	elif y_elements[i] in angular_numbers:
		angular_elements.append(X_elements[i])

_, _, z_round_elements = encoder.predict(np.array(round_elements))

_, _, z_angular_elements = encoder.predict(np.array(angular_elements))

z1_round_elements = []
z2_round_elements = []
z1_angular_elements = []
z2_angular_elements = []

for dummy, element in enumerate(z_round_elements):
	z1_round_elements.append(element[0])
	z2_round_elements.append(element[1])

for dummy, element in enumerate(z_angular_elements):
	z1_angular_elements.append(element[0])
	z2_angular_elements.append(element[1])

# Redefine mean because division problems (Python version ?)
def mean_dm(list_values):
	sum_values = 0
	for dummy, val in enumerate(list_values):
		sum_values = sum_values + val
	return(float(sum_values / len(list_values)))

z1_mean_round_elements = mean_dm(z1_round_elements)
z2_mean_round_elements = mean_dm(z2_round_elements)
z1_mean_angular_elements = mean_dm(z1_angular_elements)
z2_mean_angular_elements = mean_dm(z2_angular_elements)

def round_into_angular(z):
	z_angularized = [z[0] + z1_mean_angular_elements, z[1] + z2_mean_angular_elements]
	return(decoder.predict(np.array([z_angularized]))[0])

def angular_into_round(z):
	z_rounded = [z[0] + z1_mean_round_elements, z[1] + z2_mean_round_elements]
	return(decoder.predict(np.array([z_rounded]))[0])

print('Original round digit')

ex1_round_digit = round_elements[3]	# 3

ex1_round_digit = ex1_round_digit.reshape(28, 28)

fig, ax = plt.subplots()
im = ax.imshow(ex1_round_digit)
plt.show()

print('Original round digit made angular')

ex1_z_round = z_round_elements[0]		# 4

ex1_round_angularized_digit = round_into_angular(ex1_z_round)

ex1_round_angularized_digit = ex1_round_angularized_digit.reshape(28, 28)

fig, ax = plt.subplots()
im = ax.imshow(ex1_round_angularized_digit)
plt.show()

print('Original angular digit')

ex1_angular_digit = angular_elements[0]

ex1_angular_digit = ex1_angular_digit.reshape(28, 28)

fig, ax = plt.subplots()
im = ax.imshow(ex1_angular_digit)
plt.show()

print('Original angular digit made round')

ex1_z_angular = z_angular_elements[0]

ex1_angular_rounded_digit = angular_into_round(ex1_z_angular)

ex1_angular_rounded_digit = ex1_angular_rounded_digit.reshape(28, 28)

fig, ax = plt.subplots()
im = ax.imshow(ex1_angular_rounded_digit)
plt.show()
