# MNIST handwritten digit classification
# Inspired from some Keras code
# Simplified version (without convolutions)
# Adapted by Jean-Pierre Briot
# 28/08/2018

import numpy as np
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.utils import np_utils

import matplotlib.pyplot as plt

# Dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# fig, ax = plt.subplots()
# im = ax.imshow(X_train[0])
# plt.show()

# reshaping data into vectors of 28x28 (= 784) pixels vectors
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

# scaling the data to help training (grayscale (0, 255) -> divide by 255)
X_train /= 255
X_test /= 255

# number of labels/classes
number_classes = len(np.unique(y_train))

# one-hot encoding using keras' numpy-related utilities
Y_train = np_utils.to_categorical(y_train, number_classes)
Y_test = np_utils.to_categorical(y_test, number_classes)

# building a linear stack of layers with the sequential model
model = Sequential()
model.add(Dense(512,
	input_shape = (784, )))
model.add(Activation('relu'))                            
model.add(Dropout(0.2))
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(number_classes))
model.add(Activation('softmax'))

# compiling the sequential model
model.compile(loss = 'categorical_crossentropy',
	metrics = ['accuracy'],
	optimizer = 'adam')

# training the model and saving metrics in history
history = model.fit(X_train,
	Y_train,
	batch_size = 128,
	epochs = 10,
	verbose = 2,
	validation_data = (X_test, Y_test))

train_loss_and_metrics = model.evaluate(X_train, Y_train, verbose=2)

test_loss_and_metrics = model.evaluate(X_test, Y_test, verbose=2)

print("Train Loss", train_loss_and_metrics[0])
print("Train Accuracy", train_loss_and_metrics[1])
print("Test Loss", test_loss_and_metrics[0])
print("Test Accuracy", test_loss_and_metrics[1])

predicted_classes = model.predict_classes(X_test)

# see which we predicted correctly and which not
correct_indices = np.nonzero(predicted_classes == y_test)[0]
incorrect_indices = np.nonzero(predicted_classes != y_test)[0]
print()
print(len(correct_indices)," classified correctly")
print(len(incorrect_indices)," classified incorrectly")

# plot accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

# plot loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
