Example #1
0
def doTestMysqlHook(*args, **kwargs):
    sql_hook = MySqlHook().get_hook(conn_id="mysql_operator_test_connid")
    sql = "select * from manzeng_predict_src_table;"
    result = sql_hook.get_records(sql)
    for row in result:
        print(row)
    sql = "select max(id) as max_id from manzeng_predict_src_table"
    result = sql_hook.get_records(sql)
    print('maxid:' + str(result[0][0]))
    result = sql_hook.get_first(sql)
    print('maxid:' + str(result[0]))
    LoggingMixin.log.exception("exception raise test")
    sql_hook.run(
        """insert into manzeng_result_v3(consignor_phone,prediction) values('122','33')"""
    )
Example #2
0
def get_signers(instance_id, context):
    db = MySqlHook(mysql_conn_id='mariadb', schema="dapp")
    sql = f"""
    select
        instance_id, sign_area_id, sequence, sub_instance_id, sign_section, sign_position, sign_action, is_executed, group_culture, group_id, group_name, 
        created_date, received_date, approved_date
    from
        signers
    where 
        instance_id = %s
    """
    tasks = {}
    rows = db.get_records(sql, parameters=[instance_id])
    tasks[SIGNERS] = []
    for row in rows:
        model = {
            'instance_id': row[0],
            'sign_area_id': row[1],
            'sequence': row[2],
            'sub_instance_id': row[3],
            'sign_section': row[4],
            'sign_position': row[5],
            'sign_action': row[6],
            'is_executed': row[7],
            'group_culture': row[8],
            'group_id': row[9],
            'group_name': row[10],
            'created_date': str(row[11]),
            'received_date': str(row[12]),
            'approved_date': str(row[13])
        }
        tasks[SIGNERS].append(model)

    context['ti'].xcom_push(key=SIGNERS, value=tasks[SIGNERS])
    return list(tasks.values())
Example #3
0
def get_sign_activity(instance_id, context):
    db = MySqlHook(mysql_conn_id='mariadb', schema="dapp")
    sql = f"""
    select
        instance_id, sign_area_id, sequence, sign_area, sign_section, sign_position, sign_action, is_comment, is_executed, group_id, user_id, host_address
    from
        sign_activity
    where 
        instance_id = %s
    """
    tasks = {}
    tasks[SIGN_ACTIVITY] = []
    rows = db.get_records(sql, parameters=[instance_id])
    for row in rows:
        model = {
            'instance_id': row[0],
            'sign_area_id': row[1],
            'sequence': row[2],
            'sign_area': row[3],
            'sign_section': row[4],
            'sign_position': row[5],
            'sign_action': row[6],
            'is_comment': row[7],
            'is_executed': row[8],
            'group_id': row[9],
            'user_id': row[10],
            'host_address': row[11]
        }
        tasks[SIGN_ACTIVITY].append(model)

    context['ti'].xcom_push(key=SIGN_ACTIVITY, value=tasks[SIGN_ACTIVITY])
    return list(tasks.values())
Example #4
0
 def objects(self):
     where_clause = ''
     if DB_WHITELIST:
         dbs = ",".join(["'" + db + "'" for db in DB_WHITELIST])
         where_clause = "AND b.name IN ({})".format(dbs)
     if DB_BLACKLIST:
         dbs = ",".join(["'" + db + "'" for db in DB_BLACKLIST])
         where_clause = "AND b.name NOT IN ({})".format(dbs)
     sql = """
     SELECT CONCAT(b.NAME, '.', a.TBL_NAME), TBL_TYPE
     FROM TBLS a
     JOIN DBS b ON a.DB_ID = b.DB_ID
     WHERE
         a.TBL_NAME NOT LIKE '%tmp%' AND
         a.TBL_NAME NOT LIKE '%temp%' AND
         b.NAME NOT LIKE '%tmp%' AND
         b.NAME NOT LIKE '%temp%'
     {where_clause}
     LIMIT {LIMIT};
     """.format(where_clause=where_clause, LIMIT=TABLE_SELECTOR_LIMIT)
     h = MySqlHook(METASTORE_MYSQL_CONN_ID)
     d = [
         {'id': row[0], 'text': row[0]}
         for row in h.get_records(sql)]
     return json.dumps(d)
