def mapXmlAsHadoopFile(spark: SparkSession, appconfig: ConfigParser,
                       location) -> DataFrame:
    xml: RDD = spark.sparkContext.newAPIHadoopFile(
        f'{location}',
        'com.databricks.spark.xml.XmlInputFormat',
        'org.apache.hadoop.io.LongWritable',
        'org.apache.hadoop.io.Text',
        conf={
            'xmlinput.start':
            f'{str(appconfig["Xml"]["XmlRecordStart"]).strip()}',
            'xmlinput.end': f'{str(appconfig["Xml"]["XmlRecordEnd"]).strip()}',
            'xmlinput.encoding': 'utf-8'
        })
    df: DataFrame

    if str(appconfig['Xml']['EliminateNewLines']).strip().__eq__('true'):
        df = spark.createDataFrame(
            xml.map(lambda x: str(x[1]).strip().replace('\n', '')),
            StringType()).withColumnRenamed('value', 'line')
    else:
        df = spark.createDataFrame(xml.map(lambda x: str(x[1]).strip()),
                                   StringType()).withColumnRenamed(
                                       'value', 'line')

    return df
示例#2
0
def df_base(spark: SparkSession) -> DataFrame:

    data = [(1, "Finance", 10), (2, "Marketing", 20), (3, "Sales", 30), (4, "IT", 40), (5, "CTS", 41), (6, "CTS", 42)]
    for _ in range(20):
        data.extend(data)
    deptColumns = ["ID", "dept_name", "dept_id"]
    return spark.createDataFrame(data=data, schema=deptColumns)
def test_load_processor(spark_session: SparkSession):
    load_options = {
        'header': 'true'
    }

    default_props = PropertyGroup()
    default_props.set_property(
        LoadProcessor.PATH, f'{FIXTURE_DIR}/sample_load.csv')
    default_props.set_property(LoadProcessor.FORMAT, 'csv')

    property_groups = PropertyGroups()
    property_groups.set_property_group(
        LoadProcessor.LOAD_OPTIONS_GROUP, load_options)
    property_groups.set_property_group(
        LoadProcessor.DEFAULT_PROPS_GROUP, default_props
    )

    processor_context = ProcessorContext(spark_session)
    processor_context.set_property_group(
        LoadProcessor.LOAD_OPTIONS_GROUP, load_options)
    processor_context.set_property_group(
        LoadProcessor.DEFAULT_PROPS_GROUP, default_props
    )
    processor = LoadProcessor()
    output = processor.run(processor_context)
    actual = output.df.collect()
    expected_data = [{'name': 'xyz', 'contact': '123'}]
    expected = spark_session.createDataFrame(expected_data) \
        .select('name', 'contact').collect()
    assert actual == expected
示例#4
0
def test_can_run_framework_pipeline(spark_session: SparkSession) -> None:
    # Arrange
    data_dir: Path = Path(__file__).parent.joinpath("./")
    flights_path: str = f"file://{data_dir.joinpath('flights.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    spark_session.sql("DROP TABLE IF EXISTS default.flights")

    # Act
    parameters = {"flights_path": flights_path}

    with ProgressLogger() as progress_logger:
        pipeline: MyPipeline = MyPipeline(parameters=parameters,
                                          progress_logger=progress_logger)
        transformer = pipeline.fit(df)
        transformer.transform(df)

    # Assert
    result_df: DataFrame = spark_session.sql("SELECT * FROM flights2")
    result_df.show()

    assert result_df.count() > 0
