コード例 #1
0
 def execute(self, context: Dict[str, Any]) -> None:
     base_dir = Path()
     if self.base_dir_task_id is not None:
         base_dir = Path(context["task_instance"].xcom_pull(
             task_ids=self.base_dir_task_id))
         self.log.info("Setting base_dir to '%s'.", base_dir)
     hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)
     with closing(hook.get_conn()) as conn:
         if hook.supports_autocommit:
             self.log.debug("Setting autocommit to '%s'.", self.autocommit)
             hook.set_autocommit(conn, self.autocommit)
         with closing(conn.cursor()) as cursor:
             for ft in self.data:
                 file = ft.file if ft.file.is_absolute(
                 ) else base_dir / ft.file
                 self.log.info("Processing file '%s' for table '%s'.", file,
                               ft.table)
                 with file.open(newline="") as csv_file:
                     reader = csv.DictReader(csv_file,
                                             dialect=csv.unix_dialect)
                     assert (reader.fieldnames is not None
                             ), f"CSV file '{file}' needs to have a header."
                     cursor.execute(
                         sql.SQL("""
                             PREPARE csv_ins_stmt AS INSERT INTO {table} ({columns})
                                                     VALUES ({values})
                             """).format(
                             table=sql.Identifier(ft.table),
                             columns=sql.SQL(", ").join(
                                 map(sql.Identifier, reader.fieldnames)),
                             values=sql.SQL(", ").join(
                                 sql.SQL(f"${pos}") for pos in range(
                                     1,
                                     len(reader.fieldnames) + 1)),
                         ))
                     execute_batch(
                         cursor,
                         sql.SQL("EXECUTE csv_ins_stmt ({params})").format(
                             params=sql.SQL(", ").join(
                                 sql.Placeholder(col)
                                 for col in reader.fieldnames)),
                         # Nasty hack to treat empty string values as None
                         ({
                             k: v if v != "" else None
                             for k, v in row.items()
                         } for row in reader),
                         self.page_size,
                     )
                     cursor.execute("DEALLOCATE csv_ins_stmt")
         if not hook.get_autocommit(conn):
             self.log.debug("Committing transaction.")
             conn.commit()
     for output in hook.conn.notices:
         self.log.info(output)
コード例 #2
0
    def execute(self, context: Dict[str, Any]) -> None:  # noqa: C901
        hook = PostgresHook(postgres_conn_id=self.postgres_conn_id,
                            schema=self.database)

        # Use the mixin class _assign to assign new values, if provided.
        self._assign(context)

        with closing(hook.get_conn()) as conn:
            if hook.supports_autocommit:
                self.log.debug("Setting autocommit to '%s'.", self.autocommit)
                hook.set_autocommit(conn, self.autocommit)

            if self.source_table_name is not None and self.target_table_name is not None:
                # Start a list to hold copy information
                table_copies: List[TableMapping] = [
                    TableMapping(source=self.source_table_name,
                                 target=self.target_table_name),
                ]

            # Find the cross-tables for n-m relations, we assume they have
            # a name that start with f"{source_table_name}_"

            with closing(conn.cursor()) as cursor:
                # the underscore must be escaped because of it's special meaning in a like
                # the exclamation mark was used as an escape chacater because
                # a backslash was not interpreted as an escape
                cursor.execute(
                    """
                        SELECT tablename FROM pg_tables
                        WHERE schemaname = 'public' AND tablename like %(table_name)s ESCAPE '!'
                    """,
                    dict(table_name=f"{self.source_table_name}!_%"),
                )

                junction_tables = tuple(
                    map(operator.itemgetter("tablename"), cursor.fetchall()))
                if junction_tables:
                    self.log.info(
                        f"Found the following junction tables: '{', '.join(junction_tables)}'."
                    )
                else:
                    self.log.info("Did not found any junction tables.")

                junction_table_copies: List[TableMapping] = []
                for source_table_name in junction_tables:
                    target_table_name = source_table_name.replace("_new", "")
                    junction_table_copies.append(
                        TableMapping(source_table_name, target_table_name))

                statements: List[Statement] = [
                    Statement(
                        sql="""
                    CREATE TABLE IF NOT EXISTS {target_table_name}
                    (
                        LIKE {source_table_name} INCLUDING CONSTRAINTS INCLUDING INDEXES
                    )
                    """,
                        log_msg="Creating new table '{target_table_name}' "
                        "using table '{source_table_name}' as a template.",
                    )
                ]
                if self.truncate_target:
                    statements.append(
                        Statement(
                            sql="TRUNCATE TABLE {target_table_name} CASCADE",
                            log_msg="Truncating table '{target_table_name}'.",
                        ))
                if self.copy_data:
                    statements.append(
                        Statement(
                            sql="""
                        INSERT INTO {target_table_name}
                        SELECT *
                        FROM {source_table_name}
                        """,
                            log_msg=
                            "Copying all data from table '{source_table_name}' "
                            "to table '{target_table_name}'.",
                        ))
                if self.drop_source:
                    statements.append(
                        Statement(
                            sql=
                            "DROP TABLE IF EXISTS {source_table_name} CASCADE",
                            log_msg="Dropping table '{source_table_name}'.",
                        ))
                for table_mapping in itertools.chain(table_copies,
                                                     junction_table_copies):
                    for stmt in statements:
                        self.log.info(
                            stmt.log_msg.format(
                                source_table_name=table_mapping.source,
                                target_table_name=table_mapping.target,
                            ))
                        cursor.execute(
                            sql.SQL(stmt.sql).format(
                                source_table_name=sql.Identifier(
                                    table_mapping.source),
                                target_table_name=sql.Identifier(
                                    table_mapping.target),
                            ))
            if not hook.get_autocommit(conn):
                self.log.debug("Committing transaction.")
                conn.commit()
        for output in hook.conn.notices:
            self.log.info(output)