def package(ctx, image_config, source, output, force, fast, in_container_source_path, python_interpreter): """ This command produces a Flyte backend registrable package of all entities in Flyte. For tasks, one pb file is produced for each task, representing one TaskTemplate object. For workflows, one pb file is produced for each workflow, representing a WorkflowClosure object. The closure object contains the WorkflowTemplate, along with the relevant tasks for that workflow. This serialization step will set the name of the tasks to the fully qualified name of the task function. """ if os.path.exists(output) and not force: raise click.BadParameter(click.style(f"Output file {output} already exists, specify -f to override.", fg="red")) serialization_settings = SerializationSettings( image_config=image_config, fast_serialization_settings=FastSerializationSettings( enabled=fast, destination_dir=in_container_source_path, ), python_interpreter=python_interpreter, ) pkgs = ctx.obj[constants.CTX_PACKAGES] if not pkgs: display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!") try: serialize_and_package(pkgs, serialization_settings, source, output, fast) except NoSerializableEntitiesError: click.secho(f"No flyte objects found in packages {pkgs}", fg="yellow")
def test_tensorflow_task(): @task( task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1), cache=True, requests=Resources(cpu="1"), cache_version="1", ) def my_tensorflow_task(x: int, y: str) -> int: return x assert my_tensorflow_task(x=10, y="hello") == 10 assert my_tensorflow_task.task_config is not None default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) assert my_tensorflow_task.get_custom(settings) == {"workers": 10, "psReplicas": 1, "chiefReplicas": 1} assert my_tensorflow_task.resources.limits == Resources() assert my_tensorflow_task.resources.requests == Resources(cpu="1") assert my_tensorflow_task.task_type == "tensorflow"
def test_query_no_inputs_or_outputs(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs={}, task_config=HiveConfig(cluster_label="flyte"), query_template=""" insert into extant_table (1, 'two') """, output_schema_type=None, ) @workflow def my_wf(): hive_task() default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, hive_task) assert len(task_spec.template.interface.inputs) == 0 assert len(task_spec.template.interface.outputs) == 0 get_serializable(OrderedDict(), serialization_settings, my_wf)
def test_two(two_sample_inputs): my_input = two_sample_inputs[0] my_input_2 = two_sample_inputs[1] @dynamic def dt1(a: List[MyInput]) -> List[FlyteFile]: x = [] for aa in a: x.append(aa.main_product) return x with FlyteContextManager.with_context( FlyteContextManager.current_context().with_serialization_settings( SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ) ) as ctx: with FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ) ) ) as ctx: input_literal_map = TypeEngine.dict_to_literal_map( ctx, d={"a": [my_input, my_input_2]}, type_hints={"a": List[MyInput]} ) dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map) assert len(dynamic_job_spec.literals["o0"].collection.literals) == 2
def test_serialization_settings_transport(): default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"hello": "blah"}, image_config=ImageConfig( default_image=default_img, images=[default_img], ), flytekit_virtualenv_root="/opt/venv/blah", python_interpreter="/opt/venv/bin/python3", fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir="/opt/blah/blah/blah", distribution_location="s3://my-special-bucket/blah/bha/asdasdasd/cbvsdsdf/asdddasdasdasdasdasdasd.tar.gz", ), ) tp = serialization_settings.serialized_context with_serialized = serialization_settings.with_serialized_context() assert serialization_settings.env == {"hello": "blah"} assert with_serialized.env assert with_serialized.env[SERIALIZED_CONTEXT_ENV_VAR] == tp ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings assert len(tp) == 376
def test_pod_task_serialized(): pod = Pod( pod_spec=get_pod_spec(), primary_container_name="an undefined container", labels={"label": "foo"}, annotations={"anno": "bar"}, ) @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") ssettings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task) assert serialized.template.task_type_version == 2 assert serialized.template.config[ "primary_container_name"] == "an undefined container" assert serialized.template.k8s_pod.metadata.labels == {"label": "foo"} assert serialized.template.k8s_pod.metadata.annotations == {"anno": "bar"} assert serialized.template.k8s_pod.pod_spec is not None
def test_pytorch_task(): @task( task_config=PyTorch(num_workers=10), cache=True, cache_version="1", requests=Resources(cpu="1"), ) def my_pytorch_task(x: int, y: str) -> int: return x assert my_pytorch_task(x=10, y="hello") == 10 assert my_pytorch_task.task_config is not None default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) assert my_pytorch_task.get_custom(settings) == {"workers": 10} assert my_pytorch_task.resources.limits == Resources() assert my_pytorch_task.resources.requests == Resources(cpu="1") assert my_pytorch_task.task_type == "pytorch"
def test_mpi_task(): @task( task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1), requests=Resources(cpu="1"), cache=True, cache_version="1", ) def my_mpi_task(x: int, y: str) -> int: return x assert my_mpi_task(x=10, y="hello") == 10 assert my_mpi_task.task_config is not None default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) assert my_mpi_task.get_custom(settings) == { "numLauncherReplicas": 10, "numWorkers": 10, "slots": 1 } assert my_mpi_task.task_type == "mpi"
def test_register_a_hello_world_wf(): version = get_version("1") ss = SerializationSettings(image_config, project="flytesnacks", domain="development", version=version) rr.register_workflow(hello_wf, serialization_settings=ss) fetched_wf = rr.fetch_workflow(name=hello_wf.name, version=version) rr.execute(fetched_wf, inputs={"a": 5})
def _get_reg_settings(): default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) return settings
def test_dc_dyn_directory(folders_and_files_setup): proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file") proxy_p = MyProxyParameters(id="pp_id", job_i_step=1) my_input_gcs = MyInput( main_product=FlyteFile(folders_and_files_setup[0]), apriori_config=MyAprioriConfiguration( static_data_dir=FlyteDirectory("gs://my-bucket/one"), external_data_dir=FlyteDirectory("gs://my-bucket/two"), ), proxy_config=proxy_c, proxy_params=proxy_p, ) my_input_gcs_2 = MyInput( main_product=FlyteFile(folders_and_files_setup[0]), apriori_config=MyAprioriConfiguration( static_data_dir=FlyteDirectory("gs://my-bucket/three"), external_data_dir=FlyteDirectory("gs://my-bucket/four"), ), proxy_config=proxy_c, proxy_params=proxy_p, ) @dynamic def dt1(a: List[MyInput]) -> List[FlyteDirectory]: x = [] for aa in a: x.append(aa.apriori_config.external_data_dir) return x ctx = FlyteContextManager.current_context() cb = ( ctx.new_builder() .with_serialization_settings( SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ) .with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) ) with FlyteContextManager.with_context(cb) as ctx: input_literal_map = TypeEngine.dict_to_literal_map( ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, type_hints={"a": List[MyInput]} ) dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map) assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two" assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four"
def test_pod_task_undefined_primary(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="an undefined container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") pod_spec = simple_pod_task.get_k8s_pod( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )).pod_spec assert len(pod_spec["containers"]) == 3 primary_container = pod_spec["containers"][2] assert primary_container["name"] == "an undefined container" config = simple_pod_task.get_config( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert config["primary_container_name"] == "an undefined container"
def test_spark_template_with_remote(): @task(task_config=Spark(spark_conf={"spark": "1"})) def my_spark(a: str) -> int: return 10 @task def my_python_task(a: str) -> int: return 10 remote = FlyteRemote(config=Config.for_endpoint(endpoint="localhost", insecure=True), default_project="p1", default_domain="d1") mock_client = MagicMock() remote._client = mock_client remote.register_task( my_spark, serialization_settings=SerializationSettings( image_config=MagicMock(), ), version="v1", ) serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"] print(serialized_spec) # Check if the serialized spark task has mainApplicaitonFile field set. assert serialized_spec.template.custom["mainApplicationFile"] assert serialized_spec.template.custom["sparkConf"] remote.register_task( my_python_task, serialization_settings=SerializationSettings(image_config=MagicMock()), version="v1") serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"] # Check if the serialized python task has no mainApplicaitonFile field set by default. assert serialized_spec.template.custom is None
def test_py_func_task_get_container(): def foo(i: int): pass default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1") other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other") cfg = ImageConfig(default_image=default_img, images=[default_img, other_img]) settings = SerializationSettings(project="p", domain="d", version="v", image_config=cfg, env={"FOO": "bar"}) pytask = PythonFunctionTask(None, foo, None, environment={"BAZ": "baz"}) c = pytask.get_container(settings) assert c.image == "xyz.com/abc:tag1" assert c.env == {"FOO": "bar", "BAZ": "baz"}
def test_fast_pod_task_serialization(): pod = Pod( pod_spec=V1PodSpec(restart_policy="OnFailure", containers=[V1Container(name="primary")]), primary_container_name="primary", ) @task(task_config=pod, environment={"FOO": "bar"}) def simple_pod_task(i: int): pass default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), fast_serialization_settings=FastSerializationSettings(enabled=True), ) serialized = get_serializable(OrderedDict(), serialization_settings, simple_pod_task) assert serialized.template.k8s_pod.pod_spec["containers"][0]["args"] == [ "pyflyte-fast-execute", "--additional-distribution", "{{ .remote_package_path }}", "--dest-dir", "{{ .dest_dir }}", "--", "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ]
def test_serialization(): snowflake_task = SnowflakeTask( name="flytekit.demo.snowflake_task.query", inputs=kwtypes(ds=str), task_config=SnowflakeConfig(account="snowflake", warehouse="my_warehouse", schema="my_schema", database="my_database"), query_template=query_template, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(ds: str) -> FlyteSchema: return snowflake_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, snowflake_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement assert "insert overwrite directory" in task_spec.template.sql.statement assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI assert "snowflake" == task_spec.template.config["account"] assert "my_warehouse" == task_spec.template.config["warehouse"] assert "my_schema" == task_spec.template.config["schema"] assert "my_database" == task_spec.template.config["database"] assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs[ "o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.var == "results"
def test_map_pod_task_serialization(): pod = Pod( pod_spec=V1PodSpec(restart_policy="OnFailure", containers=[V1Container(name="primary")]), primary_container_name="primary", ) @task(task_config=pod, environment={"FOO": "bar"}) def simple_pod_task(i: int): pass mapped_task = map_task(simple_pod_task, metadata=TaskMetadata(retries=1)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) # Test that target is correctly serialized with an updated command pod_spec = mapped_task.get_k8s_pod(serialization_settings).pod_spec assert len(pod_spec["containers"]) == 1 assert pod_spec["containers"][0]["args"] == [ "pyflyte-map-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ] assert { "primary_container_name": "primary" } == mapped_task.get_config(serialization_settings)
def test_serialization(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs=kwtypes(my_schema=FlyteSchema, ds=str), config=HiveConfig(cluster_label="flyte"), query_template=""" set engine=tez; insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet -- will be unique per retry select * from blah where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}' """, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema: return hive_task(my_schema=in_schema, ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, hive_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["query"][ "query"] assert "insert overwrite directory" in task_spec.template.custom["query"][ "query"] assert len(task_spec.template.interface.inputs) == 2 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs[ "o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.var == "results"
def serialize_all( pkgs: typing.List[str] = None, local_source_root: typing.Optional[str] = None, folder: typing.Optional[str] = None, mode: typing.Optional[SerializationMode] = None, image: typing.Optional[str] = None, flytekit_virtualenv_root: typing.Optional[str] = None, python_interpreter: typing.Optional[str] = None, config_file: typing.Optional[str] = None, ): """ This function will write to the folder specified the following protobuf types :: flyteidl.admin.launch_plan_pb2.LaunchPlan flyteidl.admin.workflow_pb2.WorkflowSpec flyteidl.admin.task_pb2.TaskSpec These can be inspected by calling (in the launch plan case) :: flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the entity type. :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization. :param local_source_root: Where to start looking for the code. :param folder: Where to write the output protobuf files :param mode: Regular vs fast :param image: The fully qualified and versioned default image to use :param flytekit_virtualenv_root: The full path of the virtual env in the container. """ if not (mode == SerializationMode.DEFAULT or mode == SerializationMode.FAST): raise AssertionError(f"Unrecognized serialization mode: {mode}") serialization_settings = SerializationSettings( image_config=ImageConfig.auto(config_file, img_name=image), fast_serialization_settings=FastSerializationSettings( enabled=mode == SerializationMode.FAST, # TODO: if we want to move the destination dir as a serialization argument, we should initialize it here ), flytekit_virtualenv_root=flytekit_virtualenv_root, python_interpreter=python_interpreter, ) serialize_to_folder(pkgs, serialization_settings, local_source_root, folder)
def test_serialization(): athena_task = AthenaTask( name="flytekit.demo.athena_task.query", inputs=kwtypes(ds=str), task_config=AthenaConfig(database="mnist", catalog="my_catalog", workgroup="my_wg"), query_template=""" insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet select * from blah where ds = '{{ .Inputs.ds }}' """, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(ds: str) -> FlyteSchema: return athena_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, athena_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"] assert "insert overwrite directory" in task_spec.template.custom["statement"] assert "mnist" == task_spec.template.custom["schema"] assert "my_catalog" == task_spec.template.custom["catalog"] assert "my_wg" == task_spec.template.custom["routingGroup"] assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs["o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"
def test_serialization(): bigquery_task = BigQueryTask( name="flytekit.demo.bigquery_task.query", inputs=kwtypes(ds=str), task_config=BigQueryConfig( ProjectID="Flyte", Location="Asia", QueryJobConfig=QueryJobConfig(allow_large_results=True) ), query_template=query_template, output_structured_dataset_type=StructuredDataset, ) @workflow def my_wf(ds: str) -> StructuredDataset: return bigquery_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, bigquery_task) assert "SELECT * FROM `bigquery-public-data.crypto_dogecoin.transactions`" in task_spec.template.sql.statement assert "@version" in task_spec.template.sql.statement assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI s = Struct() s.update({"ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True}) assert task_spec.template.custom == json_format.MessageToDict(s) assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs["o0"].type.structured_dataset_type is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"
def test_aws_batch_task(): @task(task_config=config) def t1(a: int) -> str: inc = a + 2 return str(inc) assert t1.task_config is not None assert t1.task_config == config assert t1.task_type == "aws-batch" assert isinstance(t1, PythonFunctionTask) default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) assert t1.get_custom(settings) == config.to_dict() assert t1.get_command(settings) == [ "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}/0", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_aws_batch", "task-name", "t1", ]
def test_dynamic_pod_task(): dynamic_pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task def t1(a: int) -> int: return a + 10 @dynamic( task_config=dynamic_pod, requests=Resources(cpu="10"), limits=Resources(ephemeral_storage="1Gi", gpu="2"), environment={"FOO": "bar"}, ) def dynamic_pod_task(a: int) -> List[int]: s = [] for i in range(a): s.append(t1(a=i)) return s assert isinstance(dynamic_pod_task, PodFunctionTask) default_img = Image(name="default", fqn="test", tag="tag") pod_spec = dynamic_pod_task.get_k8s_pod( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )).pod_spec assert len(pod_spec["containers"]) == 2 primary_container = pod_spec["containers"][0] assert isinstance(dynamic_pod_task.task_config, Pod) assert primary_container["resources"] == { "requests": { "cpu": "10" }, "limits": { "ephemeral-storage": "1Gi", "gpu": "2" }, } config = dynamic_pod_task.get_config( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert config["primary_container_name"] == "a container" with context_manager.FlyteContextManager.with_context( context_manager.FlyteContext.current_context( ).with_serialization_settings( SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, ))) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION))) as ctx: dynamic_job_spec = dynamic_pod_task.compile_into_workflow( ctx, dynamic_pod_task._task_function, a=5) assert len(dynamic_job_spec._nodes) == 5
def test_pod_task_deserialization(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") target = simple_pod_task.get_k8s_pod( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) # Test that custom is correctly serialized by deserializing it with the python API client response = MagicMock() response.data = json.dumps(target.pod_spec) deserialized_pod_spec = ApiClient().deserialize(response, V1PodSpec) assert deserialized_pod_spec.restart_policy == "OnFailure" assert len(deserialized_pod_spec.containers) == 2 primary_container = deserialized_pod_spec.containers[0] assert primary_container.name == "a container" assert primary_container.args == [ "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ] assert primary_container.volume_mounts[0].mount_path == "some/where" assert primary_container.volume_mounts[0].name == "volume mount" assert primary_container.resources == V1ResourceRequirements( limits={"gpu": "2"}, requests={"cpu": "10"}) assert primary_container.env == [V1EnvVar(name="FOO", value="bar")] assert deserialized_pod_spec.containers[1].name == "another container" config = simple_pod_task.get_config( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert config["primary_container_name"] == "a container"
def test_fast(): REQUESTS_GPU = Resources(cpu="123m", mem="234Mi", ephemeral_storage="123M", gpu="1") LIMITS_GPU = Resources(cpu="124M", mem="235Mi", ephemeral_storage="124M", gpu="1") def get_minimal_pod_task_config() -> Pod: primary_container = V1Container(name="flytetask") pod_spec = V1PodSpec(containers=[primary_container]) return Pod(pod_spec=pod_spec, primary_container_name="flytetask") @task( task_config=get_minimal_pod_task_config(), requests=REQUESTS_GPU, limits=LIMITS_GPU, ) def pod_task_with_resources(dummy_input: str) -> str: return dummy_input @dynamic(requests=REQUESTS_GPU, limits=LIMITS_GPU) def dynamic_task_with_pod_subtask(dummy_input: str) -> str: pod_task_with_resources(dummy_input=dummy_input) return dummy_input default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir="/User/flyte/workflows", distribution_location="s3://my-s3-bucket/fast/123", ), ) with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings(serialization_settings)) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx: input_literal_map = TypeEngine.dict_to_literal_map( ctx, {"dummy_input": "hi"}) dynamic_job_spec = dynamic_task_with_pod_subtask.dispatch_execute( ctx, input_literal_map) # print(dynamic_job_spec) assert len(dynamic_job_spec._nodes) == 1 assert len(dynamic_job_spec.tasks) == 1 args = " ".join( dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0] ["args"]) assert args.startswith( "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 " "--dest-dir /User/flyte/workflows") assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][ "resources"]["limits"]["cpu"] == "124M" assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][ "resources"]["requests"]["gpu"] == "1" assert context_manager.FlyteContextManager.size() == 1
def default_serialization_settings(default_image_config): return SerializationSettings( project="p", domain="d", version="v", image_config=default_image_config, env={"FOO": "bar"} )
def minimal_serialization_settings(default_image_config): return SerializationSettings(project="p", domain="d", version="v", image_config=default_image_config)
def register( ctx: click.Context, project: str, domain: str, image_config: ImageConfig, output: str, destination_dir: str, service_account: str, raw_data_prefix: str, version: typing.Optional[str], package_or_module: typing.Tuple[str], ): """ see help """ pkgs = ctx.obj[constants.CTX_PACKAGES] if not pkgs: cli_logger.debug("No pkgs") if pkgs: raise ValueError( "Unimplemented, just specify pkgs like folder/files as args at the end of the command" ) if len(package_or_module) == 0: display_help_with_error( ctx, "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed", ) cli_logger.debug( f"Running pyflyte register from {os.getcwd()} " f"with images {image_config} " f"and image destinationfolder {destination_dir} " f"on {len(package_or_module)} package(s) {package_or_module}") # Create and save FlyteRemote, remote = get_and_save_remote_with_click_context(ctx, project, domain) # Todo: add switch for non-fast - skip the zipping and uploading and no fastserializationsettings # Create a zip file containing all the entries. detected_root = find_common_root(package_or_module) cli_logger.debug(f"Using {detected_root} as root folder for project") zip_file = fast_package(detected_root, output) # Upload zip file to Admin using FlyteRemote. md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file)) cli_logger.debug(f"Uploaded zip {zip_file} to {native_url}") # Create serialization settings # Todo: Rely on default Python interpreter for now, this will break custom Spark containers serialization_settings = SerializationSettings( project=project, domain=domain, image_config=image_config, fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir=destination_dir, distribution_location=native_url, ), ) options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) # Load all the entities registerable_entities = load_packages_and_modules(serialization_settings, detected_root, list(package_or_module), options) if len(registerable_entities) == 0: display_help_with_error(ctx, "No Flyte entities were detected. Aborting!") cli_logger.info( f"Found and serialized {len(registerable_entities)} entities") if not version: version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa cli_logger.info(f"Computed version is {version}") # Register using repo code repo_register(registerable_entities, project, domain, version, remote.client)
from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion from flytekit.models.core.workflow import WorkflowTemplate from flytekit.models.task import TaskTemplate from flytekit.remote import FlyteLaunchPlan, FlyteTask from flytekit.remote.interface import TypedInterface from flytekit.remote.workflow import FlyteWorkflow from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) @task def t1(a: int) -> int: return a + 2 @task def t2(a: int, b: str) -> str: return b + str(a)
class PythonCustomizedContainerTask(ExecutableTemplateShimTask, PythonTask[TC]): """ Please take a look at the comments for :py:class`flytekit.extend.ExecutableTemplateShimTask` as well. This class should be subclassed and a custom Executor provided as a default to this parent class constructor when building a new external-container flytekit-only plugin. This class provides authors of new task types the basic scaffolding to create task-template based tasks. In order to write such a task, authors need to * subclass the ``ShimTaskExecutor`` class and override the ``execute_from_model`` function. This function is where all the business logic should go. Keep in mind though that you, the plugin author, will not have access to anything that's not serialized within the ``TaskTemplate`` which is why you'll also need to * subclass this class, and override the ``get_custom`` function to include all the information the executor will need to run. * Also pass the executor you created as the ``executor_type`` argument of this class's constructor. Keep in mind that the total size of the ``TaskTemplate`` still needs to be small, since these will be accessed frequently by the Flyte engine. """ SERIALIZE_SETTINGS = SerializationSettings( project="PLACEHOLDER_PROJECT", domain="LOCAL", version="PLACEHOLDER_VERSION", env=None, image_config=ImageConfig( default_image=Image(name="custom_container_task", fqn="flyteorg.io/placeholder", tag="image")), ) def __init__( self, name: str, task_config: TC, container_image: str, executor_type: Type[ShimTaskExecutor], task_resolver: Optional[TaskTemplateResolver] = None, task_type="python-task", requests: Optional[Resources] = None, limits: Optional[Resources] = None, environment: Optional[Dict[str, str]] = None, secret_requests: Optional[List[Secret]] = None, **kwargs, ): """ :param name: unique name for the task, usually the function's module and name. :param task_config: Configuration object for Task. Should be a unique type for that specific Task :param container_image: This is the external container image the task should run at platform-run-time. :param executor: This is an executor which will actually provide the business logic. :param task_resolver: Custom resolver - if you don't make one, use the default task template resolver. :param task_type: String task type to be associated with this Task. :param requests: custom resource request settings. :param limits: custom resource limit settings. :param environment: Environment variables you want the task to have when run. :param List[Secret] secret_requests: Secrets that are requested by this container execution. These secrets will be mounted based on the configuration in the Secret and available through the SecretManager using the name of the secret as the group Ideally the secret keys should also be semi-descriptive. The key values will be available from runtime, if the backend is configured to provide secrets and if secrets are available in the configured secrets store. Possible options for secret stores are - `Vault <https://www.vaultproject.io/>`__ - `Confidant <https://lyft.github.io/confidant/>`__ - `Kube secrets <https://kubernetes.io/docs/concepts/configuration/secret/>`__ - `AWS Parameter store <https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html>`__ """ sec_ctx = None if secret_requests: for s in secret_requests: if not isinstance(s, Secret): raise AssertionError( f"Secret {s} should be of type flytekit.Secret, received {type(s)}" ) sec_ctx = SecurityContext(secrets=secret_requests) super().__init__( tt=None, executor_type=executor_type, task_type=task_type, name=name, task_config=task_config, security_ctx=sec_ctx, **kwargs, ) self._resources = ResourceSpec( requests=requests if requests else Resources(), limits=limits if limits else Resources()) self._environment = environment self._container_image = container_image self._task_resolver = task_resolver or default_task_template_resolver def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: # Overriding base implementation to raise an error, force task author to implement raise NotImplementedError def get_config(self, settings: SerializationSettings) -> Dict[str, str]: # Overriding base implementation but not doing anything. Technically this should be the task config, # but the IDL limitation that the value also has to be a string is very limiting. # Recommend putting information you need in the config into custom instead, because when serializing # the custom field, we jsonify custom and the place it into a protobuf struct. This config field # just gets put into a Dict[str, str] return {} @property def resources(self) -> ResourceSpec: return self._resources @property def task_resolver(self) -> TaskTemplateResolver: return self._task_resolver @property def task_template(self) -> Optional[_task_model.TaskTemplate]: """ Override the base class implementation to serialize on first call. """ return self._task_template or self.serialize_to_model( settings=PythonCustomizedContainerTask.SERIALIZE_SETTINGS) @property def container_image(self) -> str: return self._container_image def get_command(self, settings: SerializationSettings) -> List[str]: container_args = [ "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", self.task_resolver.location, "--", *self.task_resolver.loader_args(settings, self), ] return container_args def get_container( self, settings: SerializationSettings) -> _task_model.Container: env = { **settings.env, **self.environment } if self.environment else settings.env return _get_container_definition( image=self.container_image, command=[], args=self.get_command(settings=settings), data_loading_config=None, environment=env, storage_request=self.resources.requests.storage, cpu_request=self.resources.requests.cpu, gpu_request=self.resources.requests.gpu, memory_request=self.resources.requests.mem, storage_limit=self.resources.limits.storage, cpu_limit=self.resources.limits.cpu, gpu_limit=self.resources.limits.gpu, memory_limit=self.resources.limits.mem, ) def serialize_to_model( self, settings: SerializationSettings) -> _task_model.TaskTemplate: # This doesn't get called from translator unfortunately. Will need to move the translator to use the model # objects directly first. # Note: This doesn't settle the issue of duplicate registrations. We'll need to figure that out somehow. # TODO: After new control plane classes are in, promote the template to a FlyteTask, so that authors of # customized-container tasks have a familiar thing to work with. obj = _task_model.TaskTemplate( identifier_models.Identifier(identifier_models.ResourceType.TASK, settings.project, settings.domain, self.name, settings.version), self.task_type, self.metadata.to_taskmetadata_model(), self.interface, self.get_custom(settings), container=self.get_container(settings), config=self.get_config(settings), ) self._task_template = obj return obj