def execute(self, context: Dict[str, Any]) -> None:
        """Executes the ``generate_db_object`` method from schema-tools.

        Which leads to the creation of tables and/or an index on the identifier (as specified in
        the data JSON schema). By default both tables and the identifier and 'many-to-many
        table' indexes are created. By setting the boolean indicators in the method parameters,
        tables or an identifier index (per table) can be created.
        """

        # Use the mixin class _assign to assign new values, if provided.
        # This needs to operate first, it can change the data_table_name.
        self._assign(context)

        # NB. data_table_name could have been changed because of xcom info
        if isinstance(self.data_table_name, str):
            self.data_table_name = re.compile(self.data_table_name)
        else:
            self.data_table_name = self.data_table_name

        engine = _get_engine(self.db_conn)
        dataset_schema = schema_def_from_url(
            SCHEMA_URL, self.data_schema_name, prefetch_related=True
        )

        importer = BaseImporter(dataset_schema, engine, logger=self.log)
        self.log.info(
            "schema_name='%s', engine='%s', ind_table='%s', ind_extra_index='%s'.",
            self.data_schema_name,
            engine,
            self.ind_table,
            self.ind_extra_index,
        )

        if self.data_table_name is None:
            return

        for table in dataset_schema.tables:
            self.log.info("Considering table '%s'.", table.name)
            cur_table = f"{self.data_schema_name}_{table.name}"

            if re.fullmatch(self.data_table_name, cur_table):
                self.log.info("Generating PostgreSQL objects for table '%s'.", table.name)
                importer.generate_db_objects(
                    table.id,
                    ind_tables=self.ind_table,
                    ind_extra_index=self.ind_extra_index,
                )
            else:
                self.log.info("Skipping table '%s' (reason: no match).", table.name)
                continue
示例#2
0
    def execute(self, context=None):
        dataset = schema_def_from_url(SCHEMA_URL, self.dataset_name)
        pg_hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)

        sqls = []
        dataset_id = to_snake_case(dataset.id)
        for table in dataset.tables:
            table_id = to_snake_case(table.id)
            sqls.append(f"""
                DROP TABLE IF EXISTS {self.to_pg_schema}.{dataset_id}_{table_id};
                ALTER TABLE {self.from_pg_schema}.{table_id} SET SCHEMA {self.to_pg_schema};
                ALTER TABLE {table_id}
                    RENAME TO {dataset_id}_{table_id}; """)
        pg_hook.run(sqls)
    def execute(self, context=None):
        dataset = schema_def_from_url(SCHEMA_URL, self.dataset_name)
        print(dataset)
        pg_hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)
        sqls = []
        existing_tables_lookup = self._get_existing_tables(
            pg_hook, dataset.tables, pg_schema=self.pg_schema)
        snaked_tablenames = existing_tables_lookup.keys()
        existing_columns = self._get_existing_columns(pg_hook,
                                                      snaked_tablenames,
                                                      pg_schema=self.pg_schema)

        if self.rename_indexes:
            for table_name, index_names in self._get_existing_indexes(
                    pg_hook, snaked_tablenames,
                    pg_schema=self.pg_schema).items():
                if table_name not in existing_tables_lookup:
                    continue
                for index_name in index_names:
                    new_table_name = existing_tables_lookup[table_name].id
                    new_index_name = index_name.replace(
                        table_name,
                        to_snake_case(f"{dataset.id}_{new_table_name}"))
                    if index_name != new_index_name:
                        sqls.append(
                            f"""ALTER INDEX {self.pg_schema}.{index_name}
                                RENAME TO {new_index_name}""")

        for snaked_tablename, table in existing_tables_lookup.items():
            for field in table.fields:
                provenance = field.get("provenance")
                if provenance is not None:
                    snaked_field_name = to_snake_case(field.name)
                    if "relation" in field:
                        snaked_field_name += "_id"
                    if provenance.lower(
                    ) in existing_columns[snaked_tablename]:
                        # quotes are applied on the provenance name in case the source uses a space in the name
                        sqls.append(
                            f"""ALTER TABLE {self.pg_schema}.{snaked_tablename}
                                RENAME COLUMN "{provenance}" TO {snaked_field_name}"""
                        )

            provenance = table.get("provenance")
            if provenance is not None:
                sqls.append(
                    f"""ALTER TABLE IF EXISTS {self.pg_schema}.{snaked_tablename}
                            RENAME TO {to_snake_case(table.id)}""")

        pg_hook.run(sqls)
