Esempio n. 1
0
def run_and_push(**kwargs):
    conn = PostgresHook('postgres_default').get_conn()
    PostgresHook.set_autocommit(PostgresHook(), conn, True)
    cur = conn.cursor()
    cur.execute(kwargs['templates_dict']['script'])
    result = cur.fetchall()
    row = result[0]
    kwargs['ti'].xcom_push(key=row[0], value=row[1])
Esempio n. 2
0
 def execute(self, context: Dict[str, Any]) -> None:
     base_dir = Path()
     if self.base_dir_task_id is not None:
         base_dir = Path(context["task_instance"].xcom_pull(
             task_ids=self.base_dir_task_id))
         self.log.info("Setting base_dir to '%s'.", base_dir)
     hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)
     with closing(hook.get_conn()) as conn:
         if hook.supports_autocommit:
             self.log.debug("Setting autocommit to '%s'.", self.autocommit)
             hook.set_autocommit(conn, self.autocommit)
         with closing(conn.cursor()) as cursor:
             for ft in self.data:
                 file = ft.file if ft.file.is_absolute(
                 ) else base_dir / ft.file
                 self.log.info("Processing file '%s' for table '%s'.", file,
                               ft.table)
                 with file.open(newline="") as csv_file:
                     reader = csv.DictReader(csv_file,
                                             dialect=csv.unix_dialect)
                     assert (reader.fieldnames is not None
                             ), f"CSV file '{file}' needs to have a header."
                     cursor.execute(
                         sql.SQL("""
                             PREPARE csv_ins_stmt AS INSERT INTO {table} ({columns})
                                                     VALUES ({values})
                             """).format(
                             table=sql.Identifier(ft.table),
                             columns=sql.SQL(", ").join(
                                 map(sql.Identifier, reader.fieldnames)),
                             values=sql.SQL(", ").join(
                                 sql.SQL(f"${pos}") for pos in range(
                                     1,
                                     len(reader.fieldnames) + 1)),
                         ))
                     execute_batch(
                         cursor,
                         sql.SQL("EXECUTE csv_ins_stmt ({params})").format(
                             params=sql.SQL(", ").join(
                                 sql.Placeholder(col)
                                 for col in reader.fieldnames)),
                         # Nasty hack to treat empty string values as None
                         ({
                             k: v if v != "" else None
                             for k, v in row.items()
                         } for row in reader),
                         self.page_size,
                     )
                     cursor.execute("DEALLOCATE csv_ins_stmt")
         if not hook.get_autocommit(conn):
             self.log.debug("Committing transaction.")
             conn.commit()
     for output in hook.conn.notices:
         self.log.info(output)
    def execute(self, context):
        '''
        :param context: DAG/task metadata provided by Airflow.
        '''
        env = Variable.get('environment')

        # Get the path to the CSV produced by the previous task
        # in our DAG, using Airflow's message passing.
        csv_path = context['task_instance'].xcom_pull(
            dag_id=context['dag'].dag_id,
            task_ids=self.process_csv_from,
            key='csv_path')

        # Our connection to the database (explained in an earlier operator)
        pg_hook = PostgresHook('pg_onethree_demo', supports_autocommit=True)
        conn = pg_hook.get_conn()
        pg_hook.set_autocommit(conn, True)
        cursor = conn.cursor()

        # This contains the DAG name, task name, and date for uniqueness
        temp_table = context['task_instance_key_str']

        # Create temporary table. For simplicity I'm hardcoding the structure
        # here but a production operator would determine the structure dynamically
        # (since the operator would support different CSVs and target tables).
        #
        # Also, a more sophisticated database structure would have a table
        # dedicated to targets, and a separate table simply mapping between the
        # drug and target tables. Then the target table could store metadata on
        # the targets without duplication.
        create_temp_table = '''
            CREATE TEMP TABLE {} (drugbank_drug INTEGER, 
            drugbank_target CHARACTER VARYING)
        '''.format(temp_table)
        log.info(create_temp_table)
        cursor.execute(create_temp_table)

        # Load our data into the temporary table. This would let us perform
        # aggregate validations in the future (i.e. validation SQL queries
        # performed on the temporary table before the data is loaded into
        # the target table), with a new validation Airflow operator.
        #
        # The Postgres COPY command requires a local file.
        temp_csv = task_temp_file(context)
        append_to_local(env, csv_path, temp_csv)
        copy_sql = "COPY {} FROM STDIN DELIMITER ',' ".format(temp_table)
        log.info(copy_sql)
        cursor.copy_expert(sql=copy_sql, file=open(temp_csv, 'r'))
        os.remove(temp_csv)

        # First step of the merge (an upsert, i.e. insert-or-update).
        #
        # In this case we don't actually perform an update, but in a
        # production operator we could have an operator parameter
        # specifying columns to match to determine when to perform
        # an update
        # (i.e. ON CONFLICT (matching columns) [UPDATE remaining columns] )
        #
        upsert = '''
            INSERT INTO {} (drugbank_drug, drugbank_target)
            SELECT * FROM {}
            ON CONFLICT (drugbank_drug, drugbank_target) DO NOTHING
            '''.format(self.target_table, temp_table)
        log.info(upsert)
        cursor.execute(upsert)

        # Second step of the merge (delete targets no longer tied
        # to the specified drug)
        delete = '''
            DELETE FROM {} WHERE (drugbank_drug, drugbank_target) 
            NOT IN (SELECT drugbank_drug, drugbank_target FROM {})
            '''.format(self.target_table, temp_table)
        cursor.execute(delete)
        conn.commit()
    def execute(self, context: Dict[str, Any]) -> None:  # noqa: C901
        hook = PostgresHook(postgres_conn_id=self.postgres_conn_id,
                            schema=self.database)

        # Use the mixin class _assign to assign new values, if provided.
        self._assign(context)

        with closing(hook.get_conn()) as conn:
            if hook.supports_autocommit:
                self.log.debug("Setting autocommit to '%s'.", self.autocommit)
                hook.set_autocommit(conn, self.autocommit)

            if self.source_table_name is not None and self.target_table_name is not None:
                # Start a list to hold copy information
                table_copies: List[TableMapping] = [
                    TableMapping(source=self.source_table_name,
                                 target=self.target_table_name),
                ]

            # Find the cross-tables for n-m relations, we assume they have
            # a name that start with f"{source_table_name}_"

            with closing(conn.cursor()) as cursor:
                # the underscore must be escaped because of it's special meaning in a like
                # the exclamation mark was used as an escape chacater because
                # a backslash was not interpreted as an escape
                cursor.execute(
                    """
                        SELECT tablename FROM pg_tables
                        WHERE schemaname = 'public' AND tablename like %(table_name)s ESCAPE '!'
                    """,
                    dict(table_name=f"{self.source_table_name}!_%"),
                )

                junction_tables = tuple(
                    map(operator.itemgetter("tablename"), cursor.fetchall()))
                if junction_tables:
                    self.log.info(
                        f"Found the following junction tables: '{', '.join(junction_tables)}'."
                    )
                else:
                    self.log.info("Did not found any junction tables.")

                junction_table_copies: List[TableMapping] = []
                for source_table_name in junction_tables:
                    target_table_name = source_table_name.replace("_new", "")
                    junction_table_copies.append(
                        TableMapping(source_table_name, target_table_name))

                statements: List[Statement] = [
                    Statement(
                        sql="""
                    CREATE TABLE IF NOT EXISTS {target_table_name}
                    (
                        LIKE {source_table_name} INCLUDING CONSTRAINTS INCLUDING INDEXES
                    )
                    """,
                        log_msg="Creating new table '{target_table_name}' "
                        "using table '{source_table_name}' as a template.",
                    )
                ]
                if self.truncate_target:
                    statements.append(
                        Statement(
                            sql="TRUNCATE TABLE {target_table_name} CASCADE",
                            log_msg="Truncating table '{target_table_name}'.",
                        ))
                if self.copy_data:
                    statements.append(
                        Statement(
                            sql="""
                        INSERT INTO {target_table_name}
                        SELECT *
                        FROM {source_table_name}
                        """,
                            log_msg=
                            "Copying all data from table '{source_table_name}' "
                            "to table '{target_table_name}'.",
                        ))
                if self.drop_source:
                    statements.append(
                        Statement(
                            sql=
                            "DROP TABLE IF EXISTS {source_table_name} CASCADE",
                            log_msg="Dropping table '{source_table_name}'.",
                        ))
                for table_mapping in itertools.chain(table_copies,
                                                     junction_table_copies):
                    for stmt in statements:
                        self.log.info(
                            stmt.log_msg.format(
                                source_table_name=table_mapping.source,
                                target_table_name=table_mapping.target,
                            ))
                        cursor.execute(
                            sql.SQL(stmt.sql).format(
                                source_table_name=sql.Identifier(
                                    table_mapping.source),
                                target_table_name=sql.Identifier(
                                    table_mapping.target),
                            ))
            if not hook.get_autocommit(conn):
                self.log.debug("Committing transaction.")
                conn.commit()
        for output in hook.conn.notices:
            self.log.info(output)