Example #5
0
    def execute(self, context):
        log.info("이찬호")
        db = MySqlHook(mysql_conn_id=self.db_conn_id, schema=self.db_schema)
        sql = """
        select 
            o.id,
            o.name,
            o.desc
        from 
            test o
        """
        # initialize the task list buckets
        tasks = {}
        index = 0
        rows = db.get_records(sql)
        for row in rows:
            index += 1
            tasks[f'get_workflow_{index}'] = []

        resultCounter = 0
        for row in rows:
            resultCounter += 1
            bucket = (resultCounter % index)
            model = {'id': str(row[0]), 'name': str(row[1])}
            tasks[f'get_workflow_{bucket}'].append(model)

        # Push the order lists into xcom
        for task in tasks:
            if len(tasks[task]) > 0:
                logging.info(f'Task {task} has {len(tasks[task])} orders.')
                context['ti'].xcom_push(key=task, value=tasks[task])

        return list(tasks.values())
Example #6
0
class MysqlQueryOperatorWithTemplatedParams(BaseOperator):
    template_fields = ('sql', 'parameters')
    template_ext = ('.sql',)
    ui_color = '#ededed'

    @apply_defaults
    def __init__(
            self, sql,
            mysql_conn_id='mysql_dwh',
            autocommit=False,
            parameters=None,
            *args, **kwargs):
        super(MysqlQueryOperatorWithTemplatedParams,
              self).__init__(*args, **kwargs)
        self.sql = sql
        self.mysql_conn_id = mysql_conn_id
        self.autocommit = autocommit
        self.parameters = parameters

    def execute(self, context):
        logging.info('Executing: ' + str(self.sql), self.parameters)
        self.hook = MySqlHook(mysql_conn_id=self.mysql_conn_id)
        data = self.hook.get_records(self.sql,
                                     parameters=self.parameters)
        #logging.info('Executed: ' + str(x))
        return data
Example #7
0
 def objects(self):
     where_clause = ''
     if DB_WHITELIST:
         dbs = ",".join(["'" + db + "'" for db in DB_WHITELIST])
         where_clause = "AND b.name IN ({})".format(dbs)
     if DB_BLACKLIST:
         dbs = ",".join(["'" + db + "'" for db in DB_BLACKLIST])
         where_clause = "AND b.name NOT IN ({})".format(dbs)
     sql = """
     SELECT CONCAT(b.NAME, '.', a.TBL_NAME), TBL_TYPE
     FROM TBLS a
     JOIN DBS b ON a.DB_ID = b.DB_ID
     WHERE
         a.TBL_NAME NOT LIKE '%tmp%' AND
         a.TBL_NAME NOT LIKE '%temp%' AND
         b.NAME NOT LIKE '%tmp%' AND
         b.NAME NOT LIKE '%temp%'
     {where_clause}
     LIMIT {LIMIT};
     """.format(where_clause=where_clause, LIMIT=TABLE_SELECTOR_LIMIT)
     h = MySqlHook(METASTORE_MYSQL_CONN_ID)
     d = [
         {'id': row[0], 'text': row[0]}
         for row in h.get_records(sql)]
     return json.dumps(d)
Example #8
0
    def index(self):
        end_date = datetime.now().date()
        start_date = end_date - timedelta(days=1)
        sql = """
            select 
                a.dag_id, 
                a.state, 
                min(a.start_date) as start_date, 
                max(a.end_date) as  end_date, 
                max(a.end_date)-min(a.start_date) as duration,
                b.job_type, 
                a.job_id
            from task_instance as a
            join job as b 
            ON a.job_id = b.id
            where 
                a.start_date >= "{start_date}" 
                and  a.start_date < "{end_date}" 
                and a.state != 'failed'
            group by a.dag_id, a.job_id 
            order by start_date;
        """.format(start_date=start_date, end_date=end_date)
        h = MySqlHook(METASTORE_MYSQL_CONN_ID)
        rows = h.get_records(sql)
        tasks = []
        taskNames = []
        name_set = set("")
        time_format = "%Y-%m-%dT%H:%M:%S"
        for row in rows:
            dag_id = row[0]
            state = row[1]
            start_date = row[2]
            end_date = row[3]
            if not end_date:
                end_date = datetime.now()
            duration = str(row[4])
            task = {
                'status': state,
                'taskName': dag_id,
                'startDate': time.mktime(start_date.timetuple()) * 1000,
                'endDate': time.mktime(end_date.timetuple()) * 1000,
                'executionDate': start_date.strftime(time_format),
                'isoStart': start_date.strftime(time_format),
                'isoEnd': end_date.strftime(time_format),
                'duration': duration
            }
            taskNames.append(dag_id)
            name_set.add(dag_id)
            tasks.append(task)

        data = {
            'height': 20 * len(name_set),
            'tasks': tasks,
            'taskNames': taskNames,
            'taskStatus': {
                'success': 'success'
            }
        }

        return self.render("scheduler_browser/gantt.html", data=data)
