예제 #1
0
def test_bq_job_information():
    job_details = read_file_json('tests/bigquery/job_details.json')
    client = MagicMock()
    client.get_job.return_value._properties = job_details

    client.get_table.return_value = TableMock()

    statistics = BigQueryDatasetsProvider(client=client).get_facets("job_id")

    assert statistics.run_facets == {
        'bigQuery_job':
        BigQueryJobRunFacet(cached=False,
                            billedBytes=111149056,
                            properties=json.dumps(job_details))
    }
    assert statistics.inputs == [
        Dataset(source=Source(scheme='bigquery', connection_url='bigquery'),
                name='bigquery-public-data.usa_names.usa_1910_2013',
                fields=[
                    Field('state', 'STRING', [], '2-digit state code'),
                    Field('gender', 'STRING', [], 'Sex (M=male or F=female)'),
                    Field('year', 'INTEGER', [], '4-digit year of birth'),
                    Field('name', 'STRING', [],
                          'Given name of a person at birth'),
                    Field('number', 'INTEGER', [],
                          'Number of occurrences of the name')
                ])
    ]
    assert statistics.output == Dataset(
        source=Source(scheme='bigquery', connection_url='bigquery'),
        name='bq-airflow-openlineage.new_dataset.output_table',
    )
예제 #2
0
 def extract_on_complete(self, task_instance) -> TaskMetadata:
     inputs = [
         Dataset.from_table_schema(self.source, DbTableSchema(
             schema_name='schema',
             table_name=DbTableName('extract_on_complete_input1'),
             columns=[DbColumn(
                 name='field1',
                 type='text',
                 description='',
                 ordinal_position=1
             ),
                 DbColumn(
                 name='field2',
                 type='text',
                 description='',
                 ordinal_position=2
             )]
         )).to_openlineage_dataset()
     ]
     outputs = [
         Dataset.from_table(self.source, "extract_on_complete_output1").to_openlineage_dataset()
     ]
     return TaskMetadata(
         name=get_job_name(task=self.operator),
         inputs=inputs,
         outputs=outputs,
     )
예제 #3
0
 def extract(self) -> TaskMetadata:
     inputs = [
         Dataset.from_table(self.source, "extract_input1").to_openlineage_dataset()
     ]
     outputs = [
         Dataset.from_table(self.source, "extract_output1").to_openlineage_dataset()
     ]
     return TaskMetadata(
         name=get_job_name(task=self.operator),
         inputs=inputs,
         outputs=outputs
     )
예제 #4
0
 def _fetch_datasets_from_pandas_source(
     self, data_asset: PandasDataset,
     validation_result_suite: ExpectationSuiteValidationResult
 ) -> List[OLDataset]:  # noqa
     """
     Generate a list of OpenLineage Datasets from a PandasDataset
     :param data_asset:
     :param validation_result_suite:
     :return:
     """
     if data_asset.batch_kwargs.__contains__("path"):
         path = data_asset.batch_kwargs.get("path")
         if path.startswith("/"):
             path = "file://{}".format(path)
         parsed_url = urlparse(path)
         columns = [
             Field(name=col,
                   type=str(data_asset[col].dtype)
                   if data_asset[col].dtype is not None else 'UNKNOWN')
             for col in data_asset.columns
         ]
         return [
             Dataset(source=self._source(parsed_url._replace(path='')),
                     name=parsed_url.path,
                     fields=columns,
                     input_facets=self.results_facet(
                         validation_result_suite)).to_openlineage_dataset()
         ]
예제 #5
0
def test_extract_authority_uri(get_connection, mock_get_table_schemas):

    mock_get_table_schemas.side_effect = \
        [[DB_TABLE_SCHEMA], NO_DB_TABLE_SCHEMA]

    conn = Connection()
    conn.parse_from_uri(uri=CONN_URI)
    get_connection.return_value = conn

    expected_inputs = [
        Dataset(
            name=f"{DB_NAME}.{DB_SCHEMA_NAME}.{DB_TABLE_NAME.name}",
            source=Source(
                scheme='postgres',
                authority='localhost:5432',
                connection_url=CONN_URI_WITHOUT_USERPASS
            ),
            fields=[]
        ).to_openlineage_dataset()]

    task_metadata = PostgresExtractor(TASK).extract()

    assert task_metadata.name == f"{DAG_ID}.{TASK_ID}"
    assert task_metadata.inputs == expected_inputs
    assert task_metadata.outputs == []
