示例#1
0
def package(ctx, image_config, source, output, force, fast, in_container_source_path, python_interpreter):
    """
    This command produces a Flyte backend registrable package of all entities in Flyte.
    For tasks, one pb file is produced for each task, representing one TaskTemplate object.
    For workflows, one pb file is produced for each workflow, representing a WorkflowClosure object.  The closure
        object contains the WorkflowTemplate, along with the relevant tasks for that workflow.
        This serialization step will set the name of the tasks to the fully qualified name of the task function.
    """
    if os.path.exists(output) and not force:
        raise click.BadParameter(click.style(f"Output file {output} already exists, specify -f to override.", fg="red"))

    serialization_settings = SerializationSettings(
        image_config=image_config,
        fast_serialization_settings=FastSerializationSettings(
            enabled=fast,
            destination_dir=in_container_source_path,
        ),
        python_interpreter=python_interpreter,
    )

    pkgs = ctx.obj[constants.CTX_PACKAGES]
    if not pkgs:
        display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!")

    try:
        serialize_and_package(pkgs, serialization_settings, source, output, fast)
    except NoSerializableEntitiesError:
        click.secho(f"No flyte objects found in packages {pkgs}", fg="yellow")
示例#2
0
def test_tensorflow_task():
    @task(
        task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1),
        cache=True,
        requests=Resources(cpu="1"),
        cache_version="1",
    )
    def my_tensorflow_task(x: int, y: str) -> int:
        return x

    assert my_tensorflow_task(x=10, y="hello") == 10

    assert my_tensorflow_task.task_config is not None

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
    )

    assert my_tensorflow_task.get_custom(settings) == {"workers": 10, "psReplicas": 1, "chiefReplicas": 1}
    assert my_tensorflow_task.resources.limits == Resources()
    assert my_tensorflow_task.resources.requests == Resources(cpu="1")
    assert my_tensorflow_task.task_type == "tensorflow"
示例#3
0
def test_query_no_inputs_or_outputs():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs={},
        task_config=HiveConfig(cluster_label="flyte"),
        query_template="""
            insert into extant_table (1, 'two')
        """,
        output_schema_type=None,
    )

    @workflow
    def my_wf():
        hive_task()

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 hive_task)
    assert len(task_spec.template.interface.inputs) == 0
    assert len(task_spec.template.interface.outputs) == 0

    get_serializable(OrderedDict(), serialization_settings, my_wf)
示例#4
0
def test_two(two_sample_inputs):
    my_input = two_sample_inputs[0]
    my_input_2 = two_sample_inputs[1]

    @dynamic
    def dt1(a: List[MyInput]) -> List[FlyteFile]:
        x = []
        for aa in a:
            x.append(aa.main_product)
        return x

    with FlyteContextManager.with_context(
        FlyteContextManager.current_context().with_serialization_settings(
            SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
    ) as ctx:
        with FlyteContextManager.with_context(
            ctx.with_execution_state(
                ctx.execution_state.with_params(
                    mode=ExecutionState.Mode.TASK_EXECUTION,
                )
            )
        ) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(
                ctx, d={"a": [my_input, my_input_2]}, type_hints={"a": List[MyInput]}
            )
            dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
            assert len(dynamic_job_spec.literals["o0"].collection.literals) == 2
示例#5
0
def test_serialization_settings_transport():
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"hello": "blah"},
        image_config=ImageConfig(
            default_image=default_img,
            images=[default_img],
        ),
        flytekit_virtualenv_root="/opt/venv/blah",
        python_interpreter="/opt/venv/bin/python3",
        fast_serialization_settings=FastSerializationSettings(
            enabled=True,
            destination_dir="/opt/blah/blah/blah",
            distribution_location="s3://my-special-bucket/blah/bha/asdasdasd/cbvsdsdf/asdddasdasdasdasdasdasd.tar.gz",
        ),
    )

    tp = serialization_settings.serialized_context
    with_serialized = serialization_settings.with_serialized_context()
    assert serialization_settings.env == {"hello": "blah"}
    assert with_serialized.env
    assert with_serialized.env[SERIALIZED_CONTEXT_ENV_VAR] == tp
    ss = SerializationSettings.from_transport(tp)
    assert ss is not None
    assert ss == serialization_settings
    assert len(tp) == 376