示例#4
0
        def _create_dataset_info(dataset_id: str,
                                 table_id: str) -> DatasetInfo:
            dataset = schema_def_from_url(SCHEMA_URL,
                                          dataset_id,
                                          prefetch_related=True)
            # Fetch the db_name for this dataset and table
            db_table_name = dataset.get_table_by_id(table_id).db_name()

            # We do not pass the dataset through xcom, but only the id.
            # The methodtools.lru_cache decorator is not pickleable
            # (Airflow uses pickle for (de)serialization).
            # provide the dataset_table_id as fully qualified name, for convenience
            dataset_table_id = f"{dataset_id}_{table_id}"
            return DatasetInfo(SCHEMA_URL, dataset_id, table_id,
                               dataset_table_id, db_table_name)
示例#5
0
    def execute(self, context=None):
        dataset = schema_def_from_url(SCHEMA_URL, self.dataset_name)
        pg_hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)

        table_names = self.additional_table_names or []
        for table in dataset.tables:
            table_names.append(table.id)
            provenance_tablename = table.get("provenance")
            if provenance_tablename is not None:
                table_names.append(provenance_tablename)

        sqls = [
            f"DROP TABLE IF EXISTS {self.pg_schema}.{to_snake_case(table_name)} CASCADE"
            for table_name in table_names
        ]

        pg_hook.run(sqls)
    def execute(self, context: Optional[Dict] = None) -> None:
        """Moves database objects (in this case tables) to other schema owner

        Args:
            context: When this operator is created the context parameter is used
                to refer to get_template_context for more context as part of
                inheritance of the BaseOperator. It is set to None in this case.

        Executes:
            SQL alter statement to change the schema owner of the table so the
            table is moved to the defined schema (a.k.a. schema swapping)

        """
        dataset = schema_def_from_url(SCHEMA_URL, self.dataset_name)
        pg_hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)

        sqls = []
        dataset_id = to_snake_case(dataset.id)
        tables = dataset.tables

        if self.subset_tables:
            subset_tables = [
                to_snake_case(table) for table in self.subset_tables
            ]
            tables = [
                table for table in tables
                if to_snake_case(table["id"]) in subset_tables
            ]

        for table in tables:
            table_id = to_snake_case(table.id)
            sqls.append(f"""
                DROP TABLE IF EXISTS {self.to_pg_schema}.{dataset_id}_{table_id};
                ALTER TABLE IF EXISTS {self.from_pg_schema}.{table_id}
                    SET SCHEMA {self.to_pg_schema};
                ALTER TABLE IF EXISTS {table_id}
                    RENAME TO {dataset_id}_{table_id}; """)
        pg_hook.run(sqls)
示例#7
0
    def execute(self,
                context: Optional[Dict[str, Any]] = None) -> None:  # NoQA C901
        """translates table, column and index names based on provenance
         specification in schema

        Args:
            context: When this operator is created the context parameter is used
                to refer to get_template_context for more context as part of
                inheritance of the BaseOperator. It is set to None in this case.

        Executes:
            SQL alter statements to change database table names, columns and or indexes

        """
        dataset = schema_def_from_url(SCHEMA_URL, self.dataset_name)
        pg_hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)
        sqls = []
        existing_tables_lookup = self._get_existing_tables(
            pg_hook, dataset.tables, pg_schema=self.pg_schema)
        snaked_tablenames = existing_tables_lookup.keys()
        existing_columns = self._get_existing_columns(pg_hook,
                                                      snaked_tablenames,
                                                      pg_schema=self.pg_schema)

        if self.rename_indexes:
            for table_name, index_names in self._get_existing_indexes(
                    pg_hook, snaked_tablenames,
                    pg_schema=self.pg_schema).items():
                if table_name not in existing_tables_lookup:
                    continue
                for index_name in index_names:
                    new_table_name = existing_tables_lookup[table_name].id
                    new_index_name = index_name.replace(
                        table_name,
                        to_snake_case(f"{dataset.id}_{new_table_name}"))
                    if index_name != new_index_name:
                        sqls.append(
                            f"""ALTER INDEX {self.pg_schema}.{index_name}
                                RENAME TO {new_index_name}""")

        for snaked_tablename, table in existing_tables_lookup.items():
            for field in table.fields:
                provenance = field.get("provenance")
                if provenance is not None:
                    snaked_field_name = to_snake_case(field.name)
                    if "relation" in field:
                        snaked_field_name += "_id"
                    if provenance.lower(
                    ) in existing_columns[snaked_tablename]:
                        # quotes are applied on the provenance name in case the
                        # source uses a space in the name
                        sqls.append(
                            f"""ALTER TABLE {self.pg_schema}.{snaked_tablename}
                                RENAME COLUMN "{provenance}" TO {snaked_field_name}"""
                        )

            provenance = table.get("provenance")
            if provenance is not None:
                sqls.append(
                    f"""ALTER TABLE IF EXISTS {self.pg_schema}.{snaked_tablename}
                            RENAME TO {to_snake_case(table.id)}""")

        pg_hook.run(sqls)
