예제 #1
0
def test_query_no_inputs_or_outputs():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs={},
        task_config=HiveConfig(cluster_label="flyte"),
        query_template="""
            insert into extant_table (1, 'two')
        """,
        output_schema_type=None,
    )

    @workflow
    def my_wf():
        hive_task()

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 hive_task)
    assert len(task_spec.template.interface.inputs) == 0
    assert len(task_spec.template.interface.outputs) == 0

    get_serializable(OrderedDict(), serialization_settings, my_wf)
예제 #2
0
def test_serialization():
    athena_task = AthenaTask(
        name="flytekit.demo.athena_task.query",
        inputs=kwtypes(ds=str),
        task_config=AthenaConfig(database="mnist",
                                 catalog="my_catalog",
                                 workgroup="my_wg"),
        query_template="""
            insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet
            select *
            from blah
            where ds = '{{ .Inputs.ds }}'
        """,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(ds: str) -> FlyteSchema:
        return athena_task(ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 athena_task)
    assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"]
    assert "insert overwrite directory" in task_spec.template.custom[
        "statement"]
    assert "mnist" == task_spec.template.custom["schema"]
    assert "my_catalog" == task_spec.template.custom["catalog"]
    assert "my_wg" == task_spec.template.custom["routingGroup"]
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(),
                                           serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs[
        "o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.var == "results"
예제 #3
0
def test_serialization():
    snowflake_task = SnowflakeTask(
        name="flytekit.demo.snowflake_task.query",
        inputs=kwtypes(ds=str),
        task_config=SnowflakeConfig(account="snowflake",
                                    warehouse="my_warehouse",
                                    schema="my_schema",
                                    database="my_database"),
        query_template=query_template,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(ds: str) -> FlyteSchema:
        return snowflake_task(ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )

    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 snowflake_task)

    assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement
    assert "insert overwrite directory" in task_spec.template.sql.statement
    assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI
    assert "snowflake" == task_spec.template.config["account"]
    assert "my_warehouse" == task_spec.template.config["warehouse"]
    assert "my_schema" == task_spec.template.config["schema"]
    assert "my_database" == task_spec.template.config["database"]
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(),
                                           serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs[
        "o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.var == "results"
예제 #4
0
def test_serialization():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs=kwtypes(my_schema=FlyteSchema, ds=str),
        config=HiveConfig(cluster_label="flyte"),
        query_template="""
            set engine=tez;
            insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet  -- will be unique per retry
            select *
            from blah
            where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}'
        """,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema:
        return hive_task(my_schema=in_schema, ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 hive_task)
    assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["query"][
        "query"]
    assert "insert overwrite directory" in task_spec.template.custom["query"][
        "query"]
    assert len(task_spec.template.interface.inputs) == 2
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(),
                                           serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs[
        "o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.var == "results"
예제 #5
0
def test_serialization():
    bigquery_task = BigQueryTask(
        name="flytekit.demo.bigquery_task.query",
        inputs=kwtypes(ds=str),
        task_config=BigQueryConfig(
            ProjectID="Flyte", Location="Asia", QueryJobConfig=QueryJobConfig(allow_large_results=True)
        ),
        query_template=query_template,
        output_structured_dataset_type=StructuredDataset,
    )

    @workflow
    def my_wf(ds: str) -> StructuredDataset:
        return bigquery_task(ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
        env={},
    )

    task_spec = get_serializable(OrderedDict(), serialization_settings, bigquery_task)

    assert "SELECT * FROM `bigquery-public-data.crypto_dogecoin.transactions`" in task_spec.template.sql.statement
    assert "@version" in task_spec.template.sql.statement
    assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI
    s = Struct()
    s.update({"ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True})
    assert task_spec.template.custom == json_format.MessageToDict(s)
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs["o0"].type.structured_dataset_type is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"