示例#6
0
def test_pod_task_serialized():
    pod = Pod(
        pod_spec=get_pod_spec(),
        primary_container_name="an undefined container",
        labels={"label": "foo"},
        annotations={"anno": "bar"},
    )

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")
    ssettings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task)
    assert serialized.template.task_type_version == 2
    assert serialized.template.config[
        "primary_container_name"] == "an undefined container"
    assert serialized.template.k8s_pod.metadata.labels == {"label": "foo"}
    assert serialized.template.k8s_pod.metadata.annotations == {"anno": "bar"}
    assert serialized.template.k8s_pod.pod_spec is not None
示例#7
0
def test_pytorch_task():
    @task(
        task_config=PyTorch(num_workers=10),
        cache=True,
        cache_version="1",
        requests=Resources(cpu="1"),
    )
    def my_pytorch_task(x: int, y: str) -> int:
        return x

    assert my_pytorch_task(x=10, y="hello") == 10

    assert my_pytorch_task.task_config is not None

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    assert my_pytorch_task.get_custom(settings) == {"workers": 10}
    assert my_pytorch_task.resources.limits == Resources()
    assert my_pytorch_task.resources.requests == Resources(cpu="1")
    assert my_pytorch_task.task_type == "pytorch"
示例#8
0
def test_mpi_task():
    @task(
        task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1),
        requests=Resources(cpu="1"),
        cache=True,
        cache_version="1",
    )
    def my_mpi_task(x: int, y: str) -> int:
        return x

    assert my_mpi_task(x=10, y="hello") == 10

    assert my_mpi_task.task_config is not None

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    assert my_mpi_task.get_custom(settings) == {
        "numLauncherReplicas": 10,
        "numWorkers": 10,
        "slots": 1
    }
    assert my_mpi_task.task_type == "mpi"
示例#9
0
def test_register_a_hello_world_wf():
    version = get_version("1")
    ss = SerializationSettings(image_config, project="flytesnacks", domain="development", version=version)
    rr.register_workflow(hello_wf, serialization_settings=ss)

    fetched_wf = rr.fetch_workflow(name=hello_wf.name, version=version)

    rr.execute(fetched_wf, inputs={"a": 5})
示例#10
0
def _get_reg_settings():
    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    return settings
示例#11
0
def test_dc_dyn_directory(folders_and_files_setup):
    proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file")
    proxy_p = MyProxyParameters(id="pp_id", job_i_step=1)

    my_input_gcs = MyInput(
        main_product=FlyteFile(folders_and_files_setup[0]),
        apriori_config=MyAprioriConfiguration(
            static_data_dir=FlyteDirectory("gs://my-bucket/one"),
            external_data_dir=FlyteDirectory("gs://my-bucket/two"),
        ),
        proxy_config=proxy_c,
        proxy_params=proxy_p,
    )

    my_input_gcs_2 = MyInput(
        main_product=FlyteFile(folders_and_files_setup[0]),
        apriori_config=MyAprioriConfiguration(
            static_data_dir=FlyteDirectory("gs://my-bucket/three"),
            external_data_dir=FlyteDirectory("gs://my-bucket/four"),
        ),
        proxy_config=proxy_c,
        proxy_params=proxy_p,
    )

    @dynamic
    def dt1(a: List[MyInput]) -> List[FlyteDirectory]:
        x = []
        for aa in a:
            x.append(aa.apriori_config.external_data_dir)

        return x

    ctx = FlyteContextManager.current_context()
    cb = (
        ctx.new_builder()
        .with_serialization_settings(
            SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
        .with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION))
    )
    with FlyteContextManager.with_context(cb) as ctx:
        input_literal_map = TypeEngine.dict_to_literal_map(
            ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, type_hints={"a": List[MyInput]}
        )
        dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
        assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two"
        assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four"
