"""A very simple MNIST classifier.
See extensive documentation at
http://tensorflow.org/tutorials/mnist/beginners/index.md
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

### untuk inisialisasi weight parameters ###
def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.1)
  return tf.Variable(initial)

# jika pakai ReLU, ada baiknya melakukan inisialisasi bias dengan sedikit positif
# untuk menghindari dead neurons
def bias_variable(shape):
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial)

### convolutions and poolings ###
# padding SAME, zero padded, input dan output punya shape sama
# W adalah tensor yang menyatakan fitur/weight dari proses convolution
# W = [row_patch, col_patch, channel input, output channel/banyaknya fitur atau filter]
# x = [batch_size, row, col, channel]
# strides of the sliding window for each dimension in input
# [1, 1, 1, 1] artinya:
#   geser 1 untuk batch
#   geser 1 untuk row
#   geser 1 untuk col
#   geser 1 untuk channel
# most common case : [1, stride, stride, 1]
# ukuran output, jika padding SAME:
# out_height = ceil(float(in_height) / float(strides[1]))
# out_width  = ceil(float(in_width) / float(strides[2]))

def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

# input : x = [batch, row, col, channels]
# ksize, size of window for each dimension of input
# ksize = [1,2,2,1] artinya, window untuk rowxcol adalah 2x2
#                            dan untuk batch : 1, channels : 1
# strides = [1,2,2,1] artinya, geser window di row dan col : 2 dan 2
#                            dan untuk batch, geser 1, channels 1
# biasanya memang [1, x, y, 1] untuk ksize dan strides
# ini artinya, window 2x2 dan geser 2,2 untuk row dan col
# jadi, kalau input berukuran M x N, output akan berukuran M/2 x N/2
# rumus untuk pada suatu dimensi:
# output[i] = reduce(value[strides * i:strides * i + ksize])
def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], \
                        strides=[1, 2, 2, 1], padding='SAME')


def main():
  # Import data
  mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
  
  #28x28 = 784
  #mnist.train.images -> shape=[55000, 784]
  #mnist.train.labels -> shape=[55000, 10]

  # input placeholder 784 unit, output 10 unit
  x = tf.placeholder(tf.float32, [None, 784]) #-> 2D, None artinya bisa berapapun, jadi shape [55000, 784] juga cocok di-feed

  # perlu reshape image yang tadinya flat vector 784 ke 4D tensor
  # 4D tensor [batch, row, col, channel]
  x_image = tf.reshape(x, [-1,28,28,1])

  ### first Conv2D layer ###
  # ada 32 fitur/kernel dengan masing-masing kernel ber-window 5x5
  # jadi parameter W disini mempunyai shape [5, 5, 1, 32]
  W_conv1 = weight_variable([5, 5, 1, 32])
  b_conv1 = bias_variable([32])

  # reduksi 28x28 image ke 14x14 image
  h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
  h_pool1 = max_pool_2x2(h_conv1)

  ### second Conv2D layer ###
  # ada 64 fitur/kernel dengan masing-masing kernel ber-window 5x5
  # jadi parameter W disini mempunyai shape [5, 5, 32, 64]
  # ukuran satu kernel adalah 5x5x32 -> 32 adalah banyaknya channel input
  W_conv2 = weight_variable([5, 5, 32, 64])
  b_conv2 = bias_variable([64])

  # reduksi 14x14 image ke 7x7 image
  h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
  h_pool2 = max_pool_2x2(h_conv2)

  ### Fully-connected layer ###
  # W = input size x output uni, output unit = 1024 neurons
  W_fc1 = weight_variable([7 * 7 * 64, 1024])
  b_fc1 = bias_variable([1024])
  
  # kita perlu flatten kan lagi, 4D tensor ke 2D atau batch of flat vector, masing2 7*7*64
  h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
  h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

  ### to reduce overfitting, apply dropout before last readout layer
  keep_prob = tf.placeholder(tf.float32)
  h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

  ### last layer: readout layer ###
  W_fc2 = weight_variable([1024, 10])
  b_fc2 = bias_variable([10])

  y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 # raw logits

  # placeholder untuk true labels
  y_ = tf.placeholder(tf.float32, [None, 10])

  # The raw formulation of cross-entropy,
  #
  #   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
  #                                 reduction_indices=[1]))
  #
  # can be numerically unstable.
  #
  # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
  # outputs of 'y', and then average across the batch.

  cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv, y_))

  train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  
  init = tf.global_variables_initializer()
  
  sess = tf.Session()
  sess.run(init)

  for i in range(2000):
    batch = mnist.train.next_batch(128)

    if (i % 200 == 0):
       print('loss: {}'.format(sess.run(cross_entropy, feed_dict={x: mnist.train.images[:1000], \
                                                                  y_: mnist.train.labels[:1000], \
                                                                  keep_prob:1.0})))

    sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

  
  correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))

  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  print("test accuracy %g"%sess.run(accuracy, feed_dict={x: mnist.test.images[:1000], \
                                                        y_: mnist.test.labels[:1000], \
                                                        keep_prob: 1.0}))


main()