Example #9
0
def get_groups(group_id, context):
    db = MySqlHook(mysql_conn_id='mariadb', schema="dbo")
    sql = f"""
    select
        group_id, parent_id, name, culture, group_code, sequence, is_childs, depth, is_display, email, interface_id, remark, created_date, modified_date 
    from
        groups
    where 
        group_id = %s
    """
    task = {}
    rows = db.get_records(sql, parameters=[group_id])
    for row in rows:
        model = {
            'group_id': row[0],
            'parent_id': row[1],
            'name': row[2],
            'culture': row[3],
            'group_code': row[4],
            'sequence': row[5],
            'is_childs': row[6],
            'depth': row[7],
            'is_display': row[8],
            'email': row[9],
            'remark': row[10],
            'created_date': str(row[11]),
            'modified_date': str(row[12])
        }
        task = model

    context['ti'].xcom_push(key=GROUPS, value=task)
    return task
Example #10
0
def get_max_mysql(connection_name, schema, table_name, column):
    logging.info('Executing: SELECT max(' + str(column) + ') FROM ' +
                 str(schema) + '.' + str(table_name))
    hook = MySqlHook(mysql_conn_id=connection_name)
    output = hook.get_records('SELECT max(' + str(column) + ') FROM ' +
                              str(schema) + '.' + str(table_name))
    logging.info(output[0][0])
    return output[0][0]
Example #11
0
def filter_db():
    api = MySqlHook()
    data = api.get_records(sql='select * from movie where vote_average > 7')

    # truncate table filter
    api.run(sql='truncate table movie_filter')

    # insert ke table filter
    api.insert_rows(table='movie_filter', rows=data)
def move_mysql_to_redshift(tablename):
    mysql_hook = MySqlHook(mysql_conn_id='mysql_baseball')
    redshift_hook = PostgresHook(postgres_conn_id='redshift_host')

    sql = "select league_id, name, abbr from "
    sql += tablename

    cur = mysql_hook.get_records(sql)

    redshift_hook.insert_rows(tablename,cur)

    return cur;
Example #13
0
    def copy_table(self, mysql_conn_id, postgres_conn_id):

        print("### fetching records from MySQL table ###")
        mysqlserver = MySqlHook(mysql_conn_id)
        sql_query = "SELECT * from clean_store_transactions "
        data = mysqlserver.get_records(sql_query)

        print("### inserting records into Postgres table ###")
        postgresserver = PostgresHook(postgres_conn_id)
        postgres_query = "INSERT INTO clean_store_transactions VALUES(%s, %s, %s, %s, %s, %s, %s, %s, %s);"
        postgresserver.insert_rows(table='clean_store_transactions', rows=data)

        return True
Example #14
0
def step2(ds, **kargs):
    mysql_hook = MySqlHook(mysql_conn_id = 'cloudsql-test')
    items = mysql_hook.get_records("SELECT policyID FROM sample_db_1.SAMPLE_TABLE_4  limit 20")
    # mysql_hook = MySqlHook(mysql_conn_id='local_mysql')
    # items = mysql_hook.get_records("SELECT samepleid FROM cloudtest.SAMPLE_TABLE_5 limit 20")
    mail_list = []

    for r in items:
        # print 'mail:%s ' % r
        mail_list.append('%d' % r)
    # print(mail_list)
    params = ','.join(mail_list)
    logging.info("params:{}".format(params))
Example #15
0
def get_Mysql_data():
    def convert(a):
        return str(a).replace('(',
                              '').replace(')',
                                          '').replace(',',
                                                      '').replace("'", '')

    mysql_hook = MySqlHook(mysql_conn_id='test_mysql')
    sql = "SELECT sample_id FROM actg_samples where active_state=1"
    records = mysql_hook.get_records(sql=sql)
    final = []
    for a in records:
        final.append(convert(a))
    return final
