import os
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

######
# beberapa fungsi utilitas
######

#catatan:
#kita akan menggunakan Shared Variables seperti yang sudah dijelaskan pada
#https://www.tensorflow.org/how_tos/variable_scope/. Shared variables ini
#digunakan pada Discriminator. Jadi, daripada menggunakan tf.Variable(...)
#kita akan menggunakan tf.get_variable(...)

#leaky rectified linear unit sebagai activation function pada discriminator network
#ketika unit tidak aktif, nilainya tidak nol, tetapi nilai kecil
def lrelu(x, leak=0.2, name="lrelu"):
    return tf.maximum(x, leak*x)

#x : input berdimensi [batch, row, col, channel]
#kernel_shape/weights/parameters : [patch_row, patch_col, channel input, channel output/banyaknya filter]
#strides : umumnya [1, stride row, stride col, 1]
#activation_fn : activation function
def conv2D(x, kernel_shape, strides, activation_fn, scope_name):
    #di setiap nama variable/parameter pada conv2D, diberinama dengan awalan 'conv2d'
    with tf.variable_scope(scope_name):
        # Membuat sebuah variable "conv2d/weights" untuk parameter 2D Conv layer
        weights = tf.get_variable("weights", \
                                  kernel_shape, \
                                  initializer=tf.truncated_normal_initializer(stddev=0.02))

        #bias_shape : [banyaknya channel output/kernel]
        bias_shape = [kernel_shape[-1]]

        # Membuat sebuah variable untuk "conv2d/biases"
        biases = tf.get_variable("biases", \
                                 bias_shape, \
                                 initializer=tf.constant_initializer(0.0))

        conv = tf.nn.conv2d(x, \
                            weights, \
                            strides=strides, \
                            padding='SAME')

        return activation_fn(conv + biases)

#transpose dari 2D-Conv layer
#x : input berdimensi [batch, row, col, in_channels]
#kernel_shape/weights/parameters : [row, col, output_channels, in_channels] ; channel TERBALIK dari CONV2D!
#output_shape : [batch, row, col, output_channel]
#strides : umumnya [1, stride row, stride col, 1]
#activation_fn : activation function
def conv2D_trans(x, kernel_shape, output_shape, strides, activation_fn, scope_name):
    #di setiap nama variable/parameter pada conv2D_transpose, diberinama dengan awalan 'conv2d_trans'
    with tf.variable_scope(scope_name):
        # Membuat sebuah variable "conv2d_trans/weights" untuk parameter 2D Conv_trans layer
        weights = tf.get_variable("weights", \
                                  kernel_shape, \
                                  initializer=tf.truncated_normal_initializer(stddev=0.02))

        #bias_shape : [sebanyak output_channels]
        bias_shape = [output_shape[-1]]

        # Membuat sebuah variable untuk "conv2d_trans/biases"
        biases = tf.get_variable("biases", \
                                 bias_shape, \
                                 initializer=tf.constant_initializer(0.0))

        deconv = tf.nn.conv2d_transpose(x, \
                                        weights, \
                                        output_shape=output_shape, \
                                        strides=strides,\
                                        padding='SAME')

        return activation_fn(deconv + biases)
 
#dense layer/fully-connected layer
#x: input berdimensi [batch, input unit]
#output_unit: integer, banyaknya output unit
def dense(x, output_unit, activation_fn, scope_name='dense'):

    input_unit = x.get_shape()[-1]
    bias_shape = [output_unit]

    with tf.variable_scope(scope_name):
        # Membuat sebuah variable untuk "dense/weights"
        weights = tf.get_variable("weights", \
                                  [input_unit, output_unit], \
                                  initializer=tf.truncated_normal_initializer(stddev=0.02))

        # Membuat sebuah variable untuk "dense/biases"
        biases = tf.get_variable("biases", \
                                 bias_shape, \
                                 initializer=tf.constant_initializer(0.0))

        y = tf.matmul(x, weights) + biases # raw logits
        
        return activation_fn(y)



############################
# Definisi Generator dan Discriminator
############################