示例#12
0
def test_pod_task_undefined_primary():
    pod = Pod(pod_spec=get_pod_spec(),
              primary_container_name="an undefined container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")
    pod_spec = simple_pod_task.get_k8s_pod(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        )).pod_spec

    assert len(pod_spec["containers"]) == 3

    primary_container = pod_spec["containers"][2]
    assert primary_container["name"] == "an undefined container"

    config = simple_pod_task.get_config(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert config["primary_container_name"] == "an undefined container"
示例#13
0
def test_spark_template_with_remote():
    @task(task_config=Spark(spark_conf={"spark": "1"}))
    def my_spark(a: str) -> int:
        return 10

    @task
    def my_python_task(a: str) -> int:
        return 10

    remote = FlyteRemote(config=Config.for_endpoint(endpoint="localhost",
                                                    insecure=True),
                         default_project="p1",
                         default_domain="d1")

    mock_client = MagicMock()
    remote._client = mock_client

    remote.register_task(
        my_spark,
        serialization_settings=SerializationSettings(
            image_config=MagicMock(), ),
        version="v1",
    )
    serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"]

    print(serialized_spec)
    # Check if the serialized spark task has mainApplicaitonFile field set.
    assert serialized_spec.template.custom["mainApplicationFile"]
    assert serialized_spec.template.custom["sparkConf"]

    remote.register_task(
        my_python_task,
        serialization_settings=SerializationSettings(image_config=MagicMock()),
        version="v1")
    serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"]

    # Check if the serialized python task has no mainApplicaitonFile field set by default.
    assert serialized_spec.template.custom is None
def test_py_func_task_get_container():
    def foo(i: int):
        pass

    default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1")
    other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other")
    cfg = ImageConfig(default_image=default_img, images=[default_img, other_img])

    settings = SerializationSettings(project="p", domain="d", version="v", image_config=cfg, env={"FOO": "bar"})

    pytask = PythonFunctionTask(None, foo, None, environment={"BAZ": "baz"})
    c = pytask.get_container(settings)
    assert c.image == "xyz.com/abc:tag1"
    assert c.env == {"FOO": "bar", "BAZ": "baz"}
示例#15
0
def test_fast_pod_task_serialization():
    pod = Pod(
        pod_spec=V1PodSpec(restart_policy="OnFailure",
                           containers=[V1Container(name="primary")]),
        primary_container_name="primary",
    )

    @task(task_config=pod, environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        fast_serialization_settings=FastSerializationSettings(enabled=True),
    )
    serialized = get_serializable(OrderedDict(), serialization_settings,
                                  simple_pod_task)

    assert serialized.template.k8s_pod.pod_spec["containers"][0]["args"] == [
        "pyflyte-fast-execute",
        "--additional-distribution",
        "{{ .remote_package_path }}",
        "--dest-dir",
        "{{ .dest_dir }}",
        "--",
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--checkpoint-path",
        "{{.checkpointOutputPrefix}}",
        "--prev-checkpoint",
        "{{.prevCheckpointPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.test_pod",
        "task-name",
        "simple_pod_task",
    ]
示例#16
0
def test_serialization():
    snowflake_task = SnowflakeTask(
        name="flytekit.demo.snowflake_task.query",
        inputs=kwtypes(ds=str),
        task_config=SnowflakeConfig(account="snowflake",
                                    warehouse="my_warehouse",
                                    schema="my_schema",
                                    database="my_database"),
        query_template=query_template,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(ds: str) -> FlyteSchema:
        return snowflake_task(ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )

    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 snowflake_task)

    assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement
    assert "insert overwrite directory" in task_spec.template.sql.statement
    assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI
    assert "snowflake" == task_spec.template.config["account"]
    assert "my_warehouse" == task_spec.template.config["warehouse"]
    assert "my_schema" == task_spec.template.config["schema"]
    assert "my_database" == task_spec.template.config["database"]
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(),
                                           serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs[
        "o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.var == "results"
示例#17
0
def test_map_pod_task_serialization():
    pod = Pod(
        pod_spec=V1PodSpec(restart_policy="OnFailure",
                           containers=[V1Container(name="primary")]),
        primary_container_name="primary",
    )

    @task(task_config=pod, environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    mapped_task = map_task(simple_pod_task, metadata=TaskMetadata(retries=1))
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    # Test that target is correctly serialized with an updated command
    pod_spec = mapped_task.get_k8s_pod(serialization_settings).pod_spec

    assert len(pod_spec["containers"]) == 1
    assert pod_spec["containers"][0]["args"] == [
        "pyflyte-map-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--checkpoint-path",
        "{{.checkpointOutputPrefix}}",
        "--prev-checkpoint",
        "{{.prevCheckpointPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.test_pod",
        "task-name",
        "simple_pod_task",
    ]
    assert {
        "primary_container_name": "primary"
    } == mapped_task.get_config(serialization_settings)
示例#18
0
def test_serialization():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs=kwtypes(my_schema=FlyteSchema, ds=str),
        config=HiveConfig(cluster_label="flyte"),
        query_template="""
            set engine=tez;
            insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet  -- will be unique per retry
            select *
            from blah
            where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}'
        """,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema:
        return hive_task(my_schema=in_schema, ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 hive_task)
    assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["query"][
        "query"]
    assert "insert overwrite directory" in task_spec.template.custom["query"][
        "query"]
    assert len(task_spec.template.interface.inputs) == 2
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(),
                                           serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs[
        "o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.var == "results"
示例#19
0
def serialize_all(
    pkgs: typing.List[str] = None,
    local_source_root: typing.Optional[str] = None,
    folder: typing.Optional[str] = None,
    mode: typing.Optional[SerializationMode] = None,
    image: typing.Optional[str] = None,
    flytekit_virtualenv_root: typing.Optional[str] = None,
    python_interpreter: typing.Optional[str] = None,
    config_file: typing.Optional[str] = None,
):
    """
    This function will write to the folder specified the following protobuf types ::
        flyteidl.admin.launch_plan_pb2.LaunchPlan
        flyteidl.admin.workflow_pb2.WorkflowSpec
        flyteidl.admin.task_pb2.TaskSpec

    These can be inspected by calling (in the launch plan case) ::
        flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan

    See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the
    entity type.
    :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization.
    :param local_source_root: Where to start looking for the code.
    :param folder: Where to write the output protobuf files
    :param mode: Regular vs fast
    :param image: The fully qualified and versioned default image to use
    :param flytekit_virtualenv_root: The full path of the virtual env in the container.
    """

    if not (mode == SerializationMode.DEFAULT
            or mode == SerializationMode.FAST):
        raise AssertionError(f"Unrecognized serialization mode: {mode}")

    serialization_settings = SerializationSettings(
        image_config=ImageConfig.auto(config_file, img_name=image),
        fast_serialization_settings=FastSerializationSettings(
            enabled=mode == SerializationMode.FAST,
            # TODO: if we want to move the destination dir as a serialization argument, we should initialize it here
        ),
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        python_interpreter=python_interpreter,
    )

    serialize_to_folder(pkgs, serialization_settings, local_source_root,
                        folder)
示例#20
0
def test_serialization():
    athena_task = AthenaTask(
        name="flytekit.demo.athena_task.query",
        inputs=kwtypes(ds=str),
        task_config=AthenaConfig(database="mnist", catalog="my_catalog", workgroup="my_wg"),
        query_template="""
            insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet
            select *
            from blah
            where ds = '{{ .Inputs.ds }}'
        """,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(ds: str) -> FlyteSchema:
        return athena_task(ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings, athena_task)
    assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"]
    assert "insert overwrite directory" in task_spec.template.custom["statement"]
    assert "mnist" == task_spec.template.custom["schema"]
    assert "my_catalog" == task_spec.template.custom["catalog"]
    assert "my_wg" == task_spec.template.custom["routingGroup"]
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs["o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"
示例#21
0
def test_serialization():
    bigquery_task = BigQueryTask(
        name="flytekit.demo.bigquery_task.query",
        inputs=kwtypes(ds=str),
        task_config=BigQueryConfig(
            ProjectID="Flyte", Location="Asia", QueryJobConfig=QueryJobConfig(allow_large_results=True)
        ),
        query_template=query_template,
        output_structured_dataset_type=StructuredDataset,
    )

    @workflow
    def my_wf(ds: str) -> StructuredDataset:
        return bigquery_task(ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
        env={},
    )

    task_spec = get_serializable(OrderedDict(), serialization_settings, bigquery_task)

    assert "SELECT * FROM `bigquery-public-data.crypto_dogecoin.transactions`" in task_spec.template.sql.statement
    assert "@version" in task_spec.template.sql.statement
    assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI
    s = Struct()
    s.update({"ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True})
    assert task_spec.template.custom == json_format.MessageToDict(s)
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs["o0"].type.structured_dataset_type is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"
示例#22
0
def test_aws_batch_task():
    @task(task_config=config)
    def t1(a: int) -> str:
        inc = a + 2
        return str(inc)

    assert t1.task_config is not None
    assert t1.task_config == config
    assert t1.task_type == "aws-batch"
    assert isinstance(t1, PythonFunctionTask)

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    assert t1.get_custom(settings) == config.to_dict()
    assert t1.get_command(settings) == [
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}/0",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.test_aws_batch",
        "task-name",
        "t1",
    ]
示例#23
0
def test_dynamic_pod_task():
    dynamic_pod = Pod(pod_spec=get_pod_spec(),
                      primary_container_name="a container")

    @task
    def t1(a: int) -> int:
        return a + 10

    @dynamic(
        task_config=dynamic_pod,
        requests=Resources(cpu="10"),
        limits=Resources(ephemeral_storage="1Gi", gpu="2"),
        environment={"FOO": "bar"},
    )
    def dynamic_pod_task(a: int) -> List[int]:
        s = []
        for i in range(a):
            s.append(t1(a=i))
        return s

    assert isinstance(dynamic_pod_task, PodFunctionTask)
    default_img = Image(name="default", fqn="test", tag="tag")

    pod_spec = dynamic_pod_task.get_k8s_pod(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        )).pod_spec
    assert len(pod_spec["containers"]) == 2
    primary_container = pod_spec["containers"][0]
    assert isinstance(dynamic_pod_task.task_config, Pod)
    assert primary_container["resources"] == {
        "requests": {
            "cpu": "10"
        },
        "limits": {
            "ephemeral-storage": "1Gi",
            "gpu": "2"
        },
    }

    config = dynamic_pod_task.get_config(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert config["primary_container_name"] == "a container"

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContext.current_context(
            ).with_serialization_settings(
                SerializationSettings(
                    project="test_proj",
                    domain="test_domain",
                    version="abc",
                    image_config=ImageConfig(
                        Image(name="name", fqn="image", tag="name")),
                    env={},
                ))) as ctx:
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        mode=ExecutionState.Mode.TASK_EXECUTION))) as ctx:
            dynamic_job_spec = dynamic_pod_task.compile_into_workflow(
                ctx, dynamic_pod_task._task_function, a=5)
            assert len(dynamic_job_spec._nodes) == 5
示例#24
0
def test_pod_task_deserialization():
    pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")

    target = simple_pod_task.get_k8s_pod(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))

    # Test that custom is correctly serialized by deserializing it with the python API client
    response = MagicMock()
    response.data = json.dumps(target.pod_spec)
    deserialized_pod_spec = ApiClient().deserialize(response, V1PodSpec)

    assert deserialized_pod_spec.restart_policy == "OnFailure"
    assert len(deserialized_pod_spec.containers) == 2
    primary_container = deserialized_pod_spec.containers[0]
    assert primary_container.name == "a container"
    assert primary_container.args == [
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--checkpoint-path",
        "{{.checkpointOutputPrefix}}",
        "--prev-checkpoint",
        "{{.prevCheckpointPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.test_pod",
        "task-name",
        "simple_pod_task",
    ]
    assert primary_container.volume_mounts[0].mount_path == "some/where"
    assert primary_container.volume_mounts[0].name == "volume mount"
    assert primary_container.resources == V1ResourceRequirements(
        limits={"gpu": "2"}, requests={"cpu": "10"})
    assert primary_container.env == [V1EnvVar(name="FOO", value="bar")]
    assert deserialized_pod_spec.containers[1].name == "another container"

    config = simple_pod_task.get_config(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert config["primary_container_name"] == "a container"
示例#25
0
def test_fast():
    REQUESTS_GPU = Resources(cpu="123m",
                             mem="234Mi",
                             ephemeral_storage="123M",
                             gpu="1")
    LIMITS_GPU = Resources(cpu="124M",
                           mem="235Mi",
                           ephemeral_storage="124M",
                           gpu="1")

    def get_minimal_pod_task_config() -> Pod:
        primary_container = V1Container(name="flytetask")
        pod_spec = V1PodSpec(containers=[primary_container])
        return Pod(pod_spec=pod_spec, primary_container_name="flytetask")

    @task(
        task_config=get_minimal_pod_task_config(),
        requests=REQUESTS_GPU,
        limits=LIMITS_GPU,
    )
    def pod_task_with_resources(dummy_input: str) -> str:
        return dummy_input

    @dynamic(requests=REQUESTS_GPU, limits=LIMITS_GPU)
    def dynamic_task_with_pod_subtask(dummy_input: str) -> str:
        pod_task_with_resources(dummy_input=dummy_input)
        return dummy_input

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        fast_serialization_settings=FastSerializationSettings(
            enabled=True,
            destination_dir="/User/flyte/workflows",
            distribution_location="s3://my-s3-bucket/fast/123",
        ),
    )

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_serialization_settings(serialization_settings)) as ctx:
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(
                ctx, {"dummy_input": "hi"})
            dynamic_job_spec = dynamic_task_with_pod_subtask.dispatch_execute(
                ctx, input_literal_map)
            # print(dynamic_job_spec)
            assert len(dynamic_job_spec._nodes) == 1
            assert len(dynamic_job_spec.tasks) == 1
            args = " ".join(
                dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0]
                ["args"])
            assert args.startswith(
                "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 "
                "--dest-dir /User/flyte/workflows")
            assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][
                "resources"]["limits"]["cpu"] == "124M"
            assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][
                "resources"]["requests"]["gpu"] == "1"

    assert context_manager.FlyteContextManager.size() == 1