예제 #6
0
def test_extract(get_connection, mock_get_table_schemas):
    mock_get_table_schemas.side_effect = \
        [[DB_TABLE_SCHEMA], NO_DB_TABLE_SCHEMA]

    conn = Connection(
        conn_id=CONN_ID,
        conn_type='postgres',
        host='localhost',
        port='5432',
        schema='food_delivery'
    )

    get_connection.return_value = conn

    expected_inputs = [
        Dataset(
            name=f"{DB_NAME}.{DB_SCHEMA_NAME}.{DB_TABLE_NAME.name}",
            source=Source(
                scheme='postgres',
                authority='localhost:5432',
                connection_url=CONN_URI_WITHOUT_USERPASS
            ),
            fields=[]
        ).to_openlineage_dataset()]

    # Set the environment variable for the connection
    os.environ[f"AIRFLOW_CONN_{CONN_ID.upper()}"] = CONN_URI

    task_metadata = PostgresExtractor(TASK).extract()

    assert task_metadata.name == f"{DAG_ID}.{TASK_ID}"
    assert task_metadata.inputs == expected_inputs
    assert task_metadata.outputs == []
예제 #7
0
def test_extract(get_connection, mock_get_table_schemas):
    mock_get_table_schemas.side_effect = \
        [[DB_TABLE_SCHEMA], NO_DB_TABLE_SCHEMA]

    conn = Connection()
    conn.parse_from_uri(uri=CONN_URI)
    get_connection.return_value = conn

    TASK.get_hook = mock.MagicMock()
    TASK.get_hook.return_value._get_conn_params.return_value = {
        'account': 'test_account',
        'database': DB_NAME
    }

    expected_inputs = [
        Dataset(name=f"{DB_NAME}.{DB_SCHEMA_NAME}.{DB_TABLE_NAME.name}",
                source=Source(scheme='snowflake',
                              authority='test_account',
                              connection_url=CONN_URI),
                fields=[]).to_openlineage_dataset()
    ]

    # Set the environment variable for the connection
    os.environ[f"AIRFLOW_CONN_{CONN_ID.upper()}"] = CONN_URI

    task_metadata = SnowflakeExtractor(TASK).extract()

    assert task_metadata.name == f"{DAG_ID}.{TASK_ID}"
    assert task_metadata.inputs == expected_inputs
    assert task_metadata.outputs == []
예제 #8
0
    def _get_output_from_bq(self, properties) -> Optional[Dataset]:
        bq_output_table = get_from_nullable_chain(
            properties, ['configuration', 'query', 'destinationTable'])
        if not bq_output_table:
            return None

        output_table_name = self._bq_table_name(bq_output_table)
        source = self._source()

        table_schema = self._get_table_safely(output_table_name)
        if table_schema:
            return Dataset.from_table_schema(
                source=source,
                table_schema=table_schema,
            )
        else:
            self.logger.warning("Could not resolve output table from bq")
            return Dataset.from_table(source, output_table_name)