Example #16
0
def loggin_updateids_from_kv(**kwargs):
    mysql_hook = MySqlHook(mysql_conn_id='cloudsql-test')
    items = mysql_hook.get_records(
        "SELECT samepleid FROM {}.{}  where sync_status = 0".format(
            kwargs['export_database'], kwargs['export_table']))
    # mysql_hook = MySqlHook(mysql_conn_id='local_mysql')
    # items = mysql_hook.get_records("SELECT samepleid FROM cloudtest.SAMPLE_TABLE_5 limit 20")
    mail_list = []

    for r in items:
        # print 'mail:%s ' % r
        mail_list.append('%d' % r)
    # print(mail_list)
    kwargs['update_id'] = ','.join(mail_list)
    logging.info("loggin_updateids_from_kv_ids:{}".format(kwargs['update_id']))
def checkRecords(**kwargs):
    # Check for data, if present write results to file
    # If empty, log information to different file
    mysql_hook = MySqlHook(mysql_conn_id='local_mysql')
    # Table test1 contains list of postive integers
    sql = """
    select * from test1
    """
    records = mysql_hook.get_records(sql)
    # Pushing to Task instance
    if len(records) == 0:
        kwargs['task_instance'].xcom_push(key='check', value=False)
    else:
        kwargs['task_instance'].xcom_push(key='check', value=True)
        kwargs['task_instance'].xcom_push(key='recordcount',
                                          value=len(records))
Example #18
0
def get_max_updated_at(conn_id, table_name):
    """
    Gets the max updated_at timestamp. If table is empty it returns None
    :param conn_id: Connection ID from the DB to search for
    :type conn_id: str
    :param table_name: name of the table to extract the max updated_at
    :type table_name: str
    :return: str with timestamp or None
    """
    mysql_hook = MySqlHook(conn_id)

    sql_query = "SELECT MAX(updated_at) FROM {}".format(table_name)
    max_updated_at = mysql_hook.get_records(sql_query)[0][
        0]  # row 0, column 0. Only one record returned

    return max_updated_at
Example #19
0
def get_workflow(**context):
    db = MySqlHook(mysql_conn_id='mariadb', schema="djob")

    sql = """
    select
        workflow_process_id,ngen,site_id,application_id,instance_id,schema_id,name,workflow_instance_id,state,retry_count,ready,
        execute_date,created_date,bookmark,version,request,reserved,message
    from
        workflow_process
    where 
        ready > 0 and retry_count < 10
    limit 1
    """
    task = {}
    rows = db.get_records(sql)
    for row in rows:
        model = {
            'workflow_process_id': row[0],
            'ngen': row[1],
            'site_id': row[2],
            'application_id': row[3],
            'instance_id': row[4],
            'schema_id': row[5],
            'name': row[6],
            'workflow_instance_id': row[7],
            'state': row[8],
            'retry_count': row[9],
            'ready': row[10],
            'execute_date': str(row[11]),
            'created_date': str(row[12]),
            'bookmark': row[13],
            'version': row[14],
            'request': row[15],
            'reserved': row[16],
            'message': row[17]
        }
        task = model

    # 객체가 있는 경우 처리
    if task != {}:
        context['ti'].xcom_push(key=WORKFLOWS, value=task)
        sql = f"""
        update workflow_process
            set ready = 0, bookmark = 'start'
        where workflow_process_id = %s
        """
        db.run(sql, autocommit=True, parameters=[task['workflow_process_id']])
Example #20
0
    def poke(self, context):
        db = MySqlHook(mysql_conn_id='mariadb', schema="djob")
        
        sql = """
        select
            workflow_process_id,ngen,site_id,application_id,instance_id,schema_id,name,workflow_instance_id,state,retry_count,ready,
            execute_date,created_date,bookmark,version,request,reserved,message
        from
            workflow_process
        where 
            ready > 0 and retry_count < 10
        """
        tasks = {}
        tasks[WORKFLOW_PROCESS] = []
        rows = db.get_records(sql)
        for row in rows:
            model = {
                'workflow_process_id':row[0],
                'ngen':row[1],
                'site_id':row[2],
                'application_id':row[3],
                'instance_id':row[4],
                'schema_id':row[5],
                'name':row[6],
                'workflow_instance_id':row[7],
                'state':row[8],
                'retry_count':row[9],
                'ready':row[10],
                'execute_date':str(row[11]),
                'created_date':str(row[12]),
                'bookmark':row[13],
                'version':row[14],
                'request':row[15],
                'reserved':row[16],
                'message':row[17]
            }
            tasks[WORKFLOW_PROCESS].append(model)

        # 객체가 있는 경우 처리
        if tasks[WORKFLOW_PROCESS] != []:
            log.info('workflow_process find data')
            context['ti'].xcom_push(key=WORKFLOW_PROCESS, value=tasks[WORKFLOW_PROCESS])
            return True
        else:
            log.info('workflow_process empty data')
            return False