示例#5
0
def test_correctly_loads_csv_with_clean_flag_on(
        spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('column_name_test.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkCsvLoader(
        view="my_view",
        filepath=test_file_path,
        delimiter=",",
        clean_column_names=True,
    ).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    # Assert
    assert_results(result)
    assert result.collect()[1][0] == "2"
    assert (result.columns[2] ==
            "Ugly_column_with_chars_that_parquet_does_not_like_much_-")
示例#6
0
def test_simple_csv_and_sql_pipeline(spark_session: SparkSession) -> None:
    # Arrange
    data_dir: Path = Path(__file__).parent.joinpath("./")
    flights_path: str = f"file://{data_dir.joinpath('flights.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    spark_session.sql("DROP TABLE IF EXISTS default.flights")

    # Act
    parameters: Dict[str, Any] = {}

    stages: List[Transformer] = create_steps([
        FrameworkCsvLoader(view="flights", filepath=flights_path),
        FeaturesCarriersV1(parameters=parameters),
    ])

    pipeline: Pipeline = Pipeline(stages=stages)  # type: ignore
    transformer = pipeline.fit(df)
    transformer.transform(df)

    # Assert
    result_df: DataFrame = spark_session.sql("SELECT * FROM flights2")
    result_df.show()

    assert result_df.count() > 0
示例#7
0
def test_can_load_parquet(spark_session: SparkSession):
    # Arrange
    SparkTestHelper.clear_tables(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath('./')
    test_file_path: str = f"{data_dir.joinpath('test.csv')}"

    if path.isdir(data_dir.joinpath('temp')):
        shutil.rmtree(data_dir.joinpath('temp'))

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    parquet_file_path: str = ParquetHelper.create_parquet_from_csv(
        spark_session=spark_session, file_path=test_file_path)

    # Act
    FrameworkParquetLoader(view="my_view",
                           file_path=parquet_file_path).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    result.show()

    # Assert
    assert result.count() == 3

    assert result.collect()[1][0] == 2
    assert result.collect()[1][1] == "bar"
    assert result.collect()[1][2] == "bar2"
def test_can_convert_json_folder_to_jsonl(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test_files')}"

    temp_folder = data_dir.joinpath("temp")
    if path.isdir(temp_folder):
        rmtree(temp_folder)
    makedirs(temp_folder)

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkJsonToJsonlConverter(file_path=test_file_path,
                                  output_folder=temp_folder).transform(df)

    # Assert
    with open(temp_folder.joinpath("test.json"), "r+") as file:
        lines: List[str] = file.readlines()
        assert len(lines) == 2
        assert (
            lines[0] ==
            '{"title":"A Philosophy of Software Design","authors":[{"given":["John"],"surname":"Ousterhout"}],"edition":null}\n'
        )
        assert (
            lines[1] ==
            '{"title":"Essentials of Programming Languages","authors":[{"given":["Dan","P."],"surname":"Friedman"},{"given":["Mitchell"],"surname":"Wand"}],"edition":3}\n'
        )
示例#9
0
def test_simple_csv_loader_pipeline(spark_session: SparkSession) -> None:
    # Arrange
    data_dir: Path = Path(__file__).parent.joinpath('./')
    flights_path: str = f"file://{data_dir.joinpath('flights.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # noinspection SqlDialectInspection,SqlNoDataSourceInspection
    spark_session.sql("DROP TABLE IF EXISTS default.flights")

    # Act
    # parameters = Dict[str, Any]({
    # })

    stages: List[Union[Estimator, Transformer]] = [
        FrameworkCsvLoader(
            view="flights",
            path_to_csv=flights_path
        ),
        SQLTransformer(statement="SELECT * FROM flights"),
    ]

    pipeline: Pipeline = Pipeline(stages=stages)

    transformer = pipeline.fit(df)
    result_df: DataFrame = transformer.transform(df)

    # Assert
    result_df.show()

    assert result_df.count() > 0
示例#10
0
def test_validation_recurses_query_dir(spark_session: SparkSession) -> None:
    clean_spark_session(spark_session)
    query_dir: Path = Path(__file__).parent.joinpath("./queries")
    more_queries_dir: str = "more_queries"
    data_dir: Path = Path(__file__).parent.joinpath("./data")
    test_data_file: str = f"{data_dir.joinpath('test.csv')}"
    validation_query_file: str = "validate.sql"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    FrameworkCsvLoader(view="my_view", filepath=test_data_file).transform(df)

    FrameworkValidationTransformer(
        validation_source_path=str(query_dir),
        validation_queries=[validation_query_file, more_queries_dir],
    ).transform(df)

    df_validation = df.sql_ctx.table("pipeline_validation")
    df_validation.show(truncate=False)
    assert 3 == df_validation.count(
    ), "Expected 3 total rows in pipeline_validation"
    assert (1 == df_validation.filter("is_failed == 1").count()
            ), "Expected one failing row in the validation table"
示例#11
0
def test_can_load_xml_file_with_schema(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.xml')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    xml_shema = StructType([
        StructField("_id", StringType(), True),
        StructField("author", StringType(), True),
        StructField("description", StringType(), True),
        StructField("genre", StringType(), True),
        StructField("price", DoubleType(), True),
        StructField("publish_date", StringType(), True),
        StructField("title", StringType(), True),
    ])
    # Act
    FrameworkXmlLoader(view="my_view",
                       filepath=test_file_path,
                       row_tag="book",
                       schema=xml_shema).transform(df)

    result: DataFrame = spark_session.sql("SELECT * FROM my_view")
    result.show()
    assert result.count() == 12
    assert len(result.columns) == 7
示例#12
0
def test_spark_ml_vectors(spark: SparkSession, tmp_path: Path):
    test_dir = str(tmp_path)
    df = spark.createDataFrame([
        {
            "name": "a",
            "vec": Vectors.dense([1, 2])
        },
        {
            "name": "b",
            "vec": Vectors.dense([10])
        },
    ])
    df.write.mode("overwrite").parquet(str(tmp_path))

    d = spark.read.parquet(test_dir)
    d.show()

    records = _read_parquets(test_dir)
    records = sorted(records, key=lambda x: x["name"])

    expected = [
        {
            "name": "a",
            "vec": np.array([1, 2], dtype=np.float64)
        },
        {
            "name": "b",
            "vec": np.array([10], dtype=np.float64)
        },
    ]

    for exp, rec in zip(expected, records):
        assert exp["name"] == rec["name"]
        assert np.array_equal(exp["vec"], rec["vec"])
示例#13
0
def test_spark_ml_matrix(spark: SparkSession, tmp_path: Path):
    test_dir = str(tmp_path)
    df = spark.createDataFrame([
        {
            "name": 1,
            "mat": DenseMatrix(2, 2, range(4))
        },
        {
            "name": 2,
            "mat": DenseMatrix(3, 3, range(9))
        },
    ])
    df.write.mode("overwrite").format("rikai").save(test_dir)
    df.show()

    records = sorted(_read_parquets(test_dir), key=lambda x: x["name"])

    expected = [
        {
            "name": 1,
            "mat": np.array(range(4), dtype=np.float64).reshape(2, 2).T,
        },
        {
            "name": 2,
            "mat": np.array(range(9), dtype=np.float64).reshape(3, 3).T,
        },
    ]
    for exp, rec in zip(expected, records):
        assert exp["name"] == rec["name"]
        assert np.array_equal(exp["mat"], rec["mat"])
示例#14
0
def test_can_keep_columns(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.csv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkCsvLoader(view="my_view", filepath=test_file_path,
                       delimiter=",").transform(df)

    FrameworkSelectColumnsTransformer(view="my_view",
                                      keep_columns=["Column2"]).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    result.show()

    # Assert
    assert len(result.columns) == 1

    assert result.count() == 3

    assert result.collect()[1][0] == "bar"
示例#15
0
def test_can_load_non_standard_delimited_csv(
        spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.psv')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    loader = FrameworkCsvLoader(view="my_view",
                                filepath=test_file_path,
                                delimiter="|")
    loader.transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    result.show()

    # Assert
    assert loader.getDelimiter() == "|"
    assert_results(result)
def test_fail_fast_validated_framework_pipeline_writes_results(
    spark_session: SparkSession, ) -> None:
    # Arrange
    clean_spark_session(spark_session)
    data_dir: Path = Path(__file__).parent.joinpath("./")
    flights_path: str = f"file://{data_dir.joinpath('flights.csv')}"
    output_path: str = f"file://{data_dir.joinpath('temp').joinpath('validation.csv')}"

    if path.isdir(data_dir.joinpath("temp")):
        shutil.rmtree(data_dir.joinpath("temp"))

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    spark_session.sql("DROP TABLE IF EXISTS default.flights")

    # Act
    parameters = {
        "flights_path": flights_path,
        "validation_source_path": str(data_dir),
        "validation_output_path": output_path,
    }

    try:
        with ProgressLogger() as progress_logger:
            pipeline: MyFailFastValidatedPipeline = MyFailFastValidatedPipeline(
                parameters=parameters, progress_logger=progress_logger)
            transformer = pipeline.fit(df)
            transformer.transform(df)
    except AssertionError:
        validation_df = df.sql_ctx.read.csv(output_path, header=True)
        validation_df.show(truncate=False)
        assert validation_df.count() == 1
def test_can_run_validated_framework_pipeline(
        spark_session: SparkSession) -> None:
    with pytest.raises(AssertionError):
        # Arrange
        clean_spark_session(spark_session)
        data_dir: Path = Path(__file__).parent.joinpath("./")
        flights_path: str = f"file://{data_dir.joinpath('flights.csv')}"
        output_path: str = f"file://{data_dir.joinpath('temp').joinpath('validation.csv')}"

        if path.isdir(data_dir.joinpath("temp")):
            shutil.rmtree(data_dir.joinpath("temp"))

        schema = StructType([])

        df: DataFrame = spark_session.createDataFrame(
            spark_session.sparkContext.emptyRDD(), schema)

        spark_session.sql("DROP TABLE IF EXISTS default.flights")

        # Act
        parameters = {
            "flights_path": flights_path,
            "validation_source_path": str(data_dir),
            "validation_output_path": output_path,
        }

        with ProgressLogger() as progress_logger:
            pipeline: MyValidatedPipeline = MyValidatedPipeline(
                parameters=parameters, progress_logger=progress_logger)
            transformer = pipeline.fit(df)
            transformer.transform(df)
示例#18
0
def test_load_stream_processor(spark_session: SparkSession):

    schema = {
        'type':
        'struct',
        'fields': [
            {
                'name': 'name',
                'type': 'string',
                'nullable': False,
                'metadata': {}
            },
            {
                'name': 'contact',
                'type': 'integer',
                'nullable': False,
                'metadata': {}
            },
        ]
    }

    load_options = {
        'header': 'true',
        'inferSchema': 'true',
        'checkpointLocation': f'{TEST_DIR}/checkpoint'
    }

    default_props = PropertyGroup()
    default_props.set_property(LoadStreamProcessor.PATH,
                               f'{FIXTURE_DIR}/sample_load.csv')
    default_props.set_property(LoadStreamProcessor.FORMAT, 'csv')
    default_props.set_property(LoadStreamProcessor.SCHEMA, schema)

    property_groups = PropertyGroups()
    property_groups.set_property_group(LoadStreamProcessor.LOAD_OPTIONS_GROUP,
                                       load_options)
    property_groups.set_property_group(LoadStreamProcessor.DEFAULT_PROPS_GROUP,
                                       default_props)

    processor_context = ProcessorContext(spark_session,
                                         property_groups=property_groups)

    processor = LoadStreamProcessor()
    output = processor.run(processor_context)
    output_dir = f'{TEST_DIR}/stream_output'
    output.df.createOrReplaceTempView('input')
    output.df.writeStream.trigger(once=True) \
        .start(path=output_dir,
               format='csv',
               outputMode='append',
               **load_options) \
        .awaitTermination()

    actual = spark_session.read.options(**load_options) \
        .csv(output_dir).collect()
    expected_data = [{'name': 'xyz', 'contact': 123}]
    expected = spark_session.createDataFrame(expected_data).select(
        'name', 'contact').collect()
    assert actual == expected
示例#19
0
def test_bbox(spark: SparkSession, tmp_path: Path):
    test_dir = str(tmp_path)
    df = spark.createDataFrame([Row(b=Box2d(1, 2, 3, 4))])
    df.write.mode("overwrite").format("rikai").save(test_dir)

    records = _read_parquets(test_dir)

    assert_count_equal([{"b": Box2d(1, 2, 3, 4)}], records)
示例#20
0
def test_ac(spark: SparkSession) -> None:
    data = [
        ('James', 'Smith', 'M', 3000),
        ('Anna', 'Rose', 'F', 4100),
        ('Robert', 'Williams', 'M', 6200),
    ]

    columns = ["firstname", "lastname", "gender", "salary"]
    df = spark.createDataFrame(data=data, schema=columns)
    df.show()
示例#21
0
def test_can_load_fixed_width(spark_session: SparkSession) -> None:
    # Arrange
    clean_spark_session(spark_session)

    data_dir: Path = Path(__file__).parent.joinpath("./")
    test_file_path: str = f"{data_dir.joinpath('test.txt')}"

    schema = StructType([])

    df: DataFrame = spark_session.createDataFrame(
        spark_session.sparkContext.emptyRDD(), schema)

    # Act
    FrameworkFixedWidthLoader(
        view="my_view",
        filepath=test_file_path,
        columns=[
            ColumnSpec(column_name="id",
                       start_pos=1,
                       length=3,
                       data_type=StringType()),
            ColumnSpec(column_name="some_date",
                       start_pos=4,
                       length=8,
                       data_type=StringType()),
            ColumnSpec(
                column_name="some_string",
                start_pos=12,
                length=3,
                data_type=StringType(),
            ),
            ColumnSpec(
                column_name="some_integer",
                start_pos=15,
                length=4,
                data_type=IntegerType(),
            ),
        ],
    ).transform(df)

    # noinspection SqlDialectInspection
    result: DataFrame = spark_session.sql("SELECT * FROM my_view")

    result.show()

    # Assert
    assert result.count() == 2
    assert result.collect()[0][0] == "001"
    assert result.collect()[1][0] == "002"
    assert result.collect()[0][1] == "01292017"
    assert result.collect()[1][1] == "01302017"
    assert result.collect()[0][2] == "you"
    assert result.collect()[1][2] == "me"
    assert result.collect()[0][3] == 1234
    assert result.collect()[1][3] == 5678
示例#22
0
def main(argv):

    #se instancia el contexto de spark.
    sc = SparkContext(appName="KMeans-Clustering-dhoyoso-dsernae")
    #se inicia sesion en spark.
    spark = SparkSession(sc)
    #se guarda el lenguaje a partir del cual se quitaran las stop words.
    language = argv[4]  #"spanish"
    #se guarda la ruta para la salida de los clusters.
    pathout = argv[3]
    #se guarda la ruta de la cual se leeran los archivos.
    path = argv[2]  #"hdfs:///user/dhoyoso/datasets/dataset/"
    #se guarda el numero de clusters que se desea hacer.
    k = int(argv[1])  #4
    #se sacan los archivos a procesar a partir de la ruta.
    files = sc.wholeTextFiles(path)
    #se crea la estructura del dataframe; 2 columnas una para la ruta y otra para el texto.
    schema = StructType([
        StructField("path", StringType(), True),
        StructField("text", StringType(), True)
    ])
    #se crea el dataframe a partir de la estructura y los archivos.
    df = spark.createDataFrame(files, schema)
    #se tokeniza el texto usando la clase de Ml tokenizer.
    tokenizer = Tokenizer(inputCol="text", outputCol="tokens")
    #se le dice al stop words remover que idioma es el que estamos tratando.
    StopWordsRemover.loadDefaultStopWords(language)
    #se remueven las stopwords de los tokens.
    stopWords = StopWordsRemover(inputCol="tokens",
                                 outputCol="stopWordsRemovedTokens")
    #se hace el hashing tf de los tokens restantes.
    hashingTF = HashingTF(inputCol="stopWordsRemovedTokens",
                          outputCol="rawFeatures",
                          numFeatures=2000)
    #se hace el idf de la salida del hashingTF
    idf = IDF(inputCol="rawFeatures", outputCol="features", minDocFreq=1)
    #se inicializa el kmeans con el idf y el k deseado.
    kmeans = KMeans(k=k)
    #creacion del mapa de transformaciones.
    pipeline = Pipeline(stages=[tokenizer, stopWords, hashingTF, idf, kmeans])
    #inserta el dataframe como el inicio de las transformaciones
    model = pipeline.fit(df)
    #ejecuta las trasformaciones mapeadas y guarda el resultado
    results = model.transform(df)
    results.cache()
    #se corta la ruta para dejar solo el nombre y su respectivo cluster(prediction).
    split_col = split(results['path'], '/')
    results = results.withColumn('docname', split_col.getItem(7))
    df = results.select("docname", "prediction")

    #se agrupan los documentos del mismo cluster en cluster_docs_list y se guardan en el path de salida como un json.
    grouped = df.groupBy(['prediction']).agg(
        collect_list("docname").alias('cluster_docs_list'))
    grouped.coalesce(1).write.json(path=pathout, mode="overwrite")
示例#23
0
def df_group_status(spark: SparkSession) -> DataFrame:
    """Generate a test spark dataframe with three columns. With BLANK valued columns
       (grp, dt, status).

    Args:
        spark (SparkSession): [Spark session fixture]

    Returns:
        DataFrame: [Test spark dataframe]

        +---+----------+------+
        |grp|        dt|status|
        +---+----------+------+
        |  A|2020-03-21|  null|
        |  A|      null|  null|
        |  A|2020-03-22|  null|
        |  A|2020-03-25|  null|
        |  B|2020-02-21|  null|
        |  B|2020-02-22|  null|
        |  B|2020-02-25|  null|
        +---+----------+------+
    """
    from datetime import datetime
    dict_lst = {"grp": ["A", "A", "A", "A"], "dt": [datetime(2020, 3, 21), None, datetime(2020, 3, 22), datetime(2020, 3, 25)],
                "status": [None, None, None, None]}

    dict_lst2 = {"grp": ["B", "B", "B"], "dt": [datetime(2020, 2, 21), datetime(2020, 2, 22), datetime(2020, 2, 25)],
                "status": [None, None, None]}

    column_names, data = zip(*dict_lst.items())

    schema = StructType([
        StructField("grp", StringType(), True),
        StructField("dt", DateType(), True),
        StructField("status", BooleanType(), True)
        ])

    df1 = spark.createDataFrame(zip(*data), schema)

    column_names, data = zip(*dict_lst2.items())
    return df1.union(spark.createDataFrame(zip(*data), schema))
示例#24
0
def test_numpy(spark: SparkSession, tmp_path, data_type):
    import rikai

    test_dir = str(tmp_path)
    expected = [{"n": rikai.array(range(4), dtype=data_type)}]

    df = spark.createDataFrame(
        expected,
        schema=StructType([StructField("n", NDArrayType(), False)]),
    )
    df.write.mode("overwrite").format("rikai").save(test_dir)

    records = _read_parquets(test_dir)
    assert np.array_equal(np.array(range(4), dtype=data_type), records[0]["n"])
示例#25
0
def test_with_select(spark_session: SparkSession, input_data):
    """
    This seems to be faster because it avoid Parsing and Analysing the DAG on every add column
    :param spark_session:
    :param input_data:
    :return:
    """
    df = spark_session.createDataFrame(input_data)

    new_columns = [lit(None).alias(f"test_{x}") for x in range(100)]
    out_df = df.select([col("*")] + new_columns)

    out_df.explain(extended=True)

    out_df.show()
示例#26
0
def df8(spark: SparkSession) -> DataFrame:
    """Generate a test spark dataframe with two columns. With BLANK valued columns
       (letters2, numbers).
       (letters2,numbers)=[("a",1),("o1",2),("b",3),("o2",4)]

    Args:
        spark (SparkSession): [Spark session fixture]

    Returns:
        DataFrame: [Test spark dataframe]
    """
    dict_lst = {"letters2": ["a", "o1", "b", "o2"], "numbers": [1, 2, 3, 4]}

    column_names, data = zip(*dict_lst.items())
    return spark.createDataFrame(zip(*data), column_names)
示例#27
0
def df2(spark: SparkSession) -> DataFrame:
    """Generate a test spark dataframe with two columns.
       (letters, numbers).
       (letters,numbers)=[("a",1),("b,2),("c",3)]

    Args:
        spark (SparkSession): [Spark session fixture]

    Returns:
        DataFrame: [Test spark dataframe]
    """
    dict_lst = {"letters": ["a", "b", "c"], "numbers": [1, 2, 3]}

    column_names, data = zip(*dict_lst.items())
    return spark.createDataFrame(zip(*data), column_names)
示例#28
0
def test_images(spark: SparkSession, tmp_path):
    expected = [
        {
            "id": 1,
            "image": Image(uri="s3://123"),
        },
        {
            "id": 2,
            "image": Image(uri="s3://abc"),
        },
    ]
    df = spark.createDataFrame(expected)
    df.write.mode("overwrite").parquet(str(tmp_path))

    records = sorted(_read_parquets(str(tmp_path)), key=lambda x: x["id"])
    assert_count_equal(expected, records)
示例#29
0
def df5(spark: SparkSession) -> DataFrame:
    """Generate a test spark dataframe with two columns. With Null in the columns
       (letters, numbers).
       (letters,numbers)=[("z",1),("None,2),("x",None)]

    Args:
        spark (SparkSession): [Spark session fixture]

    Returns:
        DataFrame: [Test spark dataframe]
    """

    dict_lst = {"letters": ["z", None, "x"], "numbers": [1, 2, None]}

    column_names, data = zip(*dict_lst.items())
    return spark.createDataFrame(zip(*data), column_names)
示例#30
0
def test_with_column(spark_session: SparkSession, input_data):
    """
    This seems to be slower because and it causes issues during run time with large dataset.
    Parsing and Analysing the DAG on every add column
    :param spark_session:
    :param input_data:
    :return:
    """
    df = spark_session.createDataFrame(input_data)

    out_df = df
    for x in range(100):
        out_df = out_df.withColumn(f"test_{x}", lit(None))

    out_df.explain(extended=True)

    out_df.show()