예제 #9
0
    def extract(self) -> TaskMetadata:
        # (1) Parse sql statement to obtain input / output tables.
        sql_meta: SqlMeta = SqlParser.parse(self.operator.sql,
                                            self.default_schema)

        # (2) Get database connection
        self.conn = get_connection(self._conn_id())

        # (3) Default all inputs / outputs to current connection.
        # NOTE: We'll want to look into adding support for the `database`
        # property that is used to override the one defined in the connection.
        source = Source(scheme=self._get_scheme(),
                        authority=self._get_authority(),
                        connection_url=self._get_connection_uri())

        database = self.operator.database
        if not database:
            database = self._get_database()

        # (4) Map input / output tables to dataset objects with source set
        # as the current connection. We need to also fetch the schema for the
        # input tables to format the dataset name as:
        # {schema_name}.{table_name}
        inputs = [
            Dataset.from_table(source=source,
                               table_name=in_table_schema.table_name.name,
                               schema_name=in_table_schema.schema_name,
                               database_name=database)
            for in_table_schema in self._get_table_schemas(sql_meta.in_tables)
        ]
        outputs = [
            Dataset.from_table_schema(source=source,
                                      table_schema=out_table_schema,
                                      database_name=database) for
            out_table_schema in self._get_table_schemas(sql_meta.out_tables)
        ]

        return TaskMetadata(
            name=f"{self.operator.dag_id}.{self.operator.task_id}",
            inputs=[ds.to_openlineage_dataset() for ds in inputs],
            outputs=[ds.to_openlineage_dataset() for ds in outputs],
            job_facets={'sql': SqlJobFacet(self.operator.sql)})
예제 #10
0
    def _get_input_from_bq(self, properties):
        bq_input_tables = get_from_nullable_chain(
            properties, ['statistics', 'query', 'referencedTables'])
        if not bq_input_tables:
            return []

        input_table_names = [
            self._bq_table_name(bq_t) for bq_t in bq_input_tables
        ]
        sources = [self._source() for bq_t in bq_input_tables]
        try:
            return [
                Dataset.from_table_schema(source=source,
                                          table_schema=table_schema)
                for table_schema, source in zip(
                    self._get_table_schemas(input_table_names), sources)
            ]
        except Exception as e:
            self.logger.warning(f'Could not extract schema from bigquery. {e}')
            return [
                Dataset.from_table(source, table)
                for table, source in zip(input_table_names, sources)
            ]
예제 #11
0
def test_dataset_to_openlineage(table_schema):
    source_name = 'dummy://localhost:1234'
    source, columns, schema = table_schema

    dataset_schema = Dataset.from_table_schema(source, schema)
    assert dataset_schema.to_openlineage_dataset() == OpenLineageDataset(
        namespace=source_name,
        name='public.discounts',
        facets={
            'dataSource':
            DataSourceDatasetFacet(name=source_name, uri=source_name),
            'schema':
            SchemaDatasetFacet(fields=[
                SchemaField(name='id', type='int4'),
                SchemaField(name='amount_off', type='int4'),
                SchemaField(name='customer_email', type='varchar'),
                SchemaField(name='starts_on', type='timestamp'),
                SchemaField(name='ends_on', type='timestamp')
            ])
        })
예제 #12
0
    def _get_sql_table(
        self, data_asset: SqlAlchemyDataset, meta: MetaData, schema: str,
        table_name: str,
        validation_result_suite: ExpectationSuiteValidationResult
    ) -> Optional[OLDataset]:  # noqa
        """
        Construct a Dataset from the connection url and the columns returned from the
        SqlAlchemyDataset
        :param data_asset:
        :return:
        """
        engine = data_asset.engine
        if isinstance(engine, Connection):
            engine = engine.engine
        datasource_url = engine.url
        if engine.dialect.name.lower() == "bigquery":
            schema = '{}.{}'.format(datasource_url.host,
                                    datasource_url.database)

        table = Table(table_name, meta, autoload_with=engine)

        fields = [
            Field(name=key,
                  type=str(col.type) if col.type is not None else 'UNKNOWN',
                  description=col.doc) for key, col in table.columns.items()
        ]

        name = table_name \
            if schema is None \
            else "{}.{}".format(schema, table_name)

        results_facet = self.results_facet(validation_result_suite)
        return Dataset(source=self._source(urlparse(str(datasource_url))),
                       fields=fields,
                       name=name,
                       input_facets=results_facet).to_openlineage_dataset()
예제 #13
0
def test_dataset_with_db_name(source):
    dataset = Dataset.from_table(source, 'source_table', 'public',
                                 'food_delivery')
    assert dataset == Dataset(source=source,
                              name='food_delivery.public.source_table')
예제 #14
0
def test_dataset_from(source):
    dataset = Dataset.from_table(source, 'source_table', 'public')
    assert dataset == Dataset(source=source, name='public.source_table')