Example #21
0
def get_instance(**context):
    workflow = context['ti'].xcom_pull(task_ids=WORKFLOW_START_TASK,
                                       key=WORKFLOWS)
    db = MySqlHook(mysql_conn_id='mariadb', schema="dapp")
    instance_id = int(workflow["instance_id"])
    sql = f"""
    select
        instance_id,state,form_id,parent_id,workflow_id,subject,creator_culture,creator_name,group_culture,group_name,is_urgency,is_comment,is_related_document,
        attach_count,summary,re_draft_group_id,sub_proc_group_id,interface_id,created_date,completed_date,
        creator_id,group_id
    from
        instances
    where 
        instance_id = %s
    """
    task = {}
    rows = db.get_records(sql, parameters=[instance_id])
    for row in rows:
        model = {
            'instance_id': row[0],
            'state': row[1],
            'form_id': row[2],
            'parent_id': row[3],
            'workflow_id': row[4],
            'subject': row[5],
            'creator_culture': row[6],
            'creator_name': row[7],
            'group_culture': row[8],
            'group_name': row[9],
            'is_urgency': row[10],
            'is_comment': row[11],
            'is_related_document': row[12],
            'attach_count': row[13],
            'summary': row[14],
            're_draft_group_id': row[15],
            'sub_proc_group_id': row[16],
            'interface_id': row[17],
            'created_date': str(row[18]),
            'completed_date': str(row[19]),
            'creator_id': row[20],
            'group_id': row[21]
        }
        task = model

    context['ti'].xcom_push(key=INSTANCES, value=task)
    return task
Example #22
0
def check_previous_runs(**kwargs):
    context = kwargs
    current_run_id = context['dag_run'].run_id
    current_dag_id = context['dag_run'].dag_id
    # Connect to mysql and check for any errors for this DAG
    airflow_conn = MySqlHook(mysql_conn_id='deliverbi_mysql_airflow')
    l_error_count = 0
    cmd_sql = f"select count(1) from airflow.dag_run where dag_id = '{current_dag_id}' "
    cmd_sql += f"and run_id <> '{current_run_id}' and state = 'failed'"
    print(cmd_sql)
    airflow_data = airflow_conn.get_records(sql=cmd_sql)
    for row in airflow_data:
        l_error_count = int((str(row[0])))

    print("Found Previous Errors:" + str(l_error_count))
    if l_error_count != 0:
        raise AirflowException(
            "Previous Run in Error so Failing the Current Run")
def get_data_from_mysql(filename, tablename):
    hook = MySqlHook(mysql_conn_id='mysql_baseball')

    sql = "select * from "
    sql += tablename

    cur = hook.get_records(sql)

    #    f = open(filename,'w')
    #    print >>f, cur

    c = csv.writer(open(filename, "wb"), quoting=csv.QUOTE_NONNUMERIC)

    for row in cur:
        c.writerow(row)

    c.close()

    return cur
Example #24
0
def get_users(user_id, context):
    db = MySqlHook(mysql_conn_id='mariadb', schema="dbo")
    sql = f"""
    select
        user_id, name, culture, group_id, employee_num, anonymous_name, email, theme_code, date_format_code, time_format_code, time_zone, row_count, language_code, 
        interface_id, phone, mobile, fax, icon, addsign_img, is_plural, is_notification, is_absence, is_deputy 
    from
        users
    where 
        user_id = %s
    """
    task = {}
    rows = db.get_records(sql, parameters=[user_id])
    for row in rows:
        model = {
            'user_id': row[0],
            'name': row[1],
            'culture': row[2],
            'group_id': row[3],
            'employee_num': row[4],
            'anonymous_name': row[5],
            'email': row[6],
            'theme_code': row[7],
            'date_format_code': row[8],
            'time_format_code': row[9],
            'row_count': row[10],
            'language_code': row[11],
            'interface_id': row[12],
            'phone': row[13],
            'mobile': row[14],
            'fax': row[15],
            'icon': row[16],
            'addsign_img': row[17],
            'is_plural': row[18],
            'is_notification': row[19],
            'is_absence': row[20],
            'is_deputy': row[21]
        }
        task = model

    context['ti'].xcom_push(key=USERS, value=task)
    return task
