def test_serialization_settings_transport(): default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"hello": "blah"}, image_config=ImageConfig( default_image=default_img, images=[default_img], ), flytekit_virtualenv_root="/opt/venv/blah", python_interpreter="/opt/venv/bin/python3", fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir="/opt/blah/blah/blah", distribution_location="s3://my-special-bucket/blah/bha/asdasdasd/cbvsdsdf/asdddasdasdasdasdasdasd.tar.gz", ), ) tp = serialization_settings.serialized_context with_serialized = serialization_settings.with_serialized_context() assert serialization_settings.env == {"hello": "blah"} assert with_serialized.env assert with_serialized.env[SERIALIZED_CONTEXT_ENV_VAR] == tp ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings assert len(tp) == 376
def test_pod_task_serialized(): pod = Pod( pod_spec=get_pod_spec(), primary_container_name="an undefined container", labels={"label": "foo"}, annotations={"anno": "bar"}, ) @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") ssettings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task) assert serialized.template.task_type_version == 2 assert serialized.template.config[ "primary_container_name"] == "an undefined container" assert serialized.template.k8s_pod.metadata.labels == {"label": "foo"} assert serialized.template.k8s_pod.metadata.annotations == {"anno": "bar"} assert serialized.template.k8s_pod.pod_spec is not None
def test_tensorflow_task(): @task( task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1), cache=True, requests=Resources(cpu="1"), cache_version="1", ) def my_tensorflow_task(x: int, y: str) -> int: return x assert my_tensorflow_task(x=10, y="hello") == 10 assert my_tensorflow_task.task_config is not None default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) assert my_tensorflow_task.get_custom(settings) == {"workers": 10, "psReplicas": 1, "chiefReplicas": 1} assert my_tensorflow_task.resources.limits == Resources() assert my_tensorflow_task.resources.requests == Resources(cpu="1") assert my_tensorflow_task.task_type == "tensorflow"
def test_mpi_task(): @task( task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1), requests=Resources(cpu="1"), cache=True, cache_version="1", ) def my_mpi_task(x: int, y: str) -> int: return x assert my_mpi_task(x=10, y="hello") == 10 assert my_mpi_task.task_config is not None default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) assert my_mpi_task.get_custom(settings) == { "numLauncherReplicas": 10, "numWorkers": 10, "slots": 1 } assert my_mpi_task.task_type == "mpi"
def test_query_no_inputs_or_outputs(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs={}, task_config=HiveConfig(cluster_label="flyte"), query_template=""" insert into extant_table (1, 'two') """, output_schema_type=None, ) @workflow def my_wf(): hive_task() default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, hive_task) assert len(task_spec.template.interface.inputs) == 0 assert len(task_spec.template.interface.outputs) == 0 get_serializable(OrderedDict(), serialization_settings, my_wf)
def package(ctx, image_config, source, output, force, fast, in_container_source_path, python_interpreter): """ This command produces a Flyte backend registrable package of all entities in Flyte. For tasks, one pb file is produced for each task, representing one TaskTemplate object. For workflows, one pb file is produced for each workflow, representing a WorkflowClosure object. The closure object contains the WorkflowTemplate, along with the relevant tasks for that workflow. This serialization step will set the name of the tasks to the fully qualified name of the task function. """ if os.path.exists(output) and not force: raise click.BadParameter(click.style(f"Output file {output} already exists, specify -f to override.", fg="red")) serialization_settings = SerializationSettings( image_config=image_config, fast_serialization_settings=FastSerializationSettings( enabled=fast, destination_dir=in_container_source_path, ), python_interpreter=python_interpreter, ) pkgs = ctx.obj[constants.CTX_PACKAGES] if not pkgs: display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!") try: serialize_and_package(pkgs, serialization_settings, source, output, fast) except NoSerializableEntitiesError: click.secho(f"No flyte objects found in packages {pkgs}", fg="yellow")
def test_two(two_sample_inputs): my_input = two_sample_inputs[0] my_input_2 = two_sample_inputs[1] @dynamic def dt1(a: List[MyInput]) -> List[FlyteFile]: x = [] for aa in a: x.append(aa.main_product) return x with FlyteContextManager.with_context( FlyteContextManager.current_context().with_serialization_settings( SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ) ) as ctx: with FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ) ) ) as ctx: input_literal_map = TypeEngine.dict_to_literal_map( ctx, d={"a": [my_input, my_input_2]}, type_hints={"a": List[MyInput]} ) dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map) assert len(dynamic_job_spec.literals["o0"].collection.literals) == 2
def test_pytorch_task(): @task( task_config=PyTorch(num_workers=10), cache=True, cache_version="1", requests=Resources(cpu="1"), ) def my_pytorch_task(x: int, y: str) -> int: return x assert my_pytorch_task(x=10, y="hello") == 10 assert my_pytorch_task.task_config is not None default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) assert my_pytorch_task.get_custom(settings) == {"workers": 10} assert my_pytorch_task.resources.limits == Resources() assert my_pytorch_task.resources.requests == Resources(cpu="1") assert my_pytorch_task.task_type == "pytorch"
def test_register_a_hello_world_wf(): version = get_version("1") ss = SerializationSettings(image_config, project="flytesnacks", domain="development", version=version) rr.register_workflow(hello_wf, serialization_settings=ss) fetched_wf = rr.fetch_workflow(name=hello_wf.name, version=version) rr.execute(fetched_wf, inputs={"a": 5})
def _get_reg_settings(): default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) return settings
def test_dc_dyn_directory(folders_and_files_setup): proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file") proxy_p = MyProxyParameters(id="pp_id", job_i_step=1) my_input_gcs = MyInput( main_product=FlyteFile(folders_and_files_setup[0]), apriori_config=MyAprioriConfiguration( static_data_dir=FlyteDirectory("gs://my-bucket/one"), external_data_dir=FlyteDirectory("gs://my-bucket/two"), ), proxy_config=proxy_c, proxy_params=proxy_p, ) my_input_gcs_2 = MyInput( main_product=FlyteFile(folders_and_files_setup[0]), apriori_config=MyAprioriConfiguration( static_data_dir=FlyteDirectory("gs://my-bucket/three"), external_data_dir=FlyteDirectory("gs://my-bucket/four"), ), proxy_config=proxy_c, proxy_params=proxy_p, ) @dynamic def dt1(a: List[MyInput]) -> List[FlyteDirectory]: x = [] for aa in a: x.append(aa.apriori_config.external_data_dir) return x ctx = FlyteContextManager.current_context() cb = ( ctx.new_builder() .with_serialization_settings( SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ) .with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) ) with FlyteContextManager.with_context(cb) as ctx: input_literal_map = TypeEngine.dict_to_literal_map( ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, type_hints={"a": List[MyInput]} ) dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map) assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two" assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four"
def test_pod_task_undefined_primary(): pod = Pod(pod_spec=get_pod_spec(), primary_container_name="an undefined container") @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) def simple_pod_task(i: int): pass assert isinstance(simple_pod_task, PodFunctionTask) assert simple_pod_task.task_config == pod default_img = Image(name="default", fqn="test", tag="tag") pod_spec = simple_pod_task.get_k8s_pod( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )).pod_spec assert len(pod_spec["containers"]) == 3 primary_container = pod_spec["containers"][2] assert primary_container["name"] == "an undefined container" config = simple_pod_task.get_config( SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), )) assert config["primary_container_name"] == "an undefined container"
def test_spark_template_with_remote(): @task(task_config=Spark(spark_conf={"spark": "1"})) def my_spark(a: str) -> int: return 10 @task def my_python_task(a: str) -> int: return 10 remote = FlyteRemote(config=Config.for_endpoint(endpoint="localhost", insecure=True), default_project="p1", default_domain="d1") mock_client = MagicMock() remote._client = mock_client remote.register_task( my_spark, serialization_settings=SerializationSettings( image_config=MagicMock(), ), version="v1", ) serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"] print(serialized_spec) # Check if the serialized spark task has mainApplicaitonFile field set. assert serialized_spec.template.custom["mainApplicationFile"] assert serialized_spec.template.custom["sparkConf"] remote.register_task( my_python_task, serialization_settings=SerializationSettings(image_config=MagicMock()), version="v1") serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"] # Check if the serialized python task has no mainApplicaitonFile field set by default. assert serialized_spec.template.custom is None
def test_fast_pod_task_serialization(): pod = Pod( pod_spec=V1PodSpec(restart_policy="OnFailure", containers=[V1Container(name="primary")]), primary_container_name="primary", ) @task(task_config=pod, environment={"FOO": "bar"}) def simple_pod_task(i: int): pass default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), fast_serialization_settings=FastSerializationSettings(enabled=True), ) serialized = get_serializable(OrderedDict(), serialization_settings, simple_pod_task) assert serialized.template.k8s_pod.pod_spec["containers"][0]["args"] == [ "pyflyte-fast-execute", "--additional-distribution", "{{ .remote_package_path }}", "--dest-dir", "{{ .dest_dir }}", "--", "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ]
def test_py_func_task_get_container(): def foo(i: int): pass default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1") other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other") cfg = ImageConfig(default_image=default_img, images=[default_img, other_img]) settings = SerializationSettings(project="p", domain="d", version="v", image_config=cfg, env={"FOO": "bar"}) pytask = PythonFunctionTask(None, foo, None, environment={"BAZ": "baz"}) c = pytask.get_container(settings) assert c.image == "xyz.com/abc:tag1" assert c.env == {"FOO": "bar", "BAZ": "baz"}
def test_serialization(): snowflake_task = SnowflakeTask( name="flytekit.demo.snowflake_task.query", inputs=kwtypes(ds=str), task_config=SnowflakeConfig(account="snowflake", warehouse="my_warehouse", schema="my_schema", database="my_database"), query_template=query_template, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(ds: str) -> FlyteSchema: return snowflake_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, snowflake_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement assert "insert overwrite directory" in task_spec.template.sql.statement assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI assert "snowflake" == task_spec.template.config["account"] assert "my_warehouse" == task_spec.template.config["warehouse"] assert "my_schema" == task_spec.template.config["schema"] assert "my_database" == task_spec.template.config["database"] assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs[ "o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.var == "results"
def test_map_pod_task_serialization(): pod = Pod( pod_spec=V1PodSpec(restart_policy="OnFailure", containers=[V1Container(name="primary")]), primary_container_name="primary", ) @task(task_config=pod, environment={"FOO": "bar"}) def simple_pod_task(i: int): pass mapped_task = map_task(simple_pod_task, metadata=TaskMetadata(retries=1)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) # Test that target is correctly serialized with an updated command pod_spec = mapped_task.get_k8s_pod(serialization_settings).pod_spec assert len(pod_spec["containers"]) == 1 assert pod_spec["containers"][0]["args"] == [ "pyflyte-map-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ] assert { "primary_container_name": "primary" } == mapped_task.get_config(serialization_settings)
def test_serialization(): hive_task = HiveTask( name="flytekit.demo.hive_task.hivequery1", inputs=kwtypes(my_schema=FlyteSchema, ds=str), config=HiveConfig(cluster_label="flyte"), query_template=""" set engine=tez; insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet -- will be unique per retry select * from blah where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}' """, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema: return hive_task(my_schema=in_schema, ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, hive_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["query"][ "query"] assert "insert overwrite directory" in task_spec.template.custom["query"][ "query"] assert len(task_spec.template.interface.inputs) == 2 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs[ "o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[ 0].binding.promise.var == "results"
def serialize_all( pkgs: typing.List[str] = None, local_source_root: typing.Optional[str] = None, folder: typing.Optional[str] = None, mode: typing.Optional[SerializationMode] = None, image: typing.Optional[str] = None, flytekit_virtualenv_root: typing.Optional[str] = None, python_interpreter: typing.Optional[str] = None, config_file: typing.Optional[str] = None, ): """ This function will write to the folder specified the following protobuf types :: flyteidl.admin.launch_plan_pb2.LaunchPlan flyteidl.admin.workflow_pb2.WorkflowSpec flyteidl.admin.task_pb2.TaskSpec These can be inspected by calling (in the launch plan case) :: flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the entity type. :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization. :param local_source_root: Where to start looking for the code. :param folder: Where to write the output protobuf files :param mode: Regular vs fast :param image: The fully qualified and versioned default image to use :param flytekit_virtualenv_root: The full path of the virtual env in the container. """ if not (mode == SerializationMode.DEFAULT or mode == SerializationMode.FAST): raise AssertionError(f"Unrecognized serialization mode: {mode}") serialization_settings = SerializationSettings( image_config=ImageConfig.auto(config_file, img_name=image), fast_serialization_settings=FastSerializationSettings( enabled=mode == SerializationMode.FAST, # TODO: if we want to move the destination dir as a serialization argument, we should initialize it here ), flytekit_virtualenv_root=flytekit_virtualenv_root, python_interpreter=python_interpreter, ) serialize_to_folder(pkgs, serialization_settings, local_source_root, folder)
def test_serialization(): athena_task = AthenaTask( name="flytekit.demo.athena_task.query", inputs=kwtypes(ds=str), task_config=AthenaConfig(database="mnist", catalog="my_catalog", workgroup="my_wg"), query_template=""" insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet select * from blah where ds = '{{ .Inputs.ds }}' """, # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @workflow def my_wf(ds: str) -> FlyteSchema: return athena_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, athena_task) assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"] assert "insert overwrite directory" in task_spec.template.custom["statement"] assert "mnist" == task_spec.template.custom["schema"] assert "my_catalog" == task_spec.template.custom["catalog"] assert "my_wg" == task_spec.template.custom["routingGroup"] assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs["o0"].type.schema is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"
def test_serialization(): bigquery_task = BigQueryTask( name="flytekit.demo.bigquery_task.query", inputs=kwtypes(ds=str), task_config=BigQueryConfig( ProjectID="Flyte", Location="Asia", QueryJobConfig=QueryJobConfig(allow_large_results=True) ), query_template=query_template, output_structured_dataset_type=StructuredDataset, ) @workflow def my_wf(ds: str) -> StructuredDataset: return bigquery_task(ds=ds) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) task_spec = get_serializable(OrderedDict(), serialization_settings, bigquery_task) assert "SELECT * FROM `bigquery-public-data.crypto_dogecoin.transactions`" in task_spec.template.sql.statement assert "@version" in task_spec.template.sql.statement assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI s = Struct() s.update({"ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True}) assert task_spec.template.custom == json_format.MessageToDict(s) assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert admin_workflow_spec.template.interface.outputs["o0"].type.structured_dataset_type is not None assert admin_workflow_spec.template.outputs[0].var == "o0" assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0" assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"
def test_aws_batch_task(): @task(task_config=config) def t1(a: int) -> str: inc = a + 2 return str(inc) assert t1.task_config is not None assert t1.task_config == config assert t1.task_type == "aws-batch" assert isinstance(t1, PythonFunctionTask) default_img = Image(name="default", fqn="test", tag="tag") settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) assert t1.get_custom(settings) == config.to_dict() assert t1.get_command(settings) == [ "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}/0", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_aws_batch", "task-name", "t1", ]
class PythonCustomizedContainerTask(ExecutableTemplateShimTask, PythonTask[TC]): """ Please take a look at the comments for :py:class`flytekit.extend.ExecutableTemplateShimTask` as well. This class should be subclassed and a custom Executor provided as a default to this parent class constructor when building a new external-container flytekit-only plugin. This class provides authors of new task types the basic scaffolding to create task-template based tasks. In order to write such a task, authors need to * subclass the ``ShimTaskExecutor`` class and override the ``execute_from_model`` function. This function is where all the business logic should go. Keep in mind though that you, the plugin author, will not have access to anything that's not serialized within the ``TaskTemplate`` which is why you'll also need to * subclass this class, and override the ``get_custom`` function to include all the information the executor will need to run. * Also pass the executor you created as the ``executor_type`` argument of this class's constructor. Keep in mind that the total size of the ``TaskTemplate`` still needs to be small, since these will be accessed frequently by the Flyte engine. """ SERIALIZE_SETTINGS = SerializationSettings( project="PLACEHOLDER_PROJECT", domain="LOCAL", version="PLACEHOLDER_VERSION", env=None, image_config=ImageConfig( default_image=Image(name="custom_container_task", fqn="flyteorg.io/placeholder", tag="image")), ) def __init__( self, name: str, task_config: TC, container_image: str, executor_type: Type[ShimTaskExecutor], task_resolver: Optional[TaskTemplateResolver] = None, task_type="python-task", requests: Optional[Resources] = None, limits: Optional[Resources] = None, environment: Optional[Dict[str, str]] = None, secret_requests: Optional[List[Secret]] = None, **kwargs, ): """ :param name: unique name for the task, usually the function's module and name. :param task_config: Configuration object for Task. Should be a unique type for that specific Task :param container_image: This is the external container image the task should run at platform-run-time. :param executor: This is an executor which will actually provide the business logic. :param task_resolver: Custom resolver - if you don't make one, use the default task template resolver. :param task_type: String task type to be associated with this Task. :param requests: custom resource request settings. :param limits: custom resource limit settings. :param environment: Environment variables you want the task to have when run. :param List[Secret] secret_requests: Secrets that are requested by this container execution. These secrets will be mounted based on the configuration in the Secret and available through the SecretManager using the name of the secret as the group Ideally the secret keys should also be semi-descriptive. The key values will be available from runtime, if the backend is configured to provide secrets and if secrets are available in the configured secrets store. Possible options for secret stores are - `Vault <https://www.vaultproject.io/>`__ - `Confidant <https://lyft.github.io/confidant/>`__ - `Kube secrets <https://kubernetes.io/docs/concepts/configuration/secret/>`__ - `AWS Parameter store <https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html>`__ """ sec_ctx = None if secret_requests: for s in secret_requests: if not isinstance(s, Secret): raise AssertionError( f"Secret {s} should be of type flytekit.Secret, received {type(s)}" ) sec_ctx = SecurityContext(secrets=secret_requests) super().__init__( tt=None, executor_type=executor_type, task_type=task_type, name=name, task_config=task_config, security_ctx=sec_ctx, **kwargs, ) self._resources = ResourceSpec( requests=requests if requests else Resources(), limits=limits if limits else Resources()) self._environment = environment self._container_image = container_image self._task_resolver = task_resolver or default_task_template_resolver def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: # Overriding base implementation to raise an error, force task author to implement raise NotImplementedError def get_config(self, settings: SerializationSettings) -> Dict[str, str]: # Overriding base implementation but not doing anything. Technically this should be the task config, # but the IDL limitation that the value also has to be a string is very limiting. # Recommend putting information you need in the config into custom instead, because when serializing # the custom field, we jsonify custom and the place it into a protobuf struct. This config field # just gets put into a Dict[str, str] return {} @property def resources(self) -> ResourceSpec: return self._resources @property def task_resolver(self) -> TaskTemplateResolver: return self._task_resolver @property def task_template(self) -> Optional[_task_model.TaskTemplate]: """ Override the base class implementation to serialize on first call. """ return self._task_template or self.serialize_to_model( settings=PythonCustomizedContainerTask.SERIALIZE_SETTINGS) @property def container_image(self) -> str: return self._container_image def get_command(self, settings: SerializationSettings) -> List[str]: container_args = [ "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", self.task_resolver.location, "--", *self.task_resolver.loader_args(settings, self), ] return container_args def get_container( self, settings: SerializationSettings) -> _task_model.Container: env = { **settings.env, **self.environment } if self.environment else settings.env return _get_container_definition( image=self.container_image, command=[], args=self.get_command(settings=settings), data_loading_config=None, environment=env, storage_request=self.resources.requests.storage, cpu_request=self.resources.requests.cpu, gpu_request=self.resources.requests.gpu, memory_request=self.resources.requests.mem, storage_limit=self.resources.limits.storage, cpu_limit=self.resources.limits.cpu, gpu_limit=self.resources.limits.gpu, memory_limit=self.resources.limits.mem, ) def serialize_to_model( self, settings: SerializationSettings) -> _task_model.TaskTemplate: # This doesn't get called from translator unfortunately. Will need to move the translator to use the model # objects directly first. # Note: This doesn't settle the issue of duplicate registrations. We'll need to figure that out somehow. # TODO: After new control plane classes are in, promote the template to a FlyteTask, so that authors of # customized-container tasks have a familiar thing to work with. obj = _task_model.TaskTemplate( identifier_models.Identifier(identifier_models.ResourceType.TASK, settings.project, settings.domain, self.name, settings.version), self.task_type, self.metadata.to_taskmetadata_model(), self.interface, self.get_custom(settings), container=self.get_container(settings), config=self.get_config(settings), ) self._task_template = obj return obj
def test_fast(): REQUESTS_GPU = Resources(cpu="123m", mem="234Mi", ephemeral_storage="123M", gpu="1") LIMITS_GPU = Resources(cpu="124M", mem="235Mi", ephemeral_storage="124M", gpu="1") def get_minimal_pod_task_config() -> Pod: primary_container = V1Container(name="flytetask") pod_spec = V1PodSpec(containers=[primary_container]) return Pod(pod_spec=pod_spec, primary_container_name="flytetask") @task( task_config=get_minimal_pod_task_config(), requests=REQUESTS_GPU, limits=LIMITS_GPU, ) def pod_task_with_resources(dummy_input: str) -> str: return dummy_input @dynamic(requests=REQUESTS_GPU, limits=LIMITS_GPU) def dynamic_task_with_pod_subtask(dummy_input: str) -> str: pod_task_with_resources(dummy_input=dummy_input) return dummy_input default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir="/User/flyte/workflows", distribution_location="s3://my-s3-bucket/fast/123", ), ) with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings(serialization_settings)) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx: input_literal_map = TypeEngine.dict_to_literal_map( ctx, {"dummy_input": "hi"}) dynamic_job_spec = dynamic_task_with_pod_subtask.dispatch_execute( ctx, input_literal_map) # print(dynamic_job_spec) assert len(dynamic_job_spec._nodes) == 1 assert len(dynamic_job_spec.tasks) == 1 args = " ".join( dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0] ["args"]) assert args.startswith( "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 " "--dest-dir /User/flyte/workflows") assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][ "resources"]["limits"]["cpu"] == "124M" assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][ "resources"]["requests"]["gpu"] == "1" assert context_manager.FlyteContextManager.size() == 1
def default_serialization_settings(default_image_config): return SerializationSettings( project="p", domain="d", version="v", image_config=default_image_config, env={"FOO": "bar"} )
def minimal_serialization_settings(default_image_config): return SerializationSettings(project="p", domain="d", version="v", image_config=default_image_config)
def get_serializable_task( entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, ) -> task_models.TaskSpec: task_id = _identifier_model.Identifier( _identifier_model.ResourceType.TASK, settings.project, settings.domain, entity.name, settings.version, ) if isinstance( entity, PythonFunctionTask ) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC: # In case of Dynamic tasks, we want to pass the serialization context, so that they can reconstruct the state # from the serialization context. This is passed through an environment variable, that is read from # during dynamic serialization settings = settings.with_serialized_context() container = entity.get_container(settings) # This pod will be incorrect when doing fast serialize pod = entity.get_k8s_pod(settings) if settings.should_fast_serialize(): # This handles container tasks. if container and isinstance(entity, (PythonAutoContainerTask, MapPythonTask)): # For fast registration, we'll need to muck with the command, but on # ly for certain kinds of tasks. Specifically, # tasks that rely on user code defined in the container. This should be encapsulated by the auto container # parent class container._args = prefix_with_fast_execute(settings, container.args) # If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect. # The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because # the pod spec is a K8s library object, and we shouldn't be messing around with it in this file. elif pod: if isinstance(entity, MapPythonTask): entity.set_command_prefix( get_command_prefix_for_fast_execute(settings)) pod = entity.get_k8s_pod(settings) else: entity.set_command_fn( _fast_serialize_command_fn(settings, entity)) pod = entity.get_k8s_pod(settings) entity.reset_command_fn() tt = task_models.TaskTemplate( id=task_id, type=entity.task_type, metadata=entity.metadata.to_taskmetadata_model(), interface=entity.interface, custom=entity.get_custom(settings), container=container, task_type_version=entity.task_type_version, security_context=entity.security_context, config=entity.get_config(settings), k8s_pod=pod, sql=entity.get_sql(settings), ) if settings.should_fast_serialize() and isinstance( entity, PythonAutoContainerTask): entity.reset_command_fn() return task_models.TaskSpec(template=tt)
def register( ctx: click.Context, project: str, domain: str, image_config: ImageConfig, output: str, destination_dir: str, service_account: str, raw_data_prefix: str, version: typing.Optional[str], package_or_module: typing.Tuple[str], ): """ see help """ pkgs = ctx.obj[constants.CTX_PACKAGES] if not pkgs: cli_logger.debug("No pkgs") if pkgs: raise ValueError( "Unimplemented, just specify pkgs like folder/files as args at the end of the command" ) if len(package_or_module) == 0: display_help_with_error( ctx, "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed", ) cli_logger.debug( f"Running pyflyte register from {os.getcwd()} " f"with images {image_config} " f"and image destinationfolder {destination_dir} " f"on {len(package_or_module)} package(s) {package_or_module}") # Create and save FlyteRemote, remote = get_and_save_remote_with_click_context(ctx, project, domain) # Todo: add switch for non-fast - skip the zipping and uploading and no fastserializationsettings # Create a zip file containing all the entries. detected_root = find_common_root(package_or_module) cli_logger.debug(f"Using {detected_root} as root folder for project") zip_file = fast_package(detected_root, output) # Upload zip file to Admin using FlyteRemote. md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file)) cli_logger.debug(f"Uploaded zip {zip_file} to {native_url}") # Create serialization settings # Todo: Rely on default Python interpreter for now, this will break custom Spark containers serialization_settings = SerializationSettings( project=project, domain=domain, image_config=image_config, fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir=destination_dir, distribution_location=native_url, ), ) options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) # Load all the entities registerable_entities = load_packages_and_modules(serialization_settings, detected_root, list(package_or_module), options) if len(registerable_entities) == 0: display_help_with_error(ctx, "No Flyte entities were detected. Aborting!") cli_logger.info( f"Found and serialized {len(registerable_entities)} entities") if not version: version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa cli_logger.info(f"Computed version is {version}") # Register using repo code repo_register(registerable_entities, project, domain, version, remote.client)
from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion from flytekit.models.core.workflow import WorkflowTemplate from flytekit.models.task import TaskTemplate from flytekit.remote import FlyteLaunchPlan, FlyteTask from flytekit.remote.interface import TypedInterface from flytekit.remote.workflow import FlyteWorkflow from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) @task def t1(a: int) -> int: return a + 2 @task def t2(a: int, b: str) -> str: return b + str(a)
def setup_execution( raw_output_data_prefix: str, checkpoint_path: Optional[str] = None, prev_checkpoint: Optional[str] = None, dynamic_addl_distro: Optional[str] = None, dynamic_dest_dir: Optional[str] = None, ): """ :param raw_output_data_prefix: :param checkpoint_path: :param prev_checkpoint: :param dynamic_addl_distro: Works in concert with the other dynamic arg. If present, indicates that if a dynamic task were to run, it should set fast serialize to true and use these values in FastSerializationSettings :param dynamic_dest_dir: See above. :return: """ exe_project = get_one_of("FLYTE_INTERNAL_EXECUTION_PROJECT", "_F_PRJ") exe_domain = get_one_of("FLYTE_INTERNAL_EXECUTION_DOMAIN", "_F_DM") exe_name = get_one_of("FLYTE_INTERNAL_EXECUTION_ID", "_F_NM") exe_wf = get_one_of("FLYTE_INTERNAL_EXECUTION_WORKFLOW", "_F_WF") exe_lp = get_one_of("FLYTE_INTERNAL_EXECUTION_LAUNCHPLAN", "_F_LP") tk_project = get_one_of("FLYTE_INTERNAL_TASK_PROJECT", "_F_TK_PRJ") tk_domain = get_one_of("FLYTE_INTERNAL_TASK_DOMAIN", "_F_TK_DM") tk_name = get_one_of("FLYTE_INTERNAL_TASK_NAME", "_F_TK_NM") tk_version = get_one_of("FLYTE_INTERNAL_TASK_VERSION", "_F_TK_V") compressed_serialization_settings = os.environ.get(SERIALIZED_CONTEXT_ENV_VAR, "") ctx = FlyteContextManager.current_context() # Create directories user_workspace_dir = ctx.file_access.get_random_local_directory() logger.info(f"Using user directory {user_workspace_dir}") pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True) from flytekit import __version__ as _api_version checkpointer = None if checkpoint_path is not None: checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint) logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}") execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=exe_project, domain=exe_domain, name=exe_name, ), execution_date=_datetime.datetime.utcnow(), stats=_get_stats( cfg=StatsConfig.auto(), # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp prefix=f"{tk_project}.{tk_domain}.{tk_name}.user_stats", tags={ "exec_project": exe_project, "exec_domain": exe_domain, "exec_workflow": exe_wf, "exec_launchplan": exe_lp, "api_version": _api_version, }, ), logging=user_space_logger, tmp_dir=user_workspace_dir, raw_output_prefix=raw_output_data_prefix, checkpoint=checkpointer, task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version), ) try: file_access = FileAccessProvider( local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), raw_output_prefix=raw_output_data_prefix, ) except TypeError: # would be thrown from DataPersistencePlugins.find_plugin logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}") raise es = ctx.new_execution_state().with_params( mode=ExecutionState.Mode.TASK_EXECUTION, user_space_params=execution_parameters, ) cb = ctx.new_builder().with_file_access(file_access).with_execution_state(es) if compressed_serialization_settings: ss = SerializationSettings.from_transport(compressed_serialization_settings) ssb = ss.new_builder() ssb.project = exe_project ssb.domain = exe_domain ssb.version = tk_version if dynamic_addl_distro: ssb.fast_serialization_settings = FastSerializationSettings( enabled=True, destination_dir=dynamic_dest_dir, distribution_location=dynamic_addl_distro, ) cb = cb.with_serialization_settings(ssb.build()) with FlyteContextManager.with_context(cb) as ctx: yield ctx