#images: input image dengan dimensi [batch, row, col, channel]
#channel : input_chn -> 16 -> 32 -> 64 -> dense layer
#ukuran image : setiap masuk Conv2D layer, berkurang separuh karena stride 2,2 untuk baris dan kolom
def discriminator(images, reuse=False):
    with tf.variable_scope("discriminator") as scope:

        #jika ingin reuse parameters/weights/variables dari discriminator yang sebelumnya sudah ADA !
        if reuse:
            scope.reuse_variables()

        #Conv2D layer pertama, input : image [batch, row, col, channel in], 
        #                      output: [batch, row/2, col/2, 16] 
        #ukuran image berkurang separuh karena stride 2,2
        #patch kernel 4x4
        channel_in_1 = images.get_shape().as_list()[-1]
        out_conv1 = conv2D(images, \
                           kernel_shape=[4, 4, channel_in_1, 16], \
                           strides=[1,2,2,1], \
                           activation_fn=lrelu, \
                           scope_name='d_1_conv2d')

        #Conv2D layer kedua, input : out_conv1 [batch, row, col, 16], 
        #                    output: [batch, row/2, col/2, 32] 
        #ukuran image berkurang separuh karena stride 2,2
        #patch kernel 4x4
        channel_in_2 = out_conv1.get_shape().as_list()[-1]
        out_conv2 = conv2D(out_conv1, \
                           kernel_shape=[4, 4, channel_in_2, 32], \
                           strides=[1,2,2,1], \
                           activation_fn=lrelu, \
                           scope_name='d_2_conv2d')

        #Conv2D layer ketiga, input : out_conv2 [batch, row, col, 32], 
        #                     output: [batch, row/2, col/2, 64] 
        #ukuran image berkurang separuh karena stride 2,2
        #patch kernel 4x4
        channel_in_3 = out_conv2.get_shape().as_list()[-1]
        out_conv3 = conv2D(out_conv2, \
                           kernel_shape=[4, 4, channel_in_3, 64], \
                           strides=[1,2,2,1], \
                           activation_fn=lrelu, \
                           scope_name='d_3_conv2d')

        #flatten layer
        #input: tensor berdimensi [batch, row, col, channel]
        #output: tensor berdimensi [batch, row*col*channel]
        r, c, ch = out_conv3.get_shape().as_list()[1], out_conv3.get_shape().as_list()[2], out_conv3.get_shape().as_list()[3]
        out_conv3_flat = tf.reshape(out_conv3, [-1, r*c*ch])

        #fully-connected layer / dense layer
        #input : tensor berdimensi [batch, input_unit]
        #output: tensor berdimensi [batch, output_unit]
        #output unit disini adalah 1, kita akan menggunakan SIGMOID
        #0 -> bukan berasal dari distribusi asli, 1 -> berasal dari distribusi asli
        out_dense = dense(out_conv3_flat, \
                          output_unit=1, \
                          activation_fn=tf.nn.sigmoid, \
                          scope_name='d_1_dense')

        return out_dense

#input: z yang merupakan vector dari random numbers
def generator(z):

    #dapat informasi batch size
    batch_size = z.get_shape().as_list()[0]

    with tf.variable_scope("generator") as scope:

        #dense layer
        #transformasi 100 bilangan random ke vektor yang ukuran lebih besar
        #z : input berdimensi [batch, 100]
        #output : berdimensi [batch, 4*4*64]
        out_dense = dense(z, \
                          output_unit=4*4*64, \
                          activation_fn=tf.nn.relu, \
                          scope_name='g_1_dense')

        #reshape layer
        #sebelumnya adalah tensor [batch, 4*4*64]
        #agar bisa dikonvolusi2D, kita reshape ke [batch, 4, 4, 64]
        out_dense_con = tf.reshape(out_dense, [-1, 4, 4, 64])

        #Deconv layer pertama, input: image 4x4 64 channel; [batch, 4, 4, 64]
        #                      output: [batch, 8, 8, 32]
        #dengan stride 2,2
        #dan patch kernel 5x5
        out_deconv1 = conv2D_trans(out_dense_con, \
                                   kernel_shape=[5, 5, 32, 64], \
                                   output_shape=[batch_size, 8, 8, 32], \
                                   strides=[1,2,2,1], \
                                   activation_fn=tf.nn.relu, \
                                   scope_name='g_1_deconv2d')

        #Deconv layer kedua, input: image 8x8 32 channel; [batch, 8, 8, 32]
        #                    output: [batch, 16, 16, 16]
        #dengan stride 2,2
        #dan patch kernel 5x5
        out_deconv2 = conv2D_trans(out_deconv1, \
                                   kernel_shape=[5, 5, 16, 32], \
                                   output_shape=[batch_size, 16, 16, 16], \
                                   strides=[1,2,2,1], \
                                   activation_fn=tf.nn.relu, \
                                   scope_name='g_2_deconv2d')

        #Deconv layer ketiga, input: image 16x16 16 channel; [batch, 16, 16, 16]
        #                     output: [batch, 32, 32, 8]
        #dengan stride 2,2
        #dan patch kernel 5x5
        out_deconv3 = conv2D_trans(out_deconv2, \
                                   kernel_shape=[5, 5, 8, 16], \
                                   output_shape=[batch_size, 32, 32, 8], \
                                   strides=[1,2,2,1], \
                                   activation_fn=tf.nn.relu, \
                                   scope_name='g_3_deconv2d')

        #Deconv layer keempat, input: image 32x32 8 channel; [batch, 32, 32, 8]
        #                      output: [batch, 32, 32, 1]
        #dengan stride 1,1
        #dan patch kernel 5x5
        out_deconv4 = conv2D_trans(out_deconv3, \
                                   kernel_shape=[5, 5, 1, 8], \
                                   output_shape=[batch_size, 32, 32, 1], \
                                   strides=[1,1,1,1], \
                                   activation_fn=tf.nn.tanh, \
                                   scope_name='g_4_deconv2d')

        return out_deconv4