Example #25
0
def get_columns_and_exclude(conn_id, table_name, l_columns_exclude):
    """

    :param conn_id: connection id to connect to
    :param table_name: table to get the columns
    :param l_columns_exclude: list of strings of columns to exclude
    :return: list of strings without the excluded columns
    """
    mysql_hook = MySqlHook(conn_id)

    sql_query = "SHOW COLUMNS FROM {}".format(table_name)
    all_records = mysql_hook.get_records(sql_query)

    l_columns_after_exclude = [
        "t.`{}`".format(l_row[0]) for l_row in all_records
        if l_row[0] not in l_columns_exclude
    ]
    logging.debug(
        "Columns after exclude: '{}'".format(l_columns_after_exclude))

    return l_columns_after_exclude
def send_aggregate_to_requestbin():
    target = 'http://requestbin.net/r/zorarbzo'

    connection = MySqlHook(mysql_conn_id='mysql_default')
    sql = '''
        SELECT 
            film_name, name, birth_year 
        FROM
            `swapi_data`.`swapi_people_aggregate`;
    '''
    result = connection.get_records(sql)
    data = []
    for item in result:
        data.append({
            "film_name": item[0],
            "name": item[1],
            "birth_year": str(item[2])
        })

    result = requests.post(target, data=json.dumps(data))

    return result
Example #27
0
def get_group_users(gid, uid, context):
    db = MySqlHook(mysql_conn_id='mariadb', schema="dbo")
    sql = f"""
    select 
        b.*, a.relation_type, a.is_master
    from 
        dbo.group_users a
        inner join dbo.groups b
        on a.group_id = b.group_id
    where 
        a.user_id = %s and a.parent_id = %s
    """
    tasks = {}
    tasks[GROUP_USERS] = []
    rows = db.get_records(sql, parameters=[uid, gid])
    for row in rows:
        model = {
            'group_id': row[0],
            'parent_id': row[1],
            'name': row[2],
            'culture': row[3],
            'group_code': row[4],
            'sequence': row[5],
            'is_childs': row[6],
            'depth': row[7],
            'is_display': row[8],
            'email': row[9],
            'remark': row[10],
            'created_date': str(row[11]),
            'modified_date': str(row[12]),
            'relation_type': row[13],
            'is_master': row[14]
        }
        tasks[GROUP_USERS].append(model)

    context['ti'].xcom_push(key=GROUPS, value=tasks[GROUP_USERS])
    return list(tasks.values())
Example #28
0
def getWorkflows(**context):
    db = MySqlHook(mysql_conn_id='mariadb', schema="djob")
    # initialize the task list buckets
    sql = """
    select 
        o.id,
        o.name,
        o.desc
    from 
        test o
    """
    tasks = {}
    rowCount = 0
    rows = db.get_records(sql)
    for row in rows:
        rowCount += 1
        tasks[f'order_processing_task_{rowCount}'] = []

    # populate the task list buckets
    # distribute them evenly across the set of buckets
    # records: List[List[Optional[Any]]] = db.get_records(get_orders_query)

    resultCount = 0
    for row in rows:
        resultCount += 1
        model = {'id': str(row[0]), 'name': str(row[1])}
        # tasks[f'order_processing_task_{bucket}'] = []
        tasks[f'order_processing_task_{resultCount}'].append(model)

    items = {}
    # Push the order lists into xcom
    for task in tasks:
        if len(tasks[task]) > 0:
            logging.info(f'Task {task} has {len(tasks[task])} orders.')
            context['ti'].xcom_push(key=task, value=tasks[task])

    return list(tasks.values())
    def execute(self, context=None):
        metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id)
        table = metastore.get_table(table_name=self.table)
        field_types = {col.name: col.type for col in table.sd.cols}

        exprs = {
            ('', 'count'): 'COUNT(*)'
        }
        for col, col_type in list(field_types.items()):
            d = {}
            if self.assignment_func:
                d = self.assignment_func(col, col_type)
                if d is None:
                    d = self.get_default_exprs(col, col_type)
            else:
                d = self.get_default_exprs(col, col_type)
            exprs.update(d)
        exprs.update(self.extra_exprs)
        exprs = OrderedDict(exprs)
        exprs_str = ",\n        ".join([
            v + " AS " + k[0] + '__' + k[1]
            for k, v in exprs.items()])

        where_clause = ["{} = '{}'".format(k, v) for k, v in self.partition.items()]
        where_clause = " AND\n        ".join(where_clause)
        sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format(
            exprs_str=exprs_str, table=self.table, where_clause=where_clause)

        presto = PrestoHook(presto_conn_id=self.presto_conn_id)
        self.log.info('Executing SQL check: %s', sql)
        row = presto.get_first(hql=sql)
        self.log.info("Record: %s", row)
        if not row:
            raise AirflowException("The query returned None")

        part_json = json.dumps(self.partition, sort_keys=True)

        self.log.info("Deleting rows from previous runs if they exist")
        mysql = MySqlHook(self.mysql_conn_id)
        sql = """
        SELECT 1 FROM hive_stats
        WHERE
            table_name='{table}' AND
            partition_repr='{part_json}' AND
            dttm='{dttm}'
        LIMIT 1;
        """.format(table=self.table, part_json=part_json, dttm=self.dttm)
        if mysql.get_records(sql):
            sql = """
            DELETE FROM hive_stats
            WHERE
                table_name='{table}' AND
                partition_repr='{part_json}' AND
                dttm='{dttm}';
            """.format(table=self.table, part_json=part_json, dttm=self.dttm)
            mysql.run(sql)

        self.log.info("Pivoting and loading cells into the Airflow db")
        rows = [(self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1])
                for r in zip(exprs, row)]
        mysql.insert_rows(
            table='hive_stats',
            rows=rows,
            target_fields=[
                'ds',
                'dttm',
                'table_name',
                'partition_repr',
                'col',
                'metric',
                'value',
            ]
        )
