from pyspark.sql import *
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

from calendar import monthrange
from datetime import datetime, timedelta

spark = SparkSession.builder \
        .appName("Belajar Spark SQL") \
        .getOrCreate()

casa_schema = StructType([StructField("GROUP_CODE", StringType(), True), \
                          StructField("CIF", StringType(), True), \
                          StructField("AMOUNT", DoubleType(), True), \
                          StructField("buss_date2", DateType(), True,)])

hdfs_dir = "hdfs://master:9000/user/bd/data/"

df_casa = spark.read.csv(hdfs_dir + 'casa.csv', \
                         schema=casa_schema, \
                         sep="|", \
                         header=True, \
#                         inferSchema=True, \
                         dateFormat="ddMMMyyyy")

print df_casa.dtypes

# seandainya belum sorted berdasarkan buss_date2
df_casa_sorted = df_casa.sort("buss_date2", ascending=False)

# latest_date adalah object datetime.date
latest_date = df_casa_sorted.head().buss_date2

#print latest_date

def monthdelta(d1, d2):
    delta = 0
    while True:
        mdays = monthrange(d1.year, d1.month)[1]
        d1 += timedelta(days=mdays)
        if d1 <= d2:
            delta += 1
        else:
            break
    return delta

# x adalah object datetime.date
def delta(x):
    new_dt_a = latest_date.replace(day=1)
    new_dt_b = x.replace(day=1)
    return monthdelta(new_dt_b, new_dt_a)

func =  F.udf(delta, IntegerType())

#
pars = df_casa_sorted.buss_date2
df_filter = df_casa_sorted.filter(func(pars) <= 6) \
                          .select('GROUP_CODE', 'CIF', 'AMOUNT', 'buss_date2', func(pars).alias('m_diff'))

#df_filter.show()

# disini sepertinya, df_filter perlu di-cache, karena berikutnya akan diconsume berkali-kali
df_filter.cache()

# kalau tidak ada kembalikan NOL -> ini belum dihandle; mudahnya jika hasil filter, count-nya 0, kembalikan 0
aum_1 = df_filter.filter(df_filter.m_diff == 0) \
                 .select(df_filter.AMOUNT) \
                 .groupBy() \
                 .sum().head()['sum(AMOUNT)']

aum_1 = 0 if aum_1 == None else aum_1

aum_2 = df_filter.filter(df_filter.m_diff == 1) \
                 .select(df_filter.AMOUNT) \
                 .groupBy() \
                 .sum().head()['sum(AMOUNT)']

aum_2 = 0 if aum_2 == None else aum_2

aum_3 = df_filter.filter(df_filter.m_diff == 2) \
                 .select(df_filter.AMOUNT) \
                 .groupBy() \
                 .sum().head()['sum(AMOUNT)']

aum_3 = 0 if aum_3 == None else aum_3

aum_4 = df_filter.filter(df_filter.m_diff == 3) \
                 .select(df_filter.AMOUNT) \
                 .groupBy() \
                 .sum().head()['sum(AMOUNT)']

aum_4 = 0 if aum_4 == None else aum_4

aum_5 = df_filter.filter(df_filter.m_diff == 4) \
                 .select(df_filter.AMOUNT) \
                 .groupBy() \
                 .sum().head()['sum(AMOUNT)']

aum_5 = 0 if aum_5 == None else aum_5

aum_6 = df_filter.filter(df_filter.m_diff == 5) \
                 .select(df_filter.AMOUNT) \
                 .groupBy() \
                 .sum().head()['sum(AMOUNT)']

aum_6 = 0 if aum_6 == None else aum_6

aum_7 = df_filter.filter(df_filter.m_diff == 6) \
                 .select(df_filter.AMOUNT) \
                 .groupBy() \
                 .sum().head()['sum(AMOUNT)']

aum_7 = 0 if aum_7 == None else aum_7


aum = [aum_1, aum_2, aum_3, aum_4, aum_5, aum_6]
print aum

#RAT_BAL_AVG6_HIGH6
avg_bal_6 = sum(aum) / float(len(aum))
high_bal_6 = max(aum)
RAT_BAL_AVG6_HIGH6 = avg_bal_6 / high_bal_6
print 'RAT_BAL_AVG6_HIGH6: {}'.format(RAT_BAL_AVG6_HIGH6)

#F_BAL_INC75_6
aum_inc = [aum_1 > 1.75*aum_2, \
           aum_2 > 1.75*aum_3, \
           aum_3 > 1.75*aum_4, \
           aum_4 > 1.75*aum_5, \
           aum_5 > 1.75*aum_6, \
           aum_6 > 1.75*aum_7]

F_BAL_INC75_6 = sum(aum_inc)
print 'F_BAL_INC75_6: {}'.format(F_BAL_INC75_6)