def default_serialization_settings(default_image_config):
    return SerializationSettings(
        project="p", domain="d", version="v", image_config=default_image_config, env={"FOO": "bar"}
    )
def minimal_serialization_settings(default_image_config):
    return SerializationSettings(project="p", domain="d", version="v", image_config=default_image_config)
示例#28
0
def register(
    ctx: click.Context,
    project: str,
    domain: str,
    image_config: ImageConfig,
    output: str,
    destination_dir: str,
    service_account: str,
    raw_data_prefix: str,
    version: typing.Optional[str],
    package_or_module: typing.Tuple[str],
):
    """
    see help
    """
    pkgs = ctx.obj[constants.CTX_PACKAGES]
    if not pkgs:
        cli_logger.debug("No pkgs")
    if pkgs:
        raise ValueError(
            "Unimplemented, just specify pkgs like folder/files as args at the end of the command"
        )

    if len(package_or_module) == 0:
        display_help_with_error(
            ctx,
            "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed",
        )

    cli_logger.debug(
        f"Running pyflyte register from {os.getcwd()} "
        f"with images {image_config} "
        f"and image destinationfolder {destination_dir} "
        f"on {len(package_or_module)} package(s) {package_or_module}")

    # Create and save FlyteRemote,
    remote = get_and_save_remote_with_click_context(ctx, project, domain)

    # Todo: add switch for non-fast - skip the zipping and uploading and no fastserializationsettings
    # Create a zip file containing all the entries.
    detected_root = find_common_root(package_or_module)
    cli_logger.debug(f"Using {detected_root} as root folder for project")
    zip_file = fast_package(detected_root, output)

    # Upload zip file to Admin using FlyteRemote.
    md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file))
    cli_logger.debug(f"Uploaded zip {zip_file} to {native_url}")

    # Create serialization settings
    # Todo: Rely on default Python interpreter for now, this will break custom Spark containers
    serialization_settings = SerializationSettings(
        project=project,
        domain=domain,
        image_config=image_config,
        fast_serialization_settings=FastSerializationSettings(
            enabled=True,
            destination_dir=destination_dir,
            distribution_location=native_url,
        ),
    )

    options = Options.default_from(k8s_service_account=service_account,
                                   raw_data_prefix=raw_data_prefix)

    # Load all the entities
    registerable_entities = load_packages_and_modules(serialization_settings,
                                                      detected_root,
                                                      list(package_or_module),
                                                      options)
    if len(registerable_entities) == 0:
        display_help_with_error(ctx,
                                "No Flyte entities were detected. Aborting!")
    cli_logger.info(
        f"Found and serialized {len(registerable_entities)} entities")

    if not version:
        version = remote._version_from_hash(md5_bytes, serialization_settings,
                                            service_account,
                                            raw_data_prefix)  # noqa
        cli_logger.info(f"Computed version is {version}")

    # Register using repo code
    repo_register(registerable_entities, project, domain, version,
                  remote.client)