Example #30
0
    def execute(self, context=None):
        metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id)
        table = metastore.get_table(table_name=self.table)
        field_types = {col.name: col.type for col in table.sd.cols}

        exprs = {('', 'count'): 'COUNT(*)'}
        for col, col_type in list(field_types.items()):
            d = {}
            if self.assignment_func:
                d = self.assignment_func(col, col_type)
                if d is None:
                    d = self.get_default_exprs(col, col_type)
            else:
                d = self.get_default_exprs(col, col_type)
            exprs.update(d)
        exprs.update(self.extra_exprs)
        exprs = OrderedDict(exprs)
        exprs_str = ",\n        ".join(
            [v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()])

        where_clause = [
            "{} = '{}'".format(k, v) for k, v in self.partition.items()
        ]
        where_clause = " AND\n        ".join(where_clause)
        sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format(
            exprs_str=exprs_str, table=self.table, where_clause=where_clause)

        presto = PrestoHook(presto_conn_id=self.presto_conn_id)
        self.log.info('Executing SQL check: %s', sql)
        row = presto.get_first(hql=sql)
        self.log.info("Record: %s", row)
        if not row:
            raise AirflowException("The query returned None")

        part_json = json.dumps(self.partition, sort_keys=True)

        self.log.info("Deleting rows from previous runs if they exist")
        mysql = MySqlHook(self.mysql_conn_id)
        sql = """
        SELECT 1 FROM hive_stats
        WHERE
            table_name='{table}' AND
            partition_repr='{part_json}' AND
            dttm='{dttm}'
        LIMIT 1;
        """.format(table=self.table, part_json=part_json, dttm=self.dttm)
        if mysql.get_records(sql):
            sql = """
            DELETE FROM hive_stats
            WHERE
                table_name='{table}' AND
                partition_repr='{part_json}' AND
                dttm='{dttm}';
            """.format(table=self.table, part_json=part_json, dttm=self.dttm)
            mysql.run(sql)

        self.log.info("Pivoting and loading cells into the Airflow db")
        rows = [(self.ds, self.dttm, self.table, part_json) +
                (r[0][0], r[0][1], r[1]) for r in zip(exprs, row)]
        mysql.insert_rows(table='hive_stats',
                          rows=rows,
                          target_fields=[
                              'ds',
                              'dttm',
                              'table_name',
                              'partition_repr',
                              'col',
                              'metric',
                              'value',
                          ])
