
from pyspark.sql import SparkSession
import sys

# accept s3 bucket name
if len(sys.argv) != 2:
    print("Usage: pass s3_bucket_name")
    sys.exit(1)
s3_bucket_name = sys.argv[1]

# start session
spark = SparkSession.builder.appName("yellow_zone").getOrCreate()

# Read yellow taxi files and create TempView
s3_yellow_path = "s3://{}/input_data/yellow/*".format(s3_bucket_name)
df_yellow = spark.read.parquet(s3_yellow_path)
df_yellow.createOrReplaceTempView("src_yellow")
#df_src_yellow = spark.sql("SELECT * FROM src_yellow limit 10; ")
#df_src_yellow.show(10)


# Read zone_lookup file and create TempView
s3_zone_path = "s3://{}/input_data/zone_lookup/*".format(s3_bucket_name)
df_zone = spark.read.csv(s3_zone_path, header=True, inferSchema=True)
df_zone.createOrReplaceTempView("src_zone_lookup")
#df_src_zone = spark.sql("SELECT * FROM src_zone_lookup limit 10; ")
#df_src_zone.show(10)

# Create yellow zone, join yellow and zone_lookup
df_yellow_zone = spark.sql("select replace(z1.zone,'"','') as pu_zone , replace(z1.service_zone,'"','') as  pu_service_zone, yellow.pulocationid , replace(z2.zone,'"','') as do_zone , replace(z2.service_zone,'"','') as  do_service_zone, yellow.dolocationid , sum(yellow.passenger_count) passenger_count,  sum(yellow.trip_distance) trip_distance,  sum(yellow.fare_amount) fare_amount , sum(yellow.extra) extra , sum(yellow.mta_tax) mta_tax, sum(yellow.tip_amount) tip_amount, sum(yellow.tolls_amount) tolls_amount,  sum(yellow.improvement_surcharge) improvement_surcharge, sum(yellow.total_amount) total_amount, sum(yellow.congestion_surcharge) congestion_surcharge, sum(yellow.airport_fee) as airport_fee from src_yellow yellow left outer join  src_zone_lookup z1 on yellow.pulocationid = z1.locationid left outer join  src_zone_lookup z2 on yellow.dolocationid = z2.locationid group by 1,2,3,4;") 
df_yellow_zone.show(10)

# Write yellow_zone as a parquet file
s3_yellow_zone_path = "s3://{}/output_data/yellow_zone/".format(s3_bucket_name)
df_yellow_zone.write.parquet(s3_yellow_zone_path)

