def test_resource_limits_override(): @task def t1(a: str) -> str: return f"*~*~*~{a}*~*~*~" @workflow def my_wf(a: typing.List[str]) -> typing.List[str]: mappy = map_task(t1) map_node = mappy(a=a).with_overrides(limits=Resources(cpu="2", mem="200", ephemeral_storage="1Gi")) return map_node serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == [] assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [ _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2"), _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"), _resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "1Gi"), ]
def test_pod_task_serialized(): pod = Pod( pod_spec=get_pod_spec(), primary_container_name="an undefined container", labels={"label": "foo"}, annotations={"anno": "bar"}, ) @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") ssettings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task) assert serialized.template.task_type_version == 2 assert serialized.template.config[ "primary_container_name"] == "an undefined container" assert serialized.template.k8s_pod.metadata.labels == {"label": "foo"} assert serialized.template.k8s_pod.metadata.annotations == {"anno": "bar"} assert serialized.template.k8s_pod.pod_spec is not None
def test_mpi_task(): @task( task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1), requests=Resources(cpu="1"), cache=True, cache_version="1", ) def my_mpi_task(x: int, y: str) -> int: return x assert my_mpi_task(x=10, y="hello") == 10 assert my_mpi_task.task_config is not None default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) assert my_mpi_task.get_custom(settings) == { "numLauncherReplicas": 10, "numWorkers": 10, "slots": 1 } assert my_mpi_task.task_type == "mpi"
def test_dont_convert_remotes(): @task def t1(in1: FlyteDirectory): print(in1) @dynamic def dyn(in1: FlyteDirectory): t1(in1=in1) fd = FlyteDirectory("s3://anything") ctx = context_manager.FlyteContext.current_context() with context_manager.FlyteContextManager.with_context( ctx.with_serialization_settings( flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ) ) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) ) as ctx: lit = TypeEngine.to_literal( ctx, fd, FlyteDirectory, BlobType("", dimensionality=BlobType.BlobDimensionality.MULTIPART) ) lm = LiteralMap(literals={"in1": lit}) wf = dyn.dispatch_execute(ctx, lm) assert wf.nodes[0].inputs[0].binding.scalar.blob.uri == "s3://anything"
def serialization_settings(): default_img = Image(name="default", fqn="test", tag="tag") return flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), )
def _get_reg_settings(): default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) return settings
def test_dc_dyn_directory(folders_and_files_setup): proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file") proxy_p = MyProxyParameters(id="pp_id", job_i_step=1) my_input_gcs = MyInput( main_product=FlyteFile(folders_and_files_setup[0]), apriori_config=MyAprioriConfiguration( static_data_dir=FlyteDirectory("gs://my-bucket/one"), external_data_dir=FlyteDirectory("gs://my-bucket/two"), ), proxy_config=proxy_c, proxy_params=proxy_p, ) my_input_gcs_2 = MyInput( main_product=FlyteFile(folders_and_files_setup[0]), apriori_config=MyAprioriConfiguration( static_data_dir=FlyteDirectory("gs://my-bucket/three"), external_data_dir=FlyteDirectory("gs://my-bucket/four"), ), proxy_config=proxy_c, proxy_params=proxy_p, ) @dynamic def dt1(a: List[MyInput]) -> List[FlyteDirectory]: x = [] for aa in a: x.append(aa.apriori_config.external_data_dir) return x ctx = FlyteContextManager.current_context() cb = ( ctx.new_builder() .with_serialization_settings( SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ) .with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) ) with FlyteContextManager.with_context(cb) as ctx: input_literal_map = TypeEngine.dict_to_literal_map( ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, type_hints={"a": List[MyInput]} ) dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map) assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two" assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four"
def test_pod_task_undefined_primary(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="an undefined container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") pod_spec = simple_pod_task.get_k8s_pod( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )).pod_spec assert len(pod_spec["containers"]) == 3 primary_container = pod_spec["containers"][2] assert primary_container["name"] == "an undefined container" config = simple_pod_task.get_config( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert config["primary_container_name"] == "an undefined container"
def test_lps(resource_type): ref_entity = get_reference_entity( resource_type, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs={}, ) ctx = context_manager.FlyteContext.current_context() with pytest.raises(Exception) as e: ref_entity() assert "You must mock this out" in f"{e}" with context_manager.FlyteContextManager.with_context( ctx.with_new_compilation_state()) as ctx: with pytest.raises(Exception) as e: ref_entity() assert "Input was not specified" in f"{e}" output = ref_entity(a="hello", b=3) assert isinstance(output, VoidPromise) @workflow def wf1(a: str, b: int): ref_entity(a=a, b=b) serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, wf1) assert len(wf_spec.template.interface.inputs) == 2 assert len(wf_spec.template.interface.outputs) == 0 assert len(wf_spec.template.nodes) == 1 if resource_type == _identifier_model.ResourceType.LAUNCH_PLAN: assert wf_spec.template.nodes[ 0].workflow_node.launchplan_ref.project == "proj" assert wf_spec.template.nodes[ 0].workflow_node.launchplan_ref.name == "app.other.flyte_entity" else: assert wf_spec.template.nodes[ 0].task_node.reference_id.project == "proj" assert wf_spec.template.nodes[ 0].task_node.reference_id.name == "app.other.flyte_entity"
def test_module_loading(mock_entities, mock_entities_2): entities = [] mock_entities.entities = entities mock_entities_2.entities = entities with tempfile.TemporaryDirectory() as tmp_dir: # Create directories top_level = os.path.join(tmp_dir, "top") middle_level = os.path.join(top_level, "middle") bottom_level = os.path.join(middle_level, "bottom") os.makedirs(bottom_level) top_level_2 = os.path.join(tmp_dir, "top2") middle_level_2 = os.path.join(top_level_2, "middle") os.makedirs(middle_level_2) # Create init files pathlib.Path(os.path.join(top_level, "__init__.py")).touch() pathlib.Path(os.path.join(top_level, "a.py")).touch() pathlib.Path(os.path.join(middle_level, "__init__.py")).touch() pathlib.Path(os.path.join(middle_level, "a.py")).touch() pathlib.Path(os.path.join(bottom_level, "__init__.py")).touch() pathlib.Path(os.path.join(bottom_level, "a.py")).touch() with open(os.path.join(bottom_level, "a.py"), "w") as fh: fh.write(task_text) pathlib.Path(os.path.join(middle_level_2, "__init__.py")).touch() # Because they have different roots with pytest.raises(ValueError): find_common_root([middle_level_2, bottom_level]) # But now add one more init file pathlib.Path(os.path.join(top_level_2, "__init__.py")).touch() # Now it should pass root = find_common_root([middle_level_2, bottom_level]) assert pathlib.Path(root).resolve() == pathlib.Path(tmp_dir).resolve() # Now load them serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig.auto( img_name=DefaultImages.default_image()), ) x = load_packages_and_modules(serialization_settings, pathlib.Path(root), [bottom_level]) assert len(x) == 1
def test_fast_pod_task_serialization(): pod = Pod( pod_spec=V1PodSpec(restart_policy="OnFailure", containers=[V1Container(name="primary")]), primary_container_name="primary", ) @task(task_config=pod, environment={"FOO": "bar"}) def simple_pod_task(i: int): pass default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), fast_serialization_settings=FastSerializationSettings(enabled=True), ) serialized = get_serializable(OrderedDict(), serialization_settings, simple_pod_task) assert serialized.template.k8s_pod.pod_spec["containers"][0]["args"] == [ "pyflyte-fast-execute", "--additional-distribution", "{{ .remote_package_path }}", "--dest-dir", "{{ .dest_dir }}", "--", "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ]
def test_py_func_task_get_container(): def foo(i: int): pass default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1") other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other") cfg = ImageConfig(default_image=default_img, images=[default_img, other_img]) settings = SerializationSettings(project="p", domain="d", version="v", image_config=cfg, env={"FOO": "bar"}) pytask = PythonFunctionTask(None, foo, None, environment={"BAZ": "baz"}) c = pytask.get_container(settings) assert c.image == "xyz.com/abc:tag1" assert c.env == {"FOO": "bar", "BAZ": "baz"}
def test_serialization(): square = ContainerTask( name="square", input_data_dir="/var/inputs", output_data_dir="/var/outputs", inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out" ], ) sum = ContainerTask( name="sum", input_data_dir="/var/flyte/inputs", output_data_dir="/var/flyte/outputs", inputs=kwtypes(x=int, y=int), outputs=kwtypes(out=int), image="alpine", command=[ "sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out" ], ) @workflow def raw_container_wf(val1: int, val2: int) -> int: return sum(x=square(val=val1), y=square(val=val2)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) wf_spec = get_serializable(OrderedDict(), serialization_settings, raw_container_wf) assert wf_spec is not None assert wf_spec.template is not None assert len(wf_spec.template.nodes) == 3 sqn_spec = get_serializable(OrderedDict(), serialization_settings, square) assert sqn_spec.template.container.image == "alpine" sumn_spec = get_serializable(OrderedDict(), serialization_settings, sum) assert sumn_spec.template.container.image == "alpine"
def test_map_pod_task_serialization(): pod = Pod( pod_spec=V1PodSpec(restart_policy="OnFailure", containers=[V1Container(name="primary")]), primary_container_name="primary", ) @task(task_config=pod, environment={"FOO": "bar"}) def simple_pod_task(i: int): pass mapped_task = map_task(simple_pod_task, metadata=TaskMetadata(retries=1)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) # Test that target is correctly serialized with an updated command pod_spec = mapped_task.get_k8s_pod(serialization_settings).pod_spec assert len(pod_spec["containers"]) == 1 assert pod_spec["containers"][0]["args"] == [ "pyflyte-map-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ] assert { "primary_container_name": "primary" } == mapped_task.get_config(serialization_settings)
def test_wf1_with_fast_dynamic(): @task def t1(a: int) -> str: a = a + 2 return "fast-" + str(a) @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) return s @workflow def my_wf(a: int) -> typing.List[str]: v = my_subwf(a=a) return v with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir="/User/flyte/workflows", distribution_location="s3://my-s3-bucket/fast/123", ), ))) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) dynamic_job_spec = my_subwf.dispatch_execute( ctx, input_literal_map) assert len(dynamic_job_spec._nodes) == 5 assert len(dynamic_job_spec.tasks) == 1 args = " ".join(dynamic_job_spec.tasks[0].container.args) assert args.startswith( "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 " "--dest-dir /User/flyte/workflows") assert context_manager.FlyteContextManager.size() == 1
def test_serialization(): snowflake_task = SnowflakeTask( name="flytekit.demo.snowflake_task.query", inputs=kwtypes(ds=str), task_config=SnowflakeConfig(account="snowflake", warehouse="my_warehouse", schema="my_schema", database="my_database"), query_template=query_template, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(ds: str) -> FlyteSchema: return snowflake_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, snowflake_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement assert "insert overwrite directory" in task_spec.template.sql.statement assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI assert "snowflake" == task_spec.template.config["account"] assert "my_warehouse" == task_spec.template.config["warehouse"] assert "my_schema" == task_spec.template.config["schema"] assert "my_database" == task_spec.template.config["database"] assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs[ "o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.var == "results"
def test_serialization(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs=kwtypes(my_schema=FlyteSchema, ds=str), config=HiveConfig(cluster_label="flyte"), query_template=""" set engine=tez; insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet -- will be unique per retry select * from blah where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}' """, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema: return hive_task(my_schema=in_schema, ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, hive_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["query"][ "query"] assert "insert overwrite directory" in task_spec.template.custom["query"][ "query"] assert len(task_spec.template.interface.inputs) == 2 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs[ "o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.var == "results"
def test_sql_command(): default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) srz_t = get_serializable(OrderedDict(), serialization_settings, not_tk) assert srz_t.template.container.args[-5:] == [ "--resolver", "flytekit.core.python_customized_container_task.default_task_template_resolver", "--", "{{.taskTemplatePath}}", "flytekitplugins.sqlalchemy.task.SQLAlchemyTaskExecutor", ]
def test_lp_with_output(): ref_lp = get_reference_entity( _identifier_model.ResourceType.LAUNCH_PLAN, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs=kwtypes(x=bool, y=int), ) @task def t1() -> (str, int): return "hello", 88 @task def t2(q: bool, r: int) -> str: return f"q: {q} r: {r}" @workflow def wf1() -> str: t1_str, t1_int = t1() x_out, y_out = ref_lp(a=t1_str, b=t1_int) return t2(q=x_out, r=y_out) @patch(ref_lp) def inner_test(ref_mock): ref_mock.return_value = (False, 30) x = wf1() assert x == "q: False r: 30" ref_mock.assert_called_with(a="hello", b=88) inner_test() serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, wf1) assert wf_spec.template.nodes[ 1].workflow_node.launchplan_ref.project == "proj" assert wf_spec.template.nodes[ 1].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
def test_serialization_images(): @task(container_image="{{.image.xyz.fqn}}:{{.image.xyz.version}}") def t1(a: int) -> int: return a @task(container_image="{{.image.abc.fqn}}:{{.image.xyz.version}}") def t2(): pass @task(container_image="docker.io/org/myimage:latest") def t4(): pass @task(container_image="docker.io/org/myimage:{{.image.xyz.version}}") def t5(a: int) -> int: return a @task(container_image="{{.image.xyz_123.fqn}}:{{.image.xyz_123.version}}") def t6(a: int) -> int: return a os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version" imgs = ImageConfig.auto(config_file=os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs/images.config")) rs = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=imgs, ) t1_spec = get_serializable(OrderedDict(), rs, t1) assert t1_spec.template.container.image == "docker.io/xyz:latest" t1_spec.to_flyte_idl() t2_spec = get_serializable(OrderedDict(), rs, t2) assert t2_spec.template.container.image == "docker.io/abc:latest" t4_spec = get_serializable(OrderedDict(), rs, t4) assert t4_spec.template.container.image == "docker.io/org/myimage:latest" t5_spec = get_serializable(OrderedDict(), rs, t5) assert t5_spec.template.container.image == "docker.io/org/myimage:latest" t5_spec = get_serializable(OrderedDict(), rs, t6) assert t5_spec.template.container.image == "docker.io/xyz_123:v1"
def serialize_all( pkgs: typing.List[str] = None, local_source_root: typing.Optional[str] = None, folder: typing.Optional[str] = None, mode: typing.Optional[SerializationMode] = None, image: typing.Optional[str] = None, flytekit_virtualenv_root: typing.Optional[str] = None, python_interpreter: typing.Optional[str] = None, config_file: typing.Optional[str] = None, ): """ This function will write to the folder specified the following protobuf types :: flyteidl.admin.launch_plan_pb2.LaunchPlan flyteidl.admin.workflow_pb2.WorkflowSpec flyteidl.admin.task_pb2.TaskSpec These can be inspected by calling (in the launch plan case) :: flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the entity type. :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization. :param local_source_root: Where to start looking for the code. :param folder: Where to write the output protobuf files :param mode: Regular vs fast :param image: The fully qualified and versioned default image to use :param flytekit_virtualenv_root: The full path of the virtual env in the container. """ if not (mode == SerializationMode.DEFAULT or mode == SerializationMode.FAST): raise AssertionError(f"Unrecognized serialization mode: {mode}") serialization_settings = SerializationSettings( image_config=ImageConfig.auto(config_file, img_name=image), fast_serialization_settings=FastSerializationSettings( enabled=mode == SerializationMode.FAST, # TODO: if we want to move the destination dir as a serialization argument, we should initialize it here ), flytekit_virtualenv_root=flytekit_virtualenv_root, python_interpreter=python_interpreter, ) serialize_to_folder(pkgs, serialization_settings, local_source_root, folder)
def test_ref(): assert ref_t1.id.project == "flytesnacks" assert ref_t1.id.domain == "development" assert ref_t1.id.name == "recipes.aaa.simple.join_strings" assert ref_t1.id.version == "553018f39e519bdb2597b652639c30ce16b99c79" serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) spec = get_serializable(OrderedDict(), serialization_settings, ref_t1) assert isinstance(spec, ReferenceSpec) assert isinstance(spec.template, ReferenceTemplate) assert spec.template.id == ref_t1.id assert spec.template.resource_type == _identifier_model.ResourceType.TASK
def test_interruptible_override(interruptible): @task def t1(a: str) -> str: return f"*~*~*~{a}*~*~*~" @workflow def my_wf(a: str) -> str: return t1(a=a).with_overrides(interruptible=interruptible) serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].metadata.interruptible == interruptible
def get_registerable_container_image(img: Optional[str], cfg: ImageConfig) -> str: """ :param img: Configured image :param cfg: Registration configuration :return: """ if img is not None and img != "": matches = _IMAGE_REPLACE_REGEX.findall(img) if matches is None or len(matches) == 0: return img for m in matches: if len(m) < 3: raise AssertionError( "Image specification should be of the form <fqn>:<tag> OR <fqn>:{{.image.default.version}} OR " f"{{.image.xyz.fqn}}:{{.image.xyz.version}} OR {{.image.xyz}} - Received {m}" ) replace_group, name, attr = m if name is None or name == "": raise AssertionError(f"Image format is incorrect {m}") img_cfg = cfg.find_image(name) if img_cfg is None: raise AssertionError( f"Image Config with name {name} not found in the configuration" ) if attr == "version": if img_cfg.tag is not None: img = img.replace(replace_group, img_cfg.tag) else: img = img.replace(replace_group, cfg.default_image.tag) elif attr == "fqn": img = img.replace(replace_group, img_cfg.fqn) elif attr == "": img = img.replace(replace_group, img_cfg.full) else: raise AssertionError( f"Only fqn and version are supported replacements, {attr} is not supported" ) return img if cfg.default_image is None: raise ValueError("An image is required for PythonAutoContainer tasks") return f"{cfg.default_image.fqn}:{cfg.default_image.tag}"
def test_ref_sub_wf(): ref_entity = get_reference_entity( _identifier_model.ResourceType.WORKFLOW, "proj", "dom", "app.other.sub_wf", "123", inputs=kwtypes(a=str, b=int), outputs={}, ) ctx = context_manager.FlyteContext.current_context() with pytest.raises(Exception) as e: ref_entity() assert "You must mock this out" in f"{e}" with context_manager.FlyteContextManager.with_context( ctx.with_new_compilation_state()) as ctx: with pytest.raises(Exception) as e: ref_entity() assert "Input was not specified" in f"{e}" output = ref_entity(a="hello", b=3) assert isinstance(output, VoidPromise) @workflow def wf1(a: str, b: int): ref_entity(a=a, b=b) serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) with pytest.raises(Exception, match="currently unsupported"): # Subworkflow as references don't work (probably ever). The reason is because we'd need to make a network call # to admin to get the structure of the subworkflow and the whole point of reference entities is that there # is no network call. get_serializable(OrderedDict(), serialization_settings, wf1)
def test_serialization(): athena_task = AthenaTask( name="flytekit.demo.athena_task.query", inputs=kwtypes(ds=str), task_config=AthenaConfig(database="mnist", catalog="my_catalog", workgroup="my_wg"), query_template=""" insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet select * from blah where ds = '{{ .Inputs.ds }}' """, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(ds: str) -> FlyteSchema: return athena_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, athena_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"] assert "insert overwrite directory" in task_spec.template.custom["statement"] assert "mnist" == task_spec.template.custom["schema"] assert "my_catalog" == task_spec.template.custom["catalog"] assert "my_wg" == task_spec.template.custom["routingGroup"] assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs["o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"
def test_ref_dynamic_task(): @reference_task( project="flytesnacks", domain="development", name="sample.reference.task", version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: int) -> str: ... @task def t2(a: str, b: str) -> str: return b + a @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(ref_t1(a=i)) return s with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, ))) as ctx: new_exc_state = ctx.execution_state.with_params( mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) with context_manager.FlyteContextManager.with_context( ctx.with_execution_state(new_exc_state)) as ctx: with pytest.raises(Exception, match="currently unsupported"): my_subwf.compile_into_workflow(ctx, my_subwf._task_function, a=5)
def test_serialization(): bigquery_task = BigQueryTask( name="flytekit.demo.bigquery_task.query", inputs=kwtypes(ds=str), task_config=BigQueryConfig( ProjectID="Flyte", Location="Asia", QueryJobConfig=QueryJobConfig(allow_large_results=True) ), query_template=query_template, output_structured_dataset_type=StructuredDataset, ) @workflow def my_wf(ds: str) -> StructuredDataset: return bigquery_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, bigquery_task) assert "SELECT * FROM `bigquery-public-data.crypto_dogecoin.transactions`" in task_spec.template.sql.statement assert "@version" in task_spec.template.sql.statement assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI s = Struct() s.update({"ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True}) assert task_spec.template.custom == json_format.MessageToDict(s) assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs["o0"].type.structured_dataset_type is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"
def test_dynamic(): @dynamic def my_subwf(a: int) -> typing.List[int]: s = [] for i in range(a): s.append(ft(a=i)) return s with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, fast_serialization_settings=FastSerializationSettings( enabled=True), ))) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 2}) # Test that it works dynamic_job_spec = my_subwf.dispatch_execute( ctx, input_literal_map) assert len(dynamic_job_spec._nodes) == 2 assert len(dynamic_job_spec.tasks) == 1 assert dynamic_job_spec.tasks[0].id == ft.id # Test that the fast execute stuff does not get applied because the commands of tasks fetched from # Admin should never change. args = " ".join(dynamic_job_spec.tasks[0].container.args) assert not args.startswith("pyflyte-fast-execute")
def test_more_stuff(mock_client): r = FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain") # Can't upload a folder with pytest.raises(ValueError): with tempfile.TemporaryDirectory() as tmp_dir: r._upload_file(pathlib.Path(tmp_dir)) # Test that this copies the file. with tempfile.TemporaryDirectory() as tmp_dir: mm = MagicMock() mm.signed_url = os.path.join(tmp_dir, "tmp_file") mock_client.return_value.get_upload_signed_url.return_value = mm r._upload_file(pathlib.Path(__file__)) serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig.auto(img_name=DefaultImages.default_image()), ) # gives a thing computed_v = r._version_from_hash(b"", serialization_settings) assert len(computed_v) > 0 # gives the same thing computed_v2 = r._version_from_hash(b"", serialization_settings) assert computed_v2 == computed_v2 # should give a different thing computed_v3 = r._version_from_hash(b"", serialization_settings, "hi") assert computed_v2 != computed_v3