Example #31
0
    def poke(self, context):
        # Setting default args for ngap2 airflow db
        db_type = 'mysql'
        internal_db = True

        # Reading connection to check db type
        if self.cluster_id is not None:
            conn = BaseHook.get_connection(self.cluster_id)
            logging.info(
                "checking for Extra field in connections to determine db type mysql or postgres or lambda"
            )
            # check for db type from extra args
            for arg_name, arg_val in conn.extra_dejson.items():
                if arg_name in ['db_type']:
                    db_type = arg_val
                if arg_name in ['internal_db']:
                    internal_db = False

        if self.execution_delta:
            dttm = context['execution_date'] - self.execution_delta
        elif self.execution_delta_json:
            hour = context['execution_date'].strftime('%H')
            delta = self.execution_delta_json[hour]
            hour_d = int(delta.split(':')[0])
            minute_d = int(
                delta.split(':')[1] if len(delta.split(':')) > 1 else '00')

            final_minutes = 0
            if hour_d < 0:
                final_minutes = (hour_d * 60) - minute_d
            else:
                final_minutes = (hour_d * 60) + minute_d

            dttm = context['execution_date'] - timedelta(minutes=final_minutes)
        else:
            dttm = context['execution_date']

        allowed_states = tuple(self.allowed_states)
        if len(allowed_states) == 1:
            sql = " SELECT ti.task_id FROM task_instance ti WHERE ti.dag_id = '{self.external_dag_id}' " \
                  "AND ti.task_id = '{self.external_task_id}' AND ti.state = ('{allowed_states[0]}') " \
                  "AND ti.execution_date = '{dttm}';".format(**locals())
        else:
            sql = "SELECT  ti.task_id FROM task_instance ti WHERE ti.dag_id = '{self.external_dag_id}' " \
                  "AND ti.task_id = '{self.external_task_id}' AND ti.state in {allowed_states} " \
                  "AND ti.execution_date = '{dttm}';".format(**locals())

        if self.cluster_id is None:
            logging.info('Poking for '
                         '{self.external_dag_id}.'
                         '{self.external_task_id} on '
                         '{dttm} ... '.format(**locals()))

            TI = TaskInstance

            session = settings.Session()
            count = session.query(TI).filter(
                TI.dag_id == self.external_dag_id,
                TI.task_id == self.external_task_id,
                TI.state.in_(self.allowed_states),
                TI.execution_date == dttm,
            ).count()

            session.commit()
            session.close()

        elif internal_db == True:
            logging.info('Poking for '
                         '{self.external_dag_id}.'
                         '{self.external_task_id} on '
                         '{dttm} on {self.cluster_id} ... '.format(**locals()))
            hook = None

            if db_type == 'mysql':
                from airflow.hooks.mysql_hook import MySqlHook
                hook = MySqlHook(mysql_conn_id=self.cluster_id)

            elif db_type == 'postgres':
                from airflow.hooks.postgres_hook import PostgresHook
                hook = PostgresHook(postgres_conn_id=self.cluster_id)

            else:
                raise Exception("Please specify correct db type")

            records = hook.get_records(sql)

            if not records:
                count = 0
            else:
                if str(records[0][0]) in (
                        '0',
                        '',
                ):
                    count = 0
                else:
                    count = len(records)
                logging.info('task record found')
        else:
            host = conn.host
            user = conn.login
            password = conn.password
            dbname = conn.schema
            port = conn.port
            url = None
            for arg_name, arg_val in conn.extra_dejson.items():
                if arg_name in ['endpoint']:
                    url = arg_val
            if not url:
                raise KeyError(
                    'Lambda endpoint is not specified in Extra args')

            logging.info('Poking for '
                         '{self.external_dag_id}.'
                         '{self.external_task_id} on '
                         '{dttm} on {self.cluster_id} ... '.format(**locals()))

            import requests
            from requests.adapters import HTTPAdapter
            from requests.packages.urllib3.util.retry import Retry
            payload = "{\"ENDPOINT\":\"" + host + "\",\"PORT\":\"" + str(
                port
            ) + "\",\"DBUSER\":\"" + user + "\",\"DBPASSWORD\":\"" + password + "\",\"DATABASE\":\"" + dbname + "\", \"DB_TYPE\":\"" + db_type + "\",\"QUERY\":\"" + sql + "\"}"
            headers = {
                'Content-Type': "application/json",
                'cache-control': "no-cache",
            }
            session = requests.Session()
            retries = Retry(total=5,
                            backoff_factor=1,
                            status_forcelist=[502, 503, 504, 500])
            session.mount('https://', HTTPAdapter(max_retries=retries))
            response = session.post(url.replace(' ', ''),
                                    headers=headers,
                                    data=payload)
            if response.status_code == 200:
                count = int(response.text)
            else:
                raise Exception(response.content)
        logging.info(count)
        return count
Example #32
0
 def execute(self, context):
     self.log.info('Executing: %s', self.sql)
     hook = MySqlHook(mysql_conn_id=self.mysql_conn_id,
                      schema=self.database)
     return hook.get_records(self.sql, parameters=self.parameters)