Esempio n. 5
0
class RedshiftToS3(BaseOperator):
    """ UNLOAD from Redshift to S3 """

    template = """
    UNLOAD
    ('{{ sql }}')
    TO
    '{{ s3_to }}'
    credentials AS 'aws_access_key_id={{ aws_key }};aws_secret_access_key={{ aws_secret }}'
    {% for cmd in extra_commands %} {{ cmd }}{% endfor %}
    """

    template_fields = ('s3_to', 'sql', 'load_sql')
    template_ext = ('.sql',)
    ui_color = '#FBDA34'

    @apply_defaults
    def __init__(
            self,
            load_sql,
            s3_to,
            postgres_conn_id='postgres_default',
            s3_conn_id='s3_default',
            extra_commands=[],
            autocommit=False,
            parameters=None,
            dummy=False,
            *args, **kwargs):
        super(RedshiftToS3, self).__init__(*args, **kwargs)
        bucket = translate_bucket_name(re.findall("s3://([a-zA-Z\-]*)/", s3_to)[0])
        key_base = re.findall("s3://[a-zA-Z\-]*/(.*)", s3_to)[0]
        self.s3_to = "s3://" + bucket + "/" + key_base
        # Escape possible single quotes
        self.load_sql = load_sql.replace("'", "\\'")
        self.postgres_conn_id = postgres_conn_id
        self.s3_conn_id = s3_conn_id
        self.autocommit = autocommit
        self.parameters = parameters
        self.extra_commands = extra_commands
        self.dummy = dummy
        self.sql = Template(self.template).render(sql=self.load_sql, s3_to=self.s3_to,
                                                  extra_commands=self.extra_commands,
                                                  aws_key="#" * 10, aws_secret="#" * 10)

    def execute(self, context):
        self.ps_hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)
        login, password = GenericHook(self.s3_conn_id).get_credentials()
        sql = Template(self.template).render(sql=self.load_sql, s3_to=self.s3_to,
                                             extra_commands=self.extra_commands,
                                             aws_key=login, aws_secret=password)
        # Do not log the sql as it contains the actual secrets, use self.sql
        # instead
        logging.info("Executing: " + self.sql)
        if not self.dummy:
            self._run(sql, parameters=self.parameters)
        else:
            logging.info("Dummy flag is on, skipping actual execution")

    def _run(self, sql, autocommit=False, parameters=None):
        """ Copy of the run in PostgresHook without logging """
        conn = self.ps_hook.get_conn()
        if isinstance(sql, basestring):
            sql = [sql]

        if self.ps_hook.supports_autocommit:
            self.ps_hook.set_autocommit(conn, autocommit)

        cur = conn.cursor()
        for s in sql:
            if parameters is not None:
                cur.execute(s, parameters)
            else:
                cur.execute(s)
        cur.close()
        conn.commit()
        conn.close()