示例#29
0
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteAssertion
from flytekit.models.core.workflow import WorkflowTemplate
from flytekit.models.task import TaskTemplate
from flytekit.remote import FlyteLaunchPlan, FlyteTask
from flytekit.remote.interface import TypedInterface
from flytekit.remote.workflow import FlyteWorkflow
from flytekit.tools.translator import gather_dependent_entities, get_serializable

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
    project="project",
    domain="domain",
    version="version",
    env=None,
    image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


@task
def t1(a: int) -> int:
    return a + 2


@task
def t2(a: int, b: str) -> str:
    return b + str(a)

class PythonCustomizedContainerTask(ExecutableTemplateShimTask,
                                    PythonTask[TC]):
    """
    Please take a look at the comments for :py:class`flytekit.extend.ExecutableTemplateShimTask` as well. This class
    should be subclassed and a custom Executor provided as a default to this parent class constructor
    when building a new external-container flytekit-only plugin.

    This class provides authors of new task types the basic scaffolding to create task-template based tasks. In order
    to write such a task, authors need to

    * subclass the ``ShimTaskExecutor`` class  and override the ``execute_from_model`` function. This function is
      where all the business logic should go. Keep in mind though that you, the plugin author, will not have access
      to anything that's not serialized within the ``TaskTemplate`` which is why you'll also need to
    * subclass this class, and override the ``get_custom`` function to include all the information the executor
      will need to run.
    * Also pass the executor you created as the ``executor_type`` argument of this class's constructor.

    Keep in mind that the total size of the ``TaskTemplate`` still needs to be small, since these will be accessed
    frequently by the Flyte engine.
    """

    SERIALIZE_SETTINGS = SerializationSettings(
        project="PLACEHOLDER_PROJECT",
        domain="LOCAL",
        version="PLACEHOLDER_VERSION",
        env=None,
        image_config=ImageConfig(
            default_image=Image(name="custom_container_task",
                                fqn="flyteorg.io/placeholder",
                                tag="image")),
    )

    def __init__(
        self,
        name: str,
        task_config: TC,
        container_image: str,
        executor_type: Type[ShimTaskExecutor],
        task_resolver: Optional[TaskTemplateResolver] = None,
        task_type="python-task",
        requests: Optional[Resources] = None,
        limits: Optional[Resources] = None,
        environment: Optional[Dict[str, str]] = None,
        secret_requests: Optional[List[Secret]] = None,
        **kwargs,
    ):
        """
        :param name: unique name for the task, usually the function's module and name.
        :param task_config: Configuration object for Task. Should be a unique type for that specific Task
        :param container_image: This is the external container image the task should run at platform-run-time.
        :param executor: This is an executor which will actually provide the business logic.
        :param task_resolver: Custom resolver - if you don't make one, use the default task template resolver.
        :param task_type: String task type to be associated with this Task.
        :param requests: custom resource request settings.
        :param limits: custom resource limit settings.
        :param environment: Environment variables you want the task to have when run.
        :param List[Secret] secret_requests: Secrets that are requested by this container execution. These secrets will
           be mounted based on the configuration in the Secret and available through
           the SecretManager using the name of the secret as the group
           Ideally the secret keys should also be semi-descriptive.
           The key values will be available from runtime, if the backend is configured to provide secrets and
           if secrets are available in the configured secrets store. Possible options for secret stores are

           - `Vault <https://www.vaultproject.io/>`__
           - `Confidant <https://lyft.github.io/confidant/>`__
           - `Kube secrets <https://kubernetes.io/docs/concepts/configuration/secret/>`__
           - `AWS Parameter store <https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html>`__
        """
        sec_ctx = None
        if secret_requests:
            for s in secret_requests:
                if not isinstance(s, Secret):
                    raise AssertionError(
                        f"Secret {s} should be of type flytekit.Secret, received {type(s)}"
                    )
            sec_ctx = SecurityContext(secrets=secret_requests)
        super().__init__(
            tt=None,
            executor_type=executor_type,
            task_type=task_type,
            name=name,
            task_config=task_config,
            security_ctx=sec_ctx,
            **kwargs,
        )
        self._resources = ResourceSpec(
            requests=requests if requests else Resources(),
            limits=limits if limits else Resources())
        self._environment = environment
        self._container_image = container_image
        self._task_resolver = task_resolver or default_task_template_resolver

    def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
        # Overriding base implementation to raise an error, force task author to implement
        raise NotImplementedError

    def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
        # Overriding base implementation but not doing anything. Technically this should be the task config,
        # but the IDL limitation that the value also has to be a string is very limiting.
        # Recommend putting information you need in the config into custom instead, because when serializing
        # the custom field, we jsonify custom and the place it into a protobuf struct. This config field
        # just gets put into a Dict[str, str]
        return {}

    @property
    def resources(self) -> ResourceSpec:
        return self._resources

    @property
    def task_resolver(self) -> TaskTemplateResolver:
        return self._task_resolver

    @property
    def task_template(self) -> Optional[_task_model.TaskTemplate]:
        """
        Override the base class implementation to serialize on first call.
        """
        return self._task_template or self.serialize_to_model(
            settings=PythonCustomizedContainerTask.SERIALIZE_SETTINGS)

    @property
    def container_image(self) -> str:
        return self._container_image

    def get_command(self, settings: SerializationSettings) -> List[str]:
        container_args = [
            "pyflyte-execute",
            "--inputs",
            "{{.input}}",
            "--output-prefix",
            "{{.outputPrefix}}",
            "--raw-output-data-prefix",
            "{{.rawOutputDataPrefix}}",
            "--resolver",
            self.task_resolver.location,
            "--",
            *self.task_resolver.loader_args(settings, self),
        ]

        return container_args

    def get_container(
            self, settings: SerializationSettings) -> _task_model.Container:
        env = {
            **settings.env,
            **self.environment
        } if self.environment else settings.env
        return _get_container_definition(
            image=self.container_image,
            command=[],
            args=self.get_command(settings=settings),
            data_loading_config=None,
            environment=env,
            storage_request=self.resources.requests.storage,
            cpu_request=self.resources.requests.cpu,
            gpu_request=self.resources.requests.gpu,
            memory_request=self.resources.requests.mem,
            storage_limit=self.resources.limits.storage,
            cpu_limit=self.resources.limits.cpu,
            gpu_limit=self.resources.limits.gpu,
            memory_limit=self.resources.limits.mem,
        )

    def serialize_to_model(
            self, settings: SerializationSettings) -> _task_model.TaskTemplate:
        # This doesn't get called from translator unfortunately. Will need to move the translator to use the model
        # objects directly first.
        # Note: This doesn't settle the issue of duplicate registrations. We'll need to figure that out somehow.
        # TODO: After new control plane classes are in, promote the template to a FlyteTask, so that authors of
        #  customized-container tasks have a familiar thing to work with.
        obj = _task_model.TaskTemplate(
            identifier_models.Identifier(identifier_models.ResourceType.TASK,
                                         settings.project, settings.domain,
                                         self.name, settings.version),
            self.task_type,
            self.metadata.to_taskmetadata_model(),
            self.interface,
            self.get_custom(settings),
            container=self.get_container(settings),
            config=self.get_config(settings),
        )
        self._task_template = obj
        return obj