コード例 #1
0
def test_infer_schema_parquet():
    with tempfile.TemporaryFile(mode="w+b") as file:
        test_table.to_parquet(file)
        file.seek(0)

        fields = parquet.ParquetInferrer().infer_schema(file)
        fields.sort(key=lambda x: x.fieldPath)

        assert_field_paths_match(fields, expected_field_paths)
        assert_field_types_match(fields, expected_field_types)
コード例 #2
0
    def get_fields(self, table_data: TableData, path_spec: PathSpec) -> List:
        if table_data.is_s3:
            if self.source_config.aws_config is None:
                raise ValueError("AWS config is required for S3 file sources")

            s3_client = self.source_config.aws_config.get_s3_client()

            file = smart_open(table_data.full_path,
                              "rb",
                              transport_params={"client": s3_client})
        else:

            file = open(table_data.full_path, "rb")

        fields = []

        extension = pathlib.Path(table_data.full_path).suffix
        if extension == "" and path_spec.default_extension:
            extension = path_spec.default_extension

        try:
            if extension == ".parquet":
                fields = parquet.ParquetInferrer().infer_schema(file)
            elif extension == ".csv":
                fields = csv_tsv.CsvInferrer(
                    max_rows=self.source_config.max_rows).infer_schema(file)
            elif extension == ".tsv":
                fields = csv_tsv.TsvInferrer(
                    max_rows=self.source_config.max_rows).infer_schema(file)
            elif extension == ".json":
                fields = json.JsonInferrer().infer_schema(file)
            elif extension == ".avro":
                fields = avro.AvroInferrer().infer_schema(file)
            else:
                self.report.report_warning(
                    table_data.full_path,
                    f"file {table_data.full_path} has unsupported extension",
                )
            file.close()
        except Exception as e:
            self.report.report_warning(
                table_data.full_path,
                f"could not infer schema for file {table_data.full_path}: {e}",
            )
            file.close()
        logger.debug(f"Extracted fields in schema: {fields}")
        fields = sorted(fields, key=lambda f: f.fieldPath)

        return fields
コード例 #3
0
ファイル: __init__.py プロジェクト: swaroopjagadish/datahub
    def get_table_schema(self, file_path: str, table_name: str,
                         is_aws: bool) -> Iterable[MetadataWorkUnit]:

        data_platform_urn = make_data_platform_urn(self.source_config.platform)
        dataset_urn = make_dataset_urn(self.source_config.platform, table_name,
                                       self.source_config.env)

        dataset_name = os.path.basename(file_path)

        dataset_snapshot = DatasetSnapshot(
            urn=dataset_urn,
            aspects=[],
        )

        dataset_properties = DatasetPropertiesClass(
            description="",
            customProperties={},
        )
        dataset_snapshot.aspects.append(dataset_properties)

        if is_aws:
            if self.source_config.aws_config is None:
                raise ValueError("AWS config is required for S3 file sources")

            s3_client = self.source_config.aws_config.get_s3_client()

            file = smart_open(f"s3://{file_path}",
                              "rb",
                              transport_params={"client": s3_client})

        else:

            file = open(file_path, "rb")

        fields = []

        try:
            if file_path.endswith(".parquet"):
                fields = parquet.ParquetInferrer().infer_schema(file)
            elif file_path.endswith(".csv"):
                fields = csv_tsv.CsvInferrer(
                    max_rows=self.source_config.max_rows).infer_schema(file)
            elif file_path.endswith(".tsv"):
                fields = csv_tsv.TsvInferrer(
                    max_rows=self.source_config.max_rows).infer_schema(file)
            elif file_path.endswith(".json"):
                fields = json.JsonInferrer().infer_schema(file)
            elif file_path.endswith(".avro"):
                fields = avro.AvroInferrer().infer_schema(file)
            else:
                self.report.report_warning(
                    file_path, f"file {file_path} has unsupported extension")
            file.close()
        except Exception as e:
            self.report.report_warning(
                file_path, f"could not infer schema for file {file_path}: {e}")
            file.close()

        fields = sorted(fields, key=lambda f: f.fieldPath)
        schema_metadata = SchemaMetadata(
            schemaName=dataset_name,
            platform=data_platform_urn,
            version=0,
            hash="",
            fields=fields,
            platformSchema=OtherSchemaClass(rawSchema=""),
        )

        dataset_snapshot.aspects.append(schema_metadata)

        mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot)
        wu = MetadataWorkUnit(id=file_path, mce=mce)
        self.report.report_workunit(wu)
        yield wu