示例#8
0
    def execute(self, context):  # NoQA
        # When doing 'airflow test' there is a context['params']
        # For full dag runs, there is dag_run["conf"]
        dag_run = context["dag_run"]
        if dag_run is None:
            params = context["params"]
        else:
            params = dag_run.conf or {}
        self.log.debug("PARAMS: %s", params)
        max_records = params.get("max_records", self.max_records)
        cursor_pos = params.get("cursor_pos", Variable.get(f"{self.db_table_name}.cursor_pos", 0))
        batch_size = params.get("batch_size", self.batch_size)
        with TemporaryDirectory() as temp_dir:
            tmp_file = Path(temp_dir) / "out.ndjson"
            http = HttpParamsHook(http_conn_id=self.http_conn_id, method="POST")

            self.log.info("Calling GOB graphql endpoint")

            # we know the schema, can be an input param (schema_def_from_url function)
            # We use the ndjson importer from schematools, give it a tmp tablename
            pg_hook = PostgresHook()
            schema_def = schema_def_from_url(SCHEMA_URL, self.dataset)
            importer = NDJSONImporter(schema_def, pg_hook.get_sqlalchemy_engine(), logger=self.log)

            importer.generate_db_objects(
                table_name=self.schema,
                db_table_name=f"{self.db_table_name}_new",
                ind_tables=True,
                ind_extra_index=False,
            )
            # For GOB content, cursor value is exactly the same as
            # the record index. If this were not true, the cursor needed
            # to be obtained from the last content record
            records_loaded = 0

            with self.graphql_query_path.open() as gql_file:
                query = gql_file.read()

            # Sometime GOB-API fail with 500 error, caught by Airflow
            # We retry several times
            while True:

                force_refresh_token = False
                for i in range(3):
                    try:
                        request_start_time = time.time()
                        headers = self._fetch_headers(force_refresh=force_refresh_token)
                        response = http.run(
                            self.endpoint,
                            self._fetch_params(),
                            json.dumps(
                                dict(
                                    query=self.add_batch_params_to_query(
                                        query, cursor_pos, batch_size
                                    )
                                )
                            ),
                            headers=headers,
                            extra_options={"stream": True},
                        )
                    except AirflowException:
                        self.log.exception("Cannot reach %s", self.endpoint)
                        force_refresh_token = True
                        time.sleep(1)
                    else:
                        break
                else:
                    # Save cursor_pos in a variable
                    Variable.set(f"{self.db_table_name}.cursor_pos", cursor_pos)
                    raise AirflowException("All retries on GOB-API have failed.")

                records_loaded += batch_size
                # No records returns one newline and a Content-Length header
                # If records are available, there is no Content-Length header
                if int(response.headers.get("Content-Length", "2")) < 2:
                    break
                # When content is encoded (gzip etc.) we need this:
                # response.raw.read = functools.partial(response.raw.read, decode_content=True)
                try:
                    with tmp_file.open("wb") as wf:
                        shutil.copyfileobj(response.raw, wf, self.copy_bufsize)

                    request_end_time = time.time()
                    self.log.info(
                        "GOB-API request took %s seconds, cursor: %s",
                        request_end_time - request_start_time,
                        cursor_pos,
                    )
                    last_record = importer.load_file(tmp_file)
                except (SQLAlchemyError, ProtocolError, UnicodeDecodeError) as e:
                    # Save last imported file for further inspection
                    shutil.copy(
                        tmp_file,
                        f"/tmp/{self.db_table_name}-{datetime.now().isoformat()}.ndjson",
                    )
                    Variable.set(f"{self.db_table_name}.cursor_pos", cursor_pos)
                    raise AirflowException("A database error has occurred.") from e

                self.log.info(
                    "Loading db took %s seconds",
                    time.time() - request_end_time,
                )
                if last_record is None or (
                    max_records is not None and records_loaded >= max_records
                ):
                    break
                cursor_pos = last_record["cursor"]

        # On successfull completion, remove cursor_pos variable
        Variable.delete(f"{self.db_table_name}.cursor_pos")