def csv_load_to_db(destination_folder,
                   filename,
                   insert_query_file,
                   by_rows_batch=10000):
    """
    parse csv file and execute query to load into database.
    
    Arg: 
        1. filename = name of csv file 'filename.csv'
        2. destination_folder = downloaded files directory 'data/'
        3. insert_query = dir or .sql file ,'path/local/query.sql'
    
    """
    csv_file = open(destination_folder + filename, 'r')
    sql_file = open(insert_query_file, 'r')
    sql = sql_file.read()
    insert_query = sql.split(';')[1]

    conn = MySqlHook(mysql_conn_id='mysql_localhost').get_conn()
    cur = conn.cursor()
    cur.execute('use sales_records_airflow')
    cur.execute('select count(*) from sales LIMIT 1')
    row_count = cur.fetchone(
    )[0] + 1  # add one because we want to exclude header when slicing csv for loop

    if row_count is 1:
        print('empty')
        for row in islice(
                csv_file, row_count, row_count +
                by_rows_batch):  # start 1, stop 10000 return 10000 rows
            val = row.rstrip().split(',')
            dt1 = datetime.strptime(val[5], '%m/%d/%Y').date()
            dt2 = datetime.strptime(val[7], '%m/%d/%Y').date()
            val[5] = dt1
            val[7] = dt2
            params = val
            cur.execute(query=insert_query, args=params)
            conn.commit()

    elif row_count > 1:
        print('not empty')
        for row in islice(
                csv_file, row_count, row_count + by_rows_batch
        ):  # previous rows add 1 start at 10001, stop at 10001+10000 return 10000 rows end at row 20000
            val = row.rstrip().split(',')
            dt1 = datetime.strptime(val[5], '%m/%d/%Y').date()
            dt2 = datetime.strptime(val[7], '%m/%d/%Y').date()
            val[5] = dt1
            val[7] = dt2
            params = val
            cur.execute(query=insert_query, args=params)
            conn.commit()
    elif row_count == 50001:
        pass
    conn.close()
    csv_file.close()
def mysql_to_pq(source_transform,
                name_of_dataset='project_four_airflow',
                by_row_batch=1000):
    '''
    extract mysql database and save into local pq ``tmp/sales-date.pq``. this function take the last rows of bq dataset and compared againts current
    mysql database to avoid duplication, only extract load new data from mysql to bq. if dataset not exist it will create dataset using name given
    
    Args:

        1. source_transform = 'path/local/file.pq'

        2. by_row_batch = number of row you want to extract ``int``

    return: 
        ``str`` of local pq file path
    '''
    client = BigQueryHook(gcp_conn_id='google_cloud_default').get_client()
    row_id = client.query(
        'select id from project_four_airflow.sales order by id desc limit 1')
    try:
        for i in row_id:
            last_row_id = i[0]
            print(i[0])
    except GoogleAPIError:
        row_id.error_result['reason'] == 'notFound'
        last_row_id = 0
        print('no dataset.table')
        client.create_dataset(name_of_dataset)
        print('new dataset, {} created'.format(name_of_dataset))
    conn = MySqlHook(mysql_conn_id='mysql_localhost').get_conn()
    cur = conn.cursor()
    cur.execute('use sales_records_airflow')
    cur.execute('select * from sales where id>={} and id<={}'.format(
        last_row_id + 1, last_row_id + by_row_batch))

    list_row = cur.fetchall()
    rows_of_extracted_mysql = []
    for i in list_row:
        rows_of_extracted_mysql.append(list(i))
    print('extracting from mysql')
    df = pd.DataFrame(rows_of_extracted_mysql,
                      columns=[
                          'id', 'region', 'country', 'item_type',
                          'sales_channel', 'Order Priority', 'order_date',
                          'order_id', 'ship_date', 'units_sold', 'unit_price',
                          'unit_cost', 'total_revenue', 'total_cost',
                          'total_profit'
                      ])
    df.to_parquet(source_transform)
    print('task complete check,', source_transform)
def check_data(task_instance, create_table_query_file):
    conn = MySqlHook(mysql_conn_id='mysql_localhost').get_conn()
    cur = conn.cursor()
    try:
        cur.execute('use sales_records_airflow')
        cur.execute('select count(*) from sales')
        total_rows = cur.fetchone()[0]
        task_instance.xcom_push(key='mysql_total_rows', value=total_rows)
        if type(total_rows) is int:
            print('appending new data')
            return 'csv_file_exist'
        elif total_rows == 50000:
            print('up to date')
            return 'check_dataset'
    except cur.OperationalError:
        print('sql_file execute')
        sql_file = open(create_table_query_file, 'r')
        sql_query = sql_file.read()
        for query in sql_query.split(';', maxsplit=2):
            cur.execute('{}'.format(query))
            conn.commit()
        return 'csv_file_not_exist'