
from pyspark.sql import SparkSession
import sys

# accept s3 bucket name
if len(sys.argv) != 2:
    print("Usage: pass s3_bucket_name")
    sys.exit(1)
s3_bucket_name = sys.argv[1]

# start session
spark = SparkSession.builder.appName("green_zone").getOrCreate()

# Read green taxi files and create TempView
s3_green_path = "s3://{}/input_data/green/*".format(s3_bucket_name)
df_green = spark.read.parquet(s3_green_path)
df_green.createOrReplaceTempView("src_green")
#df_src_green = spark.sql("SELECT * FROM src_green limit 10; ")
#df_src_green.show(10)


# Read zone_lookup file and create TempView
s3_zone_path = "s3://{}/input_data/zone_lookup/*".format(s3_bucket_name)
df_zone = spark.read.csv(s3_zone_path, header=True, inferSchema=True)
df_zone.createOrReplaceTempView("src_zone_lookup")
#df_src_zone = spark.sql("SELECT * FROM src_zone_lookup limit 10; ")
#df_src_zone.show(10)

# Create green zone, join green and zone_lookup
df_green_zone = spark.sql("select replace(z1.zone,'"','') as pu_zone,  replace(z1.service_zone,'"','') as  pu_service_zone, green.pulocationid as pulocationid,  replace(z2.zone,'"','') as do_zone, replace(z2.service_zone,'"','') as  do_service_zone, green.dolocationid as dolocationid,  sum(green.passenger_count) passenger_count, sum(green.trip_distance) trip_distance, sum(green.fare_amount) fare_amount ,  sum(green.extra) extra , sum(green.mta_tax) mta_tax, sum(green.tip_amount) tip_amount, sum(green.tolls_amount) tolls_amount, sum(green.improvement_surcharge) improvement_surcharge, sum(green.total_amount) total_amount, sum(green.congestion_surcharge) congestion_surcharge from src_green green left outer join  src_zone_lookup z1 on green.pulocationid = z1.locationid left outer join  src_zone_lookup z2 on green.dolocationid = z2.locationid group by 1,2,3,4;") 

df_green_zone.show(10)

# Write green_zone as a parquet file
s3_green_zone_path = "s3://{}/output_data/green_zone/".format(s3_bucket_name)
df_green_zone.write.parquet(s3_green_zone_path)

