def test_query_no_inputs_or_outputs(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs={}, task_config=HiveConfig(cluster_label="flyte"), query_template=""" insert into extant_table (1, 'two') """, output_schema_type=None, ) @workflow def my_wf(): hive_task() default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, hive_task) assert len(task_spec.template.interface.inputs) == 0 assert len(task_spec.template.interface.outputs) == 0 get_serializable(OrderedDict(), serialization_settings, my_wf)
def test_serialization(): athena_task = AthenaTask( name="flytekit.demo.athena_task.query", inputs=kwtypes(ds=str), task_config=AthenaConfig(database="mnist", catalog="my_catalog", workgroup="my_wg"), query_template=""" insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet select * from blah where ds = '{{ .Inputs.ds }}' """, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(ds: str) -> FlyteSchema: return athena_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, athena_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"] assert "insert overwrite directory" in task_spec.template.custom[ "statement"] assert "mnist" == task_spec.template.custom["schema"] assert "my_catalog" == task_spec.template.custom["catalog"] assert "my_wg" == task_spec.template.custom["routingGroup"] assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs[ "o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.var == "results"
def test_serialization(): snowflake_task = SnowflakeTask( name="flytekit.demo.snowflake_task.query", inputs=kwtypes(ds=str), task_config=SnowflakeConfig(account="snowflake", warehouse="my_warehouse", schema="my_schema", database="my_database"), query_template=query_template, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(ds: str) -> FlyteSchema: return snowflake_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, snowflake_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement assert "insert overwrite directory" in task_spec.template.sql.statement assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI assert "snowflake" == task_spec.template.config["account"] assert "my_warehouse" == task_spec.template.config["warehouse"] assert "my_schema" == task_spec.template.config["schema"] assert "my_database" == task_spec.template.config["database"] assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs[ "o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.var == "results"
def test_serialization(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs=kwtypes(my_schema=FlyteSchema, ds=str), config=HiveConfig(cluster_label="flyte"), query_template=""" set engine=tez; insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet -- will be unique per retry select * from blah where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}' """, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema: return hive_task(my_schema=in_schema, ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, hive_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["query"][ "query"] assert "insert overwrite directory" in task_spec.template.custom["query"][ "query"] assert len(task_spec.template.interface.inputs) == 2 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs[ "o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.var == "results"
def test_serialization(): bigquery_task = BigQueryTask( name="flytekit.demo.bigquery_task.query", inputs=kwtypes(ds=str), task_config=BigQueryConfig( ProjectID="Flyte", Location="Asia", QueryJobConfig=QueryJobConfig(allow_large_results=True) ), query_template=query_template, output_structured_dataset_type=StructuredDataset, ) @workflow def my_wf(ds: str) -> StructuredDataset: return bigquery_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, bigquery_task) assert "SELECT * FROM `bigquery-public-data.crypto_dogecoin.transactions`" in task_spec.template.sql.statement assert "@version" in task_spec.template.sql.statement assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI s = Struct() s.update({"ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True}) assert task_spec.template.custom == json_format.MessageToDict(s) assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs["o0"].type.structured_dataset_type is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"