Esempio n. 6
0
class S3ToRedshift(BaseOperator):
    """ COPY data from S3 to Redshift """

    template = """
    COPY {{ table }} FROM '{{ s3_from }}' WITH
    credentials AS 'aws_access_key_id={{ aws_key }};aws_secret_access_key={{ aws_secret }}'
    {% for cmd in extra_commands %} {{ cmd }}{% endfor %}
    """

    template_fields = ('table', 's3_from', 'sql')
    template_ext = ('.sql',)
    ui_color = '#E3C62C'

    @apply_defaults
    def __init__(
            self,
            table,
            s3_from,
            postgres_conn_id='postgres_default',
            s3_conn_id='s3_default',
            extra_commands=[],
            dummy=False,
            parameters=None,
            *args, **kwargs):
        super(S3ToRedshift, self).__init__(*args, **kwargs)
        self.table = table
        bucket = translate_bucket_name(re.findall("s3://([a-zA-Z\-]*)/", s3_from)[0])
        key_base = re.findall("s3://[a-zA-Z\-]*/(.*)", s3_from)[0]
        self.s3_from = "s3://" + bucket + "/" + key_base
        self.postgres_conn_id = postgres_conn_id
        self.s3_conn_id = s3_conn_id
        self.parameters = parameters
        self.extra_commands = extra_commands
        self.dummy = dummy
        self.sql = Template(self.template).render(table=self.table, s3_from=self.s3_from,
                                                  extra_commands=self.extra_commands,
                                                  aws_key="#" * 10, aws_secret="#" * 10)

    def execute(self, context):
        self.ps_hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)
        login, password = GenericHook(self.s3_conn_id).get_credentials()
        sql = Template(self.template).render(table=self.table, s3_from=self.s3_from,
                                             extra_commands=self.extra_commands,
                                             aws_key=login, aws_secret=password)
        # Do not log the sql as it contains the actual secrets, use self.sql
        # instead
        logging.info("Executing: " + self.sql)
        if not self.dummy:
            self._run(sql, parameters=self.parameters)
        else:
            logging.info("Dummy flag is on, skipping actual execution")

    def _run(self, sql, autocommit=True, parameters=None):
        """ Copy of the run in PostgresHook without logging """
        conn = self.ps_hook.get_conn()
        if isinstance(sql, basestring):
            sql = [sql]

        if self.ps_hook.supports_autocommit:
            self.ps_hook.set_autocommit(conn, autocommit)

        cur = conn.cursor()
        for s in sql:
            if parameters is not None:
                cur.execute(s, parameters)
            else:
                cur.execute(s)
        cur.close()
        conn.commit()
        conn.close()