######################
# Definisi Objective Function & Optimizer-nya
# Beberapa placeholder untuk input
######################

z_size = 100
batch_size = 128

#placeholder untuk input ke generator dan discriminator
z = tf.placeholder(shape=[batch_size, z_size], dtype=tf.float32)
real_images = tf.placeholder(shape=[batch_size, 32, 32, 1], dtype=tf.float32)

#definisi proses adversarial
G = generator(z)
Dx = discriminator(real_images)
Dg = discriminator(G, reuse=True) #parameters/weights sama dengan sebelumnya

#definisi objective functions
#bandingkan dengan yang ada pada paper asli GANs, ian goodfellow
d_loss = -tf.reduce_mean(tf.log(Dx) + tf.log(1. - Dg))
g_loss = -tf.reduce_mean(tf.log(Dg))

#ambil daftar semua parameters/weights untuk discriminator dan generator
tvars = tf.trainable_variables()
d_vars = [var for var in tvars if 'd_' in var.name]
g_vars = [var for var in tvars if 'g_' in var.name]

#debugging, check daftar semua nama variables/parameters/weights:
print ('Variable names:')
for var in tvars:
    print (var.name)

#optimizer
trainerD = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5)
trainerG = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5)
#hanya update parameter untuk discriminator saja
d_train_onestep = trainerD.minimize(d_loss, var_list=d_vars)
#hanya update parameter untuk generator saja
g_train_onestep = trainerG.minimize(g_loss, var_list=g_vars)


#############################
# Trainer
#############################
def train():

    epoch = 100000
    model_directory = './models_dcgan'

    #import mnist data
    mnist = input_data.read_data_sets("MNIST_data", one_hot=False)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:

        #inisialisasi variables/parameters
        sess.run(init)

        #mulai epoch
        for i in range(epoch):
            print ('epoch ke-{}'.format(i+1))

            #generate random zs ~ U(-1, 1)
            #berdimensi [batch, z_size]
            zs = np.random.uniform(-1.0, 1.0, size=[batch_size, z_size]).astype(np.float32)

            #ambil sebuah batch
            xs,_ = mnist.train.next_batch(batch_size) #ys tidak digunakan
            #transformasi nilai jadi -1 - +1
            xs = (np.reshape(xs,[batch_size,28,28,1]) - 0.5) * 2.0
            #padding yang awalnya 28x28 -> 32x32
            xs = np.lib.pad(xs, ((0,0),(2,2),(2,2),(0,0)), 'constant', constant_values=(-1,-1))

            #update parameter di discriminator
            sess.run([d_train_onestep], feed_dict={z:zs, real_images:xs})
            #update parameter di generator
            sess.run([g_train_onestep], feed_dict={z:zs})

            #save model untuk setiap 1000 epoch
            if i % 1000 == 0 and i != 0:
                if not os.path.exists(model_directory):
                    os.makedirs(model_directory)
                #saver.save(sess, model_directory+'/model-'+str(i)+'.cptk')
                saver.save(sess, model_directory+'/model.cptk')
                print ('model saved')

#############################
# Generate images
#############################
def generate_images():

    sample_directory = './figs'
    model_directory = './models'
    batch_size_sample = 36

train()
