def test_wf_container_task_multiple(): square = ContainerTask( name="square", input_data_dir="/var/inputs", output_data_dir="/var/outputs", inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", command=["sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"], ) sum = ContainerTask( name="sum", input_data_dir="/var/flyte/inputs", output_data_dir="/var/flyte/outputs", inputs=kwtypes(x=int, y=int), outputs=kwtypes(out=int), image="alpine", command=["sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out"], ) @workflow def raw_container_wf(val1: int, val2: int) -> int: return sum(x=square(val=val1), y=square(val=val2)) with task_mock(square) as square_mock, task_mock(sum) as sum_mock: square_mock.side_effect = lambda val: val * val assert square(val=10) == 100 sum_mock.side_effect = lambda x, y: x + y assert sum(x=10, y=10) == 20 assert raw_container_wf(val1=10, val2=10) == 200
def test_wf1_with_sql(): sql = SQLTask( "my-query", query_template= "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", inputs=kwtypes(ds=datetime.datetime), outputs=kwtypes(results=FlyteSchema), metadata=TaskMetadata(retries=2), ) @task def t1() -> datetime.datetime: return datetime.datetime.now() @workflow def my_wf() -> FlyteSchema: dt = t1() return sql(ds=dt) with task_mock(sql) as mock: mock.return_value = pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] }) assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert context_manager.FlyteContextManager.size() == 1
def test_to_python_value_with_incoming_columns(): # make a literal with a type that has two columns original_type = Annotated[pd.DataFrame, kwtypes(name=str, age=int)] ctx = FlyteContextManager.current_context() lt = TypeEngine.to_literal_type(original_type) df = generate_pandas() fdt = StructuredDatasetTransformerEngine() lit = fdt.to_literal(ctx, df, python_type=original_type, expected=lt) assert len(lit.scalar.structured_dataset.metadata.structured_dataset_type. columns) == 2 # declare a new type that only has one column # get the dataframe, make sure it has the column that was asked for. subset_sd_type = Annotated[StructuredDataset, kwtypes(age=int)] sd = fdt.to_python_value(ctx, lit, subset_sd_type) assert sd.metadata.structured_dataset_type.columns[0].name == "age" sub_df = sd.open(pd.DataFrame).all() assert sub_df.shape[1] == 1 # check when columns are not specified, should pull both and add column information. sd = fdt.to_python_value(ctx, lit, StructuredDataset) assert len(sd.metadata.structured_dataset_type.columns) == 2 # should also work if subset type is just an annotated pd.DataFrame subset_pd_type = Annotated[pd.DataFrame, kwtypes(age=int)] sub_df = fdt.to_python_value(ctx, lit, subset_pd_type) assert sub_df.shape[1] == 1
def test_to_python_value_without_incoming_columns(): # make a literal with a type with no columns ctx = FlyteContextManager.current_context() lt = TypeEngine.to_literal_type(pd.DataFrame) df = generate_pandas() fdt = StructuredDatasetTransformerEngine() lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt) assert len(lit.scalar.structured_dataset.metadata.structured_dataset_type. columns) == 0 # declare a new type that only has one column # get the dataframe, make sure it has the column that was asked for. subset_sd_type = Annotated[StructuredDataset, kwtypes(age=int)] sd = fdt.to_python_value(ctx, lit, subset_sd_type) assert sd.metadata.structured_dataset_type.columns[0].name == "age" sub_df = sd.open(pd.DataFrame).all() assert sub_df.shape[1] == 1 # check when columns are not specified, should pull both and add column information. # todo: see the todos in the open_as, and iter_as functions in StructuredDatasetTransformerEngine # we have to recreate the literal because the test case above filled in the metadata lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt) sd = fdt.to_python_value(ctx, lit, StructuredDataset) assert sd.metadata.structured_dataset_type.columns == [] sub_df = sd.open(pd.DataFrame).all() assert sub_df.shape[1] == 2 # should also work if subset type is just an annotated pd.DataFrame lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt) subset_pd_type = Annotated[pd.DataFrame, kwtypes(age=int)] sub_df = fdt.to_python_value(ctx, lit, subset_pd_type) assert sub_df.shape[1] == 1
def test_wf1_with_sql_with_patch(): sql = SQLTask( "my-query", query_template= "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", inputs=kwtypes(ds=datetime.datetime), outputs=kwtypes(results=FlyteSchema), metadata=TaskMetadata(retries=2), ) @task def t1() -> datetime.datetime: return datetime.datetime.now() @workflow def my_wf() -> FlyteSchema: dt = t1() return sql(ds=dt) @patch(sql) def test_user_demo_test(mock_sql): mock_sql.return_value = pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] }) assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() # Have to call because tests inside tests don't run test_user_demo_test()
def test_task_serialization(): sql_task = SQLite3Task( "test", query_template= "select TrackId, Name from tracks limit {{.inputs.limit}}", inputs=kwtypes(limit=int), output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)], task_config=SQLite3Config( uri=EXAMPLE_DB, compressed=True, ), ) tt = sql_task.serialize_to_model(sql_task.SERIALIZE_SETTINGS) assert tt.container.args == [ "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", "flytekit.core.python_customized_container_task.default_task_template_resolver", "--", "{{.taskTemplatePath}}", "flytekit.extras.sqlite3.task.SQLite3TaskExecutor", ] assert tt.custom[ "query_template"] == "select TrackId, Name from tracks limit {{.inputs.limit}}" assert tt.container.image != ""
def test_ge_flytefile_multiple_args(): task_object_one = GreatExpectationsTask( name="test13", datasource_name="data", inputs=kwtypes(dataset=FlyteFile), expectation_suite_name="test.demo", data_connector_name="data_flytetype_data_connector", local_file_path="/tmp", ) task_object_two = GreatExpectationsTask( name="test14", datasource_name="data", inputs=kwtypes(dataset=FlyteFile), expectation_suite_name="test1.demo", data_connector_name="data_flytetype_data_connector", local_file_path="/tmp", ) @task def get_file_name(dataset_one: FlyteFile, dataset_two: FlyteFile) -> typing.Tuple[int, int]: df_one = pd.read_csv(os.path.join("data", dataset_one)) df_two = pd.read_csv(os.path.join("data", dataset_two)) return len(df_one), len(df_two) @workflow def wf( dataset_one: FlyteFile = "https://raw.githubusercontent.com/superconductive/ge_tutorials/main/data/yellow_tripdata_sample_2019-01.csv", dataset_two: FlyteFile = "https://raw.githubusercontent.com/superconductive/ge_tutorials/main/data/yellow_tripdata_sample_2019-02.csv", ) -> typing.Tuple[int, int]: task_object_one(dataset=dataset_one) task_object_two(dataset=dataset_two) return get_file_name(dataset_one=dataset_one, dataset_two=dataset_two) assert wf() == (10000, 10000)
def test_resolver_load_task(): # any task is fine, just copied one square = ContainerTask( name="square", input_data_dir="/var/inputs", output_data_dir="/var/outputs", inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out" ], ) resolver = TaskTemplateResolver() ts = get_serializable(OrderedDict(), serialization_settings, square) file = tempfile.NamedTemporaryFile().name # load_task should create an instance of the path to the object given, doesn't need to be a real executor write_proto_to_file(ts.template.to_flyte_idl(), file) shim_task = resolver.load_task( [file, f"{Placeholder.__module__}.Placeholder"]) assert isinstance(shim_task.executor, Placeholder) assert shim_task.task_template.id.name == "square" assert shim_task.task_template.interface.inputs["val"] is not None assert shim_task.task_template.interface.outputs["out"] is not None
def test_execute_sqlite3_task(flyteclient, flyte_workflows_register, flyte_remote_env): remote = FlyteRemote(Config.auto(), PROJECT, "development") example_db = "https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip" interactive_sql_task = SQLite3Task( "basic_querying", query_template= "select TrackId, Name from tracks limit {{.inputs.limit}}", inputs=kwtypes(limit=int), output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)], task_config=SQLite3Config( uri=example_db, compressed=True, ), ) registered_sql_task = remote.register(interactive_sql_task) execution = remote.execute(registered_sql_task, inputs={"limit": 10}, wait=True) output = execution.outputs["results"] result = output.open().all() assert result.__class__.__name__ == "DataFrame" assert "TrackId" in result assert "Name" in result
def test_task_serialization_deserialization_with_secret(sql_server): secret_group = "foo" secret_name = "bar" sec = SecretsManager() os.environ[sec.get_secrets_env_var(secret_group, secret_name)] = "IMMEDIATE" sql_task = SQLAlchemyTask( "test", query_template="select 1;", inputs=kwtypes(limit=int), output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)], task_config=SQLAlchemyConfig( uri=sql_server, # As sqlite3 doesn't really support passwords, we pass another connect_arg as a secret secret_connect_args={ "isolation_level": Secret(group=secret_group, key=secret_name) }, ), ) tt = sql_task.serialize_to_model(sql_task.SERIALIZE_SETTINGS) assert tt.container is not None assert tt.container.args == [ "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", "flytekit.core.python_customized_container_task.default_task_template_resolver", "--", "{{.taskTemplatePath}}", "flytekitplugins.sqlalchemy.task.SQLAlchemyTaskExecutor", ] assert tt.custom["query_template"] == "select 1;" assert tt.container.image != "" assert "secret_connect_args" in tt.custom assert "isolation_level" in tt.custom["secret_connect_args"] assert tt.custom["secret_connect_args"]["isolation_level"][ "group"] == secret_group assert tt.custom["secret_connect_args"]["isolation_level"][ "key"] == secret_name assert tt.custom["secret_connect_args"]["isolation_level"][ "group_version"] is None assert tt.custom["secret_connect_args"]["isolation_level"][ "mount_requirement"] == 0 executor = SQLAlchemyTaskExecutor() r = executor.execute_from_model(tt) assert r.iat[0, 0] == 1
def query_wf() -> int: df = SQLite3Task( name="cookbook.sqlite3.sample_inline", query_template= "select TrackId, Name from tracks limit {{.inputs.limit}}", inputs=kwtypes(limit=int), output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)], task_config=SQLite3Config(uri=EXAMPLE_DB, compressed=True), )(limit=100) return print_and_count_columns(df=df)
def test_task_schema(): sql_task = SQLite3Task( "test", query_template="select TrackId, Name from tracks limit {{.inputs.limit}}", inputs=kwtypes(limit=int), output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)], task_config=SQLite3Config(uri=EXAMPLE_DB, compressed=True,), ) assert sql_task.output_columns is not None df = sql_task(limit=1) assert df is not None
def test_task_schema(sql_server): sql_task = SQLAlchemyTask( "test", query_template= "select TrackId, Name from tracks limit {{.inputs.limit}}", inputs=kwtypes(limit=int), output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)], task_config=SQLAlchemyConfig(uri=sql_server, ), ) assert sql_task.output_columns is not None df = sql_task(limit=1) assert df is not None
def test_serialization(): square = ContainerTask( name="square", input_data_dir="/var/inputs", output_data_dir="/var/outputs", inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out" ], ) sum = ContainerTask( name="sum", input_data_dir="/var/flyte/inputs", output_data_dir="/var/flyte/outputs", inputs=kwtypes(x=int, y=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out" ], ) @workflow def raw_container_wf(val1: int, val2: int) -> int: return sum(x=square(val=val1), y=square(val=val2)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, raw_container_wf) assert wf_spec is not None assert wf_spec.template is not None assert len(wf_spec.template.nodes) == 3 sqn_spec = get_serializable(OrderedDict(), serialization_settings, square) assert sqn_spec.template.container.image == "alpine" sumn_spec = get_serializable(OrderedDict(), serialization_settings, sum) assert sumn_spec.template.container.image == "alpine"
def test_wf_typed_schema(): schema1 = FlyteSchema[kwtypes(x=int, y=str)] @task def t1() -> schema1: s = schema1() s.open().write(pandas.DataFrame(data={"x": [1, 2], "y": ["3", "4"]})) return s @task def t2(s: FlyteSchema[kwtypes(x=int, y=str)]) -> FlyteSchema[kwtypes(x=int)]: df = s.open().all() return df[s.column_names()[:-1]] @workflow def wf() -> FlyteSchema[kwtypes(x=int)]: return t2(s=t1()) w = t1() assert w is not None df = w.open(override_mode=SchemaOpenMode.READ).all() result_df = df.reset_index(drop=True) == pandas.DataFrame(data={"x": [1, 2], "y": ["3", "4"]}).reset_index( drop=True ) assert result_df.all().all() df = t2(s=w.as_readonly()) assert df is not None result_df = df.reset_index(drop=True) == pandas.DataFrame(data={"x": [1, 2]}).reset_index(drop=True) assert result_df.all().all() x = wf() df = x.open().all() result_df = df.reset_index(drop=True) == pandas.DataFrame(data={"x": [1, 2]}).reset_index(drop=True) assert result_df.all().all()
def test_wf_container_task(): @task def t1(a: int) -> (int, str): return a + 2, str(a) + "-HELLO" t2 = ContainerTask( "raw", image="alpine", inputs=kwtypes(a=int, b=str), input_data_dir="/tmp", output_data_dir="/tmp", command=["cat"], arguments=["/tmp/a"], ) @workflow def wf(a: int): x, y = t1(a=a) t2(a=x, b=y) with task_mock(t2) as mock: mock.side_effect = lambda a, b: None assert t2(a=10, b="hello") is None wf(a=10)
def test_input_output_substitution_files(): script = "cat {inputs.f} > {outputs.y}" if os.name == "nt": script = script.replace("cat", "type") t = ShellTask( name="test", debug=True, script=script, inputs=kwtypes(f=CSVFile), output_locs=[ OutputLocation(var="y", var_type=FlyteFile, location="{inputs.f}.mod"), ], ) assert t.script == script contents = "1,2,3,4\n" with tempfile.TemporaryDirectory() as tmp: csv = os.path.join(tmp, "abc.csv") print(csv) with open(csv, "w") as f: f.write(contents) y = t(f=csv) assert y.path[-4:] == ".mod" assert os.path.exists(y.path) with open(y.path) as f: s = f.read() assert s == contents
def test_wf_schema_to_df(): schema1 = FlyteSchema[kwtypes(x=int, y=str)] @task(cache=True, cache_version="v0") def t1() -> schema1: global n_cached_task_calls n_cached_task_calls += 1 s = schema1() s.open().write(pandas.DataFrame(data={"x": [1, 2], "y": ["3", "4"]})) return s @task(cache=True, cache_version="v1") def t2(df: pandas.DataFrame) -> int: global n_cached_task_calls n_cached_task_calls += 1 return len(df.columns.values) @workflow def wf() -> int: return t2(df=t1()) assert n_cached_task_calls == 0 x = wf() assert x == 2 assert n_cached_task_calls == 2 # Second call does not bump the counter x = wf() assert x == 2 assert n_cached_task_calls == 2
def test_fs_sd_compatibility(): my_schema = FlyteSchema[kwtypes(name=str, age=int)] @task def my_dataset() -> pd.DataFrame: return pd.DataFrame(data={"name": ["Alice"], "age": [5]}) @task(task_config=Spark()) def my_spark(df: pyspark.sql.DataFrame) -> my_schema: session = flytekit.current_context().spark_session new_df = session.createDataFrame([("Bob", 10)], my_schema.column_names()) return df.union(new_df) @task(task_config=Spark()) def read_spark_df(df: pyspark.sql.DataFrame) -> int: return df.count() @workflow def my_wf() -> int: df = my_dataset() fs = my_spark(df=df) return read_spark_df(df=fs) res = my_wf() assert res == 2
def test_workflow(sql_server): @task def my_task(df: pandas.DataFrame) -> int: return len(df[df.columns[0]]) insert_task = SQLAlchemyTask( "test", query_template="insert into tracks values (5, 'flyte')", output_schema_type=None, task_config=SQLAlchemyConfig(uri=sql_server), ) sql_task = SQLAlchemyTask( "test", query_template="select * from tracks limit {{.inputs.limit}}", inputs=kwtypes(limit=int), task_config=SQLAlchemyConfig(uri=sql_server), ) @workflow def wf(limit: int) -> int: insert_task() return my_task(df=sql_task(limit=limit)) assert wf(limit=10) == 6
def test_ge_simple_task(): task_object = GreatExpectationsTask( name="test1", datasource_name="data", inputs=kwtypes(dataset=str), expectation_suite_name="test.demo", data_connector_name="data_example_data_connector", ) # valid data result = task_object(dataset="yellow_tripdata_sample_2019-01.csv") assert result["success"] is True assert result["statistics"]["evaluated_expectations"] == result["statistics"]["successful_expectations"] # invalid data with pytest.raises(ValidationError): invalid_result = task_object(dataset="yellow_tripdata_sample_2019-02.csv") assert invalid_result["success"] is False assert ( invalid_result["statistics"]["evaluated_expectations"] != invalid_result["statistics"]["successful_expectations"] ) assert task_object.python_interface.inputs == {"dataset": str}
def test_notebook_task_simple(): nb_name = "nb-spark" nb = NotebookTask( name="test", notebook_path=_get_nb_path(nb_name, abs=False), outputs=kwtypes(df=FlyteSchema[kwtypes(name=str, age=int)]), task_config=Spark(spark_conf={"x": "y"}), ) n, out, render = nb.execute() assert nb.python_interface.outputs.keys() == { "df", "out_nb", "out_rendered_nb" } assert nb.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out") assert nb.rendered_output_path == render == _get_nb_path( nb_name, suffix="-out", ext=".html")
def test_raw_shell_task_instantiation(capfd): if script_sh_2 is None: return pst = RawShellTask( name="test", debug=True, inputs=flytekit.kwtypes(env=typing.Dict[str, str], script_args=str, script_file=str), output_locs=[ OutputLocation( var="out", var_type=FlyteDirectory, location="{ctx.working_directory}", ) ], script=""" #!/bin/bash set -uex cd {ctx.working_directory} {inputs.export_env} bash {inputs.script_file} {inputs.script_args} """, ) pst(script_file=script_sh_2, script_args="first_arg second_arg", env={}) cap = capfd.readouterr() assert "first_arg" in cap.out assert "second_arg" in cap.out
def get_raw_shell_task(name: str) -> RawShellTask: return RawShellTask( name=name, debug=True, inputs=flytekit.kwtypes(env=typing.Dict[str, str], script_args=str, script_file=str), output_locs=[ OutputLocation( var="out", var_type=FlyteDirectory, location="{ctx.working_directory}", ) ], script=""" #!/bin/bash set -uex cd {ctx.working_directory} {inputs.export_env} bash {inputs.script_file} {inputs.script_args} """, )
def test_ge_with_task(): task_object = GreatExpectationsTask( name="test6", datasource_name="data", inputs=kwtypes(dataset=str), expectation_suite_name="test.demo", data_connector_name="data_example_data_connector", ) @task def my_task(csv_file: str) -> int: df = pd.read_csv(os.path.join("data", csv_file)) return df.shape[0] @workflow def valid_wf(dataset: str = "yellow_tripdata_sample_2019-01.csv") -> int: task_object(dataset=dataset) return my_task(csv_file=dataset) @workflow def invalid_wf(dataset: str = "yellow_tripdata_sample_2019-02.csv") -> int: task_object(dataset=dataset) return my_task(csv_file=dataset) valid_result = valid_wf() assert valid_result == 10000 with pytest.raises(ValidationError, match=r".*passenger_count -> expect_column_min_to_be_between.*"): invalid_wf()
def test_bad_conversion(): orig = FlyteSchema[kwtypes(my_custom=bool)] lt = TypeEngine.to_literal_type(orig) # Make a not real column type lt.schema.columns[0]._type = 15 with pytest.raises(ValueError): TypeEngine.guess_python_type(lt)
def __init__( self, name: str, query_template: str, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, task_config: typing.Optional[SQLite3Config] = None, output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None, **kwargs, ): if task_config is None or task_config.uri is None: raise ValueError("SQLite DB uri is required.") outputs = kwtypes( results=output_schema_type if output_schema_type else FlyteSchema) super().__init__( name=name, task_config=task_config, # If you make changes to this task itself, you'll have to bump this image to what the release _will_ be. container_image="ghcr.io/flyteorg/flytekit:py38-v0.19.0b7", executor_type=SQLite3TaskExecutor, task_type=self._SQLITE_TASK_TYPE, query_template=query_template, inputs=inputs, outputs=outputs, **kwargs, )
def __init__( self, name: str, query_template: str, task_config: SQLAlchemyConfig, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, output_schema_type: typing.Optional[ typing.Type[FlyteSchema]] = FlyteSchema, container_image: str = SQLAlchemyDefaultImages.default_image(), **kwargs, ): if output_schema_type: outputs = kwtypes(results=output_schema_type) else: outputs = None super().__init__( name=name, task_config=task_config, executor_type=SQLAlchemyTaskExecutor, task_type=self._SQLALCHEMY_TASK_TYPE, query_template=query_template, container_image=container_image, inputs=inputs, outputs=outputs, **kwargs, )
def test_notebook_task_simple(): nb_name = "nb-simple" nb = NotebookTask( name="test", notebook_path=_get_nb_path(nb_name, abs=False), inputs=kwtypes(pi=float), outputs=kwtypes(square=float), ) sqr, out, render = nb.execute(pi=4) assert sqr == 16.0 assert nb.python_interface.inputs == {"pi": float} assert nb.python_interface.outputs.keys() == { "square", "out_nb", "out_rendered_nb" } assert nb.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out") assert nb.rendered_output_path == render == _get_nb_path( nb_name, suffix="-out", ext=".html")
def test_sql_task(): sql = SQLTask( "my-query", query_template= "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", inputs=kwtypes(ds=datetime.datetime), outputs=kwtypes(results=FlyteSchema), metadata=TaskMetadata(retries=2, cache=True, cache_version="0.1"), ) @task(cache=True, cache_version="0.1.2") def t1() -> datetime.datetime: global n_cached_task_calls n_cached_task_calls += 1 return datetime.datetime.now() @workflow def my_wf() -> FlyteSchema: dt = t1() return sql(ds=dt) with task_mock(sql) as mock: mock.return_value = pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] }) assert n_cached_task_calls == 0 assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert n_cached_task_calls == 1 # The second and third calls hit the cache assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert n_cached_task_calls == 1 assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert n_cached_task_calls == 1