예제 #1
0
def test_serialization_settings_transport():
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"hello": "blah"},
        image_config=ImageConfig(
            default_image=default_img,
            images=[default_img],
        ),
        flytekit_virtualenv_root="/opt/venv/blah",
        python_interpreter="/opt/venv/bin/python3",
        fast_serialization_settings=FastSerializationSettings(
            enabled=True,
            destination_dir="/opt/blah/blah/blah",
            distribution_location="s3://my-special-bucket/blah/bha/asdasdasd/cbvsdsdf/asdddasdasdasdasdasdasd.tar.gz",
        ),
    )

    tp = serialization_settings.serialized_context
    with_serialized = serialization_settings.with_serialized_context()
    assert serialization_settings.env == {"hello": "blah"}
    assert with_serialized.env
    assert with_serialized.env[SERIALIZED_CONTEXT_ENV_VAR] == tp
    ss = SerializationSettings.from_transport(tp)
    assert ss is not None
    assert ss == serialization_settings
    assert len(tp) == 376
예제 #2
0
def test_pod_task_serialized():
    pod = Pod(
        pod_spec=get_pod_spec(),
        primary_container_name="an undefined container",
        labels={"label": "foo"},
        annotations={"anno": "bar"},
    )

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")
    ssettings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    serialized = get_serializable(OrderedDict(), ssettings, simple_pod_task)
    assert serialized.template.task_type_version == 2
    assert serialized.template.config[
        "primary_container_name"] == "an undefined container"
    assert serialized.template.k8s_pod.metadata.labels == {"label": "foo"}
    assert serialized.template.k8s_pod.metadata.annotations == {"anno": "bar"}
    assert serialized.template.k8s_pod.pod_spec is not None
예제 #3
0
def test_tensorflow_task():
    @task(
        task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1),
        cache=True,
        requests=Resources(cpu="1"),
        cache_version="1",
    )
    def my_tensorflow_task(x: int, y: str) -> int:
        return x

    assert my_tensorflow_task(x=10, y="hello") == 10

    assert my_tensorflow_task.task_config is not None

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
    )

    assert my_tensorflow_task.get_custom(settings) == {"workers": 10, "psReplicas": 1, "chiefReplicas": 1}
    assert my_tensorflow_task.resources.limits == Resources()
    assert my_tensorflow_task.resources.requests == Resources(cpu="1")
    assert my_tensorflow_task.task_type == "tensorflow"
예제 #4
0
def test_mpi_task():
    @task(
        task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1),
        requests=Resources(cpu="1"),
        cache=True,
        cache_version="1",
    )
    def my_mpi_task(x: int, y: str) -> int:
        return x

    assert my_mpi_task(x=10, y="hello") == 10

    assert my_mpi_task.task_config is not None

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    assert my_mpi_task.get_custom(settings) == {
        "numLauncherReplicas": 10,
        "numWorkers": 10,
        "slots": 1
    }
    assert my_mpi_task.task_type == "mpi"
예제 #5
0
def test_query_no_inputs_or_outputs():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs={},
        task_config=HiveConfig(cluster_label="flyte"),
        query_template="""
            insert into extant_table (1, 'two')
        """,
        output_schema_type=None,
    )

    @workflow
    def my_wf():
        hive_task()

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 hive_task)
    assert len(task_spec.template.interface.inputs) == 0
    assert len(task_spec.template.interface.outputs) == 0

    get_serializable(OrderedDict(), serialization_settings, my_wf)
예제 #6
0
def package(ctx, image_config, source, output, force, fast, in_container_source_path, python_interpreter):
    """
    This command produces a Flyte backend registrable package of all entities in Flyte.
    For tasks, one pb file is produced for each task, representing one TaskTemplate object.
    For workflows, one pb file is produced for each workflow, representing a WorkflowClosure object.  The closure
        object contains the WorkflowTemplate, along with the relevant tasks for that workflow.
        This serialization step will set the name of the tasks to the fully qualified name of the task function.
    """
    if os.path.exists(output) and not force:
        raise click.BadParameter(click.style(f"Output file {output} already exists, specify -f to override.", fg="red"))

    serialization_settings = SerializationSettings(
        image_config=image_config,
        fast_serialization_settings=FastSerializationSettings(
            enabled=fast,
            destination_dir=in_container_source_path,
        ),
        python_interpreter=python_interpreter,
    )

    pkgs = ctx.obj[constants.CTX_PACKAGES]
    if not pkgs:
        display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!")

    try:
        serialize_and_package(pkgs, serialization_settings, source, output, fast)
    except NoSerializableEntitiesError:
        click.secho(f"No flyte objects found in packages {pkgs}", fg="yellow")
예제 #7
0
def test_two(two_sample_inputs):
    my_input = two_sample_inputs[0]
    my_input_2 = two_sample_inputs[1]

    @dynamic
    def dt1(a: List[MyInput]) -> List[FlyteFile]:
        x = []
        for aa in a:
            x.append(aa.main_product)
        return x

    with FlyteContextManager.with_context(
        FlyteContextManager.current_context().with_serialization_settings(
            SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
    ) as ctx:
        with FlyteContextManager.with_context(
            ctx.with_execution_state(
                ctx.execution_state.with_params(
                    mode=ExecutionState.Mode.TASK_EXECUTION,
                )
            )
        ) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(
                ctx, d={"a": [my_input, my_input_2]}, type_hints={"a": List[MyInput]}
            )
            dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
            assert len(dynamic_job_spec.literals["o0"].collection.literals) == 2
예제 #8
0
def test_pytorch_task():
    @task(
        task_config=PyTorch(num_workers=10),
        cache=True,
        cache_version="1",
        requests=Resources(cpu="1"),
    )
    def my_pytorch_task(x: int, y: str) -> int:
        return x

    assert my_pytorch_task(x=10, y="hello") == 10

    assert my_pytorch_task.task_config is not None

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    assert my_pytorch_task.get_custom(settings) == {"workers": 10}
    assert my_pytorch_task.resources.limits == Resources()
    assert my_pytorch_task.resources.requests == Resources(cpu="1")
    assert my_pytorch_task.task_type == "pytorch"
예제 #9
0
def test_register_a_hello_world_wf():
    version = get_version("1")
    ss = SerializationSettings(image_config, project="flytesnacks", domain="development", version=version)
    rr.register_workflow(hello_wf, serialization_settings=ss)

    fetched_wf = rr.fetch_workflow(name=hello_wf.name, version=version)

    rr.execute(fetched_wf, inputs={"a": 5})
예제 #10
0
def _get_reg_settings():
    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    return settings
예제 #11
0
def test_dc_dyn_directory(folders_and_files_setup):
    proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file")
    proxy_p = MyProxyParameters(id="pp_id", job_i_step=1)

    my_input_gcs = MyInput(
        main_product=FlyteFile(folders_and_files_setup[0]),
        apriori_config=MyAprioriConfiguration(
            static_data_dir=FlyteDirectory("gs://my-bucket/one"),
            external_data_dir=FlyteDirectory("gs://my-bucket/two"),
        ),
        proxy_config=proxy_c,
        proxy_params=proxy_p,
    )

    my_input_gcs_2 = MyInput(
        main_product=FlyteFile(folders_and_files_setup[0]),
        apriori_config=MyAprioriConfiguration(
            static_data_dir=FlyteDirectory("gs://my-bucket/three"),
            external_data_dir=FlyteDirectory("gs://my-bucket/four"),
        ),
        proxy_config=proxy_c,
        proxy_params=proxy_p,
    )

    @dynamic
    def dt1(a: List[MyInput]) -> List[FlyteDirectory]:
        x = []
        for aa in a:
            x.append(aa.apriori_config.external_data_dir)

        return x

    ctx = FlyteContextManager.current_context()
    cb = (
        ctx.new_builder()
        .with_serialization_settings(
            SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
        .with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION))
    )
    with FlyteContextManager.with_context(cb) as ctx:
        input_literal_map = TypeEngine.dict_to_literal_map(
            ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, type_hints={"a": List[MyInput]}
        )
        dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
        assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two"
        assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four"
예제 #12
0
def test_pod_task_undefined_primary():
    pod = Pod(pod_spec=get_pod_spec(),
              primary_container_name="an undefined container")

    @task(task_config=pod,
          requests=Resources(cpu="10"),
          limits=Resources(gpu="2"),
          environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    assert isinstance(simple_pod_task, PodFunctionTask)
    assert simple_pod_task.task_config == pod

    default_img = Image(name="default", fqn="test", tag="tag")
    pod_spec = simple_pod_task.get_k8s_pod(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        )).pod_spec

    assert len(pod_spec["containers"]) == 3

    primary_container = pod_spec["containers"][2]
    assert primary_container["name"] == "an undefined container"

    config = simple_pod_task.get_config(
        SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env={"FOO": "baz"},
            image_config=ImageConfig(default_image=default_img,
                                     images=[default_img]),
        ))
    assert config["primary_container_name"] == "an undefined container"
예제 #13
0
def test_spark_template_with_remote():
    @task(task_config=Spark(spark_conf={"spark": "1"}))
    def my_spark(a: str) -> int:
        return 10

    @task
    def my_python_task(a: str) -> int:
        return 10

    remote = FlyteRemote(config=Config.for_endpoint(endpoint="localhost",
                                                    insecure=True),
                         default_project="p1",
                         default_domain="d1")

    mock_client = MagicMock()
    remote._client = mock_client

    remote.register_task(
        my_spark,
        serialization_settings=SerializationSettings(
            image_config=MagicMock(), ),
        version="v1",
    )
    serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"]

    print(serialized_spec)
    # Check if the serialized spark task has mainApplicaitonFile field set.
    assert serialized_spec.template.custom["mainApplicationFile"]
    assert serialized_spec.template.custom["sparkConf"]

    remote.register_task(
        my_python_task,
        serialization_settings=SerializationSettings(image_config=MagicMock()),
        version="v1")
    serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"]

    # Check if the serialized python task has no mainApplicaitonFile field set by default.
    assert serialized_spec.template.custom is None
예제 #14
0
def test_fast_pod_task_serialization():
    pod = Pod(
        pod_spec=V1PodSpec(restart_policy="OnFailure",
                           containers=[V1Container(name="primary")]),
        primary_container_name="primary",
    )

    @task(task_config=pod, environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        fast_serialization_settings=FastSerializationSettings(enabled=True),
    )
    serialized = get_serializable(OrderedDict(), serialization_settings,
                                  simple_pod_task)

    assert serialized.template.k8s_pod.pod_spec["containers"][0]["args"] == [
        "pyflyte-fast-execute",
        "--additional-distribution",
        "{{ .remote_package_path }}",
        "--dest-dir",
        "{{ .dest_dir }}",
        "--",
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--checkpoint-path",
        "{{.checkpointOutputPrefix}}",
        "--prev-checkpoint",
        "{{.prevCheckpointPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.test_pod",
        "task-name",
        "simple_pod_task",
    ]
예제 #15
0
def test_py_func_task_get_container():
    def foo(i: int):
        pass

    default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1")
    other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other")
    cfg = ImageConfig(default_image=default_img, images=[default_img, other_img])

    settings = SerializationSettings(project="p", domain="d", version="v", image_config=cfg, env={"FOO": "bar"})

    pytask = PythonFunctionTask(None, foo, None, environment={"BAZ": "baz"})
    c = pytask.get_container(settings)
    assert c.image == "xyz.com/abc:tag1"
    assert c.env == {"FOO": "bar", "BAZ": "baz"}
예제 #16
0
def test_serialization():
    snowflake_task = SnowflakeTask(
        name="flytekit.demo.snowflake_task.query",
        inputs=kwtypes(ds=str),
        task_config=SnowflakeConfig(account="snowflake",
                                    warehouse="my_warehouse",
                                    schema="my_schema",
                                    database="my_database"),
        query_template=query_template,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(ds: str) -> FlyteSchema:
        return snowflake_task(ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )

    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 snowflake_task)

    assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement
    assert "insert overwrite directory" in task_spec.template.sql.statement
    assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI
    assert "snowflake" == task_spec.template.config["account"]
    assert "my_warehouse" == task_spec.template.config["warehouse"]
    assert "my_schema" == task_spec.template.config["schema"]
    assert "my_database" == task_spec.template.config["database"]
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(),
                                           serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs[
        "o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.var == "results"
예제 #17
0
def test_map_pod_task_serialization():
    pod = Pod(
        pod_spec=V1PodSpec(restart_policy="OnFailure",
                           containers=[V1Container(name="primary")]),
        primary_container_name="primary",
    )

    @task(task_config=pod, environment={"FOO": "bar"})
    def simple_pod_task(i: int):
        pass

    mapped_task = map_task(simple_pod_task, metadata=TaskMetadata(retries=1))
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )

    # Test that target is correctly serialized with an updated command
    pod_spec = mapped_task.get_k8s_pod(serialization_settings).pod_spec

    assert len(pod_spec["containers"]) == 1
    assert pod_spec["containers"][0]["args"] == [
        "pyflyte-map-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--checkpoint-path",
        "{{.checkpointOutputPrefix}}",
        "--prev-checkpoint",
        "{{.prevCheckpointPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.test_pod",
        "task-name",
        "simple_pod_task",
    ]
    assert {
        "primary_container_name": "primary"
    } == mapped_task.get_config(serialization_settings)
예제 #18
0
def test_serialization():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs=kwtypes(my_schema=FlyteSchema, ds=str),
        config=HiveConfig(cluster_label="flyte"),
        query_template="""
            set engine=tez;
            insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet  -- will be unique per retry
            select *
            from blah
            where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}'
        """,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema:
        return hive_task(my_schema=in_schema, ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 hive_task)
    assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["query"][
        "query"]
    assert "insert overwrite directory" in task_spec.template.custom["query"][
        "query"]
    assert len(task_spec.template.interface.inputs) == 2
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(),
                                           serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs[
        "o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[
        0].binding.promise.var == "results"
예제 #19
0
def serialize_all(
    pkgs: typing.List[str] = None,
    local_source_root: typing.Optional[str] = None,
    folder: typing.Optional[str] = None,
    mode: typing.Optional[SerializationMode] = None,
    image: typing.Optional[str] = None,
    flytekit_virtualenv_root: typing.Optional[str] = None,
    python_interpreter: typing.Optional[str] = None,
    config_file: typing.Optional[str] = None,
):
    """
    This function will write to the folder specified the following protobuf types ::
        flyteidl.admin.launch_plan_pb2.LaunchPlan
        flyteidl.admin.workflow_pb2.WorkflowSpec
        flyteidl.admin.task_pb2.TaskSpec

    These can be inspected by calling (in the launch plan case) ::
        flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan

    See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the
    entity type.
    :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization.
    :param local_source_root: Where to start looking for the code.
    :param folder: Where to write the output protobuf files
    :param mode: Regular vs fast
    :param image: The fully qualified and versioned default image to use
    :param flytekit_virtualenv_root: The full path of the virtual env in the container.
    """

    if not (mode == SerializationMode.DEFAULT
            or mode == SerializationMode.FAST):
        raise AssertionError(f"Unrecognized serialization mode: {mode}")

    serialization_settings = SerializationSettings(
        image_config=ImageConfig.auto(config_file, img_name=image),
        fast_serialization_settings=FastSerializationSettings(
            enabled=mode == SerializationMode.FAST,
            # TODO: if we want to move the destination dir as a serialization argument, we should initialize it here
        ),
        flytekit_virtualenv_root=flytekit_virtualenv_root,
        python_interpreter=python_interpreter,
    )

    serialize_to_folder(pkgs, serialization_settings, local_source_root,
                        folder)
예제 #20
0
def test_serialization():
    athena_task = AthenaTask(
        name="flytekit.demo.athena_task.query",
        inputs=kwtypes(ds=str),
        task_config=AthenaConfig(database="mnist", catalog="my_catalog", workgroup="my_wg"),
        query_template="""
            insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet
            select *
            from blah
            where ds = '{{ .Inputs.ds }}'
        """,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(ds: str) -> FlyteSchema:
        return athena_task(ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
        env={},
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings, athena_task)
    assert "{{ .rawOutputDataPrefix" in task_spec.template.custom["statement"]
    assert "insert overwrite directory" in task_spec.template.custom["statement"]
    assert "mnist" == task_spec.template.custom["schema"]
    assert "my_catalog" == task_spec.template.custom["catalog"]
    assert "my_wg" == task_spec.template.custom["routingGroup"]
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs["o0"].type.schema is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"
예제 #21
0
def test_serialization():
    bigquery_task = BigQueryTask(
        name="flytekit.demo.bigquery_task.query",
        inputs=kwtypes(ds=str),
        task_config=BigQueryConfig(
            ProjectID="Flyte", Location="Asia", QueryJobConfig=QueryJobConfig(allow_large_results=True)
        ),
        query_template=query_template,
        output_structured_dataset_type=StructuredDataset,
    )

    @workflow
    def my_wf(ds: str) -> StructuredDataset:
        return bigquery_task(ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
        env={},
    )

    task_spec = get_serializable(OrderedDict(), serialization_settings, bigquery_task)

    assert "SELECT * FROM `bigquery-public-data.crypto_dogecoin.transactions`" in task_spec.template.sql.statement
    assert "@version" in task_spec.template.sql.statement
    assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI
    s = Struct()
    s.update({"ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True})
    assert task_spec.template.custom == json_format.MessageToDict(s)
    assert len(task_spec.template.interface.inputs) == 1
    assert len(task_spec.template.interface.outputs) == 1

    admin_workflow_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert admin_workflow_spec.template.interface.outputs["o0"].type.structured_dataset_type is not None
    assert admin_workflow_spec.template.outputs[0].var == "o0"
    assert admin_workflow_spec.template.outputs[0].binding.promise.node_id == "n0"
    assert admin_workflow_spec.template.outputs[0].binding.promise.var == "results"
예제 #22
0
def test_aws_batch_task():
    @task(task_config=config)
    def t1(a: int) -> str:
        inc = a + 2
        return str(inc)

    assert t1.task_config is not None
    assert t1.task_config == config
    assert t1.task_type == "aws-batch"
    assert isinstance(t1, PythonFunctionTask)

    default_img = Image(name="default", fqn="test", tag="tag")
    settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    assert t1.get_custom(settings) == config.to_dict()
    assert t1.get_command(settings) == [
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}/0",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.test_aws_batch",
        "task-name",
        "t1",
    ]
class PythonCustomizedContainerTask(ExecutableTemplateShimTask,
                                    PythonTask[TC]):
    """
    Please take a look at the comments for :py:class`flytekit.extend.ExecutableTemplateShimTask` as well. This class
    should be subclassed and a custom Executor provided as a default to this parent class constructor
    when building a new external-container flytekit-only plugin.

    This class provides authors of new task types the basic scaffolding to create task-template based tasks. In order
    to write such a task, authors need to

    * subclass the ``ShimTaskExecutor`` class  and override the ``execute_from_model`` function. This function is
      where all the business logic should go. Keep in mind though that you, the plugin author, will not have access
      to anything that's not serialized within the ``TaskTemplate`` which is why you'll also need to
    * subclass this class, and override the ``get_custom`` function to include all the information the executor
      will need to run.
    * Also pass the executor you created as the ``executor_type`` argument of this class's constructor.

    Keep in mind that the total size of the ``TaskTemplate`` still needs to be small, since these will be accessed
    frequently by the Flyte engine.
    """

    SERIALIZE_SETTINGS = SerializationSettings(
        project="PLACEHOLDER_PROJECT",
        domain="LOCAL",
        version="PLACEHOLDER_VERSION",
        env=None,
        image_config=ImageConfig(
            default_image=Image(name="custom_container_task",
                                fqn="flyteorg.io/placeholder",
                                tag="image")),
    )

    def __init__(
        self,
        name: str,
        task_config: TC,
        container_image: str,
        executor_type: Type[ShimTaskExecutor],
        task_resolver: Optional[TaskTemplateResolver] = None,
        task_type="python-task",
        requests: Optional[Resources] = None,
        limits: Optional[Resources] = None,
        environment: Optional[Dict[str, str]] = None,
        secret_requests: Optional[List[Secret]] = None,
        **kwargs,
    ):
        """
        :param name: unique name for the task, usually the function's module and name.
        :param task_config: Configuration object for Task. Should be a unique type for that specific Task
        :param container_image: This is the external container image the task should run at platform-run-time.
        :param executor: This is an executor which will actually provide the business logic.
        :param task_resolver: Custom resolver - if you don't make one, use the default task template resolver.
        :param task_type: String task type to be associated with this Task.
        :param requests: custom resource request settings.
        :param limits: custom resource limit settings.
        :param environment: Environment variables you want the task to have when run.
        :param List[Secret] secret_requests: Secrets that are requested by this container execution. These secrets will
           be mounted based on the configuration in the Secret and available through
           the SecretManager using the name of the secret as the group
           Ideally the secret keys should also be semi-descriptive.
           The key values will be available from runtime, if the backend is configured to provide secrets and
           if secrets are available in the configured secrets store. Possible options for secret stores are

           - `Vault <https://www.vaultproject.io/>`__
           - `Confidant <https://lyft.github.io/confidant/>`__
           - `Kube secrets <https://kubernetes.io/docs/concepts/configuration/secret/>`__
           - `AWS Parameter store <https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html>`__
        """
        sec_ctx = None
        if secret_requests:
            for s in secret_requests:
                if not isinstance(s, Secret):
                    raise AssertionError(
                        f"Secret {s} should be of type flytekit.Secret, received {type(s)}"
                    )
            sec_ctx = SecurityContext(secrets=secret_requests)
        super().__init__(
            tt=None,
            executor_type=executor_type,
            task_type=task_type,
            name=name,
            task_config=task_config,
            security_ctx=sec_ctx,
            **kwargs,
        )
        self._resources = ResourceSpec(
            requests=requests if requests else Resources(),
            limits=limits if limits else Resources())
        self._environment = environment
        self._container_image = container_image
        self._task_resolver = task_resolver or default_task_template_resolver

    def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
        # Overriding base implementation to raise an error, force task author to implement
        raise NotImplementedError

    def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
        # Overriding base implementation but not doing anything. Technically this should be the task config,
        # but the IDL limitation that the value also has to be a string is very limiting.
        # Recommend putting information you need in the config into custom instead, because when serializing
        # the custom field, we jsonify custom and the place it into a protobuf struct. This config field
        # just gets put into a Dict[str, str]
        return {}

    @property
    def resources(self) -> ResourceSpec:
        return self._resources

    @property
    def task_resolver(self) -> TaskTemplateResolver:
        return self._task_resolver

    @property
    def task_template(self) -> Optional[_task_model.TaskTemplate]:
        """
        Override the base class implementation to serialize on first call.
        """
        return self._task_template or self.serialize_to_model(
            settings=PythonCustomizedContainerTask.SERIALIZE_SETTINGS)

    @property
    def container_image(self) -> str:
        return self._container_image

    def get_command(self, settings: SerializationSettings) -> List[str]:
        container_args = [
            "pyflyte-execute",
            "--inputs",
            "{{.input}}",
            "--output-prefix",
            "{{.outputPrefix}}",
            "--raw-output-data-prefix",
            "{{.rawOutputDataPrefix}}",
            "--resolver",
            self.task_resolver.location,
            "--",
            *self.task_resolver.loader_args(settings, self),
        ]

        return container_args

    def get_container(
            self, settings: SerializationSettings) -> _task_model.Container:
        env = {
            **settings.env,
            **self.environment
        } if self.environment else settings.env
        return _get_container_definition(
            image=self.container_image,
            command=[],
            args=self.get_command(settings=settings),
            data_loading_config=None,
            environment=env,
            storage_request=self.resources.requests.storage,
            cpu_request=self.resources.requests.cpu,
            gpu_request=self.resources.requests.gpu,
            memory_request=self.resources.requests.mem,
            storage_limit=self.resources.limits.storage,
            cpu_limit=self.resources.limits.cpu,
            gpu_limit=self.resources.limits.gpu,
            memory_limit=self.resources.limits.mem,
        )

    def serialize_to_model(
            self, settings: SerializationSettings) -> _task_model.TaskTemplate:
        # This doesn't get called from translator unfortunately. Will need to move the translator to use the model
        # objects directly first.
        # Note: This doesn't settle the issue of duplicate registrations. We'll need to figure that out somehow.
        # TODO: After new control plane classes are in, promote the template to a FlyteTask, so that authors of
        #  customized-container tasks have a familiar thing to work with.
        obj = _task_model.TaskTemplate(
            identifier_models.Identifier(identifier_models.ResourceType.TASK,
                                         settings.project, settings.domain,
                                         self.name, settings.version),
            self.task_type,
            self.metadata.to_taskmetadata_model(),
            self.interface,
            self.get_custom(settings),
            container=self.get_container(settings),
            config=self.get_config(settings),
        )
        self._task_template = obj
        return obj
예제 #24
0
def test_fast():
    REQUESTS_GPU = Resources(cpu="123m",
                             mem="234Mi",
                             ephemeral_storage="123M",
                             gpu="1")
    LIMITS_GPU = Resources(cpu="124M",
                           mem="235Mi",
                           ephemeral_storage="124M",
                           gpu="1")

    def get_minimal_pod_task_config() -> Pod:
        primary_container = V1Container(name="flytetask")
        pod_spec = V1PodSpec(containers=[primary_container])
        return Pod(pod_spec=pod_spec, primary_container_name="flytetask")

    @task(
        task_config=get_minimal_pod_task_config(),
        requests=REQUESTS_GPU,
        limits=LIMITS_GPU,
    )
    def pod_task_with_resources(dummy_input: str) -> str:
        return dummy_input

    @dynamic(requests=REQUESTS_GPU, limits=LIMITS_GPU)
    def dynamic_task_with_pod_subtask(dummy_input: str) -> str:
        pod_task_with_resources(dummy_input=dummy_input)
        return dummy_input

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        fast_serialization_settings=FastSerializationSettings(
            enabled=True,
            destination_dir="/User/flyte/workflows",
            distribution_location="s3://my-s3-bucket/fast/123",
        ),
    )

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_serialization_settings(serialization_settings)) as ctx:
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(
                ctx, {"dummy_input": "hi"})
            dynamic_job_spec = dynamic_task_with_pod_subtask.dispatch_execute(
                ctx, input_literal_map)
            # print(dynamic_job_spec)
            assert len(dynamic_job_spec._nodes) == 1
            assert len(dynamic_job_spec.tasks) == 1
            args = " ".join(
                dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0]
                ["args"])
            assert args.startswith(
                "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 "
                "--dest-dir /User/flyte/workflows")
            assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][
                "resources"]["limits"]["cpu"] == "124M"
            assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][
                "resources"]["requests"]["gpu"] == "1"

    assert context_manager.FlyteContextManager.size() == 1
def default_serialization_settings(default_image_config):
    return SerializationSettings(
        project="p", domain="d", version="v", image_config=default_image_config, env={"FOO": "bar"}
    )
def minimal_serialization_settings(default_image_config):
    return SerializationSettings(project="p", domain="d", version="v", image_config=default_image_config)
예제 #27
0
def get_serializable_task(
    entity_mapping: OrderedDict,
    settings: SerializationSettings,
    entity: FlyteLocalEntity,
) -> task_models.TaskSpec:
    task_id = _identifier_model.Identifier(
        _identifier_model.ResourceType.TASK,
        settings.project,
        settings.domain,
        entity.name,
        settings.version,
    )

    if isinstance(
            entity, PythonFunctionTask
    ) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC:
        # In case of Dynamic tasks, we want to pass the serialization context, so that they can reconstruct the state
        # from the serialization context. This is passed through an environment variable, that is read from
        # during dynamic serialization
        settings = settings.with_serialized_context()

    container = entity.get_container(settings)
    # This pod will be incorrect when doing fast serialize
    pod = entity.get_k8s_pod(settings)

    if settings.should_fast_serialize():
        # This handles container tasks.
        if container and isinstance(entity,
                                    (PythonAutoContainerTask, MapPythonTask)):
            # For fast registration, we'll need to muck with the command, but on
            # ly for certain kinds of tasks. Specifically,
            # tasks that rely on user code defined in the container. This should be encapsulated by the auto container
            # parent class
            container._args = prefix_with_fast_execute(settings,
                                                       container.args)

        # If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect.
        # The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because
        # the pod spec is a K8s library object, and we shouldn't be messing around with it in this file.
        elif pod:
            if isinstance(entity, MapPythonTask):
                entity.set_command_prefix(
                    get_command_prefix_for_fast_execute(settings))
                pod = entity.get_k8s_pod(settings)
            else:
                entity.set_command_fn(
                    _fast_serialize_command_fn(settings, entity))
                pod = entity.get_k8s_pod(settings)
                entity.reset_command_fn()

    tt = task_models.TaskTemplate(
        id=task_id,
        type=entity.task_type,
        metadata=entity.metadata.to_taskmetadata_model(),
        interface=entity.interface,
        custom=entity.get_custom(settings),
        container=container,
        task_type_version=entity.task_type_version,
        security_context=entity.security_context,
        config=entity.get_config(settings),
        k8s_pod=pod,
        sql=entity.get_sql(settings),
    )
    if settings.should_fast_serialize() and isinstance(
            entity, PythonAutoContainerTask):
        entity.reset_command_fn()
    return task_models.TaskSpec(template=tt)
예제 #28
0
def register(
    ctx: click.Context,
    project: str,
    domain: str,
    image_config: ImageConfig,
    output: str,
    destination_dir: str,
    service_account: str,
    raw_data_prefix: str,
    version: typing.Optional[str],
    package_or_module: typing.Tuple[str],
):
    """
    see help
    """
    pkgs = ctx.obj[constants.CTX_PACKAGES]
    if not pkgs:
        cli_logger.debug("No pkgs")
    if pkgs:
        raise ValueError(
            "Unimplemented, just specify pkgs like folder/files as args at the end of the command"
        )

    if len(package_or_module) == 0:
        display_help_with_error(
            ctx,
            "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed",
        )

    cli_logger.debug(
        f"Running pyflyte register from {os.getcwd()} "
        f"with images {image_config} "
        f"and image destinationfolder {destination_dir} "
        f"on {len(package_or_module)} package(s) {package_or_module}")

    # Create and save FlyteRemote,
    remote = get_and_save_remote_with_click_context(ctx, project, domain)

    # Todo: add switch for non-fast - skip the zipping and uploading and no fastserializationsettings
    # Create a zip file containing all the entries.
    detected_root = find_common_root(package_or_module)
    cli_logger.debug(f"Using {detected_root} as root folder for project")
    zip_file = fast_package(detected_root, output)

    # Upload zip file to Admin using FlyteRemote.
    md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file))
    cli_logger.debug(f"Uploaded zip {zip_file} to {native_url}")

    # Create serialization settings
    # Todo: Rely on default Python interpreter for now, this will break custom Spark containers
    serialization_settings = SerializationSettings(
        project=project,
        domain=domain,
        image_config=image_config,
        fast_serialization_settings=FastSerializationSettings(
            enabled=True,
            destination_dir=destination_dir,
            distribution_location=native_url,
        ),
    )

    options = Options.default_from(k8s_service_account=service_account,
                                   raw_data_prefix=raw_data_prefix)

    # Load all the entities
    registerable_entities = load_packages_and_modules(serialization_settings,
                                                      detected_root,
                                                      list(package_or_module),
                                                      options)
    if len(registerable_entities) == 0:
        display_help_with_error(ctx,
                                "No Flyte entities were detected. Aborting!")
    cli_logger.info(
        f"Found and serialized {len(registerable_entities)} entities")

    if not version:
        version = remote._version_from_hash(md5_bytes, serialization_settings,
                                            service_account,
                                            raw_data_prefix)  # noqa
        cli_logger.info(f"Computed version is {version}")

    # Register using repo code
    repo_register(registerable_entities, project, domain, version,
                  remote.client)
예제 #29
0
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteAssertion
from flytekit.models.core.workflow import WorkflowTemplate
from flytekit.models.task import TaskTemplate
from flytekit.remote import FlyteLaunchPlan, FlyteTask
from flytekit.remote.interface import TypedInterface
from flytekit.remote.workflow import FlyteWorkflow
from flytekit.tools.translator import gather_dependent_entities, get_serializable

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
    project="project",
    domain="domain",
    version="version",
    env=None,
    image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


@task
def t1(a: int) -> int:
    return a + 2


@task
def t2(a: int, b: str) -> str:
    return b + str(a)

예제 #30
0
def setup_execution(
    raw_output_data_prefix: str,
    checkpoint_path: Optional[str] = None,
    prev_checkpoint: Optional[str] = None,
    dynamic_addl_distro: Optional[str] = None,
    dynamic_dest_dir: Optional[str] = None,
):
    """

    :param raw_output_data_prefix:
    :param checkpoint_path:
    :param prev_checkpoint:
    :param dynamic_addl_distro: Works in concert with the other dynamic arg. If present, indicates that if a dynamic
      task were to run, it should set fast serialize to true and use these values in FastSerializationSettings
    :param dynamic_dest_dir: See above.
    :return:
    """
    exe_project = get_one_of("FLYTE_INTERNAL_EXECUTION_PROJECT", "_F_PRJ")
    exe_domain = get_one_of("FLYTE_INTERNAL_EXECUTION_DOMAIN", "_F_DM")
    exe_name = get_one_of("FLYTE_INTERNAL_EXECUTION_ID", "_F_NM")
    exe_wf = get_one_of("FLYTE_INTERNAL_EXECUTION_WORKFLOW", "_F_WF")
    exe_lp = get_one_of("FLYTE_INTERNAL_EXECUTION_LAUNCHPLAN", "_F_LP")

    tk_project = get_one_of("FLYTE_INTERNAL_TASK_PROJECT", "_F_TK_PRJ")
    tk_domain = get_one_of("FLYTE_INTERNAL_TASK_DOMAIN", "_F_TK_DM")
    tk_name = get_one_of("FLYTE_INTERNAL_TASK_NAME", "_F_TK_NM")
    tk_version = get_one_of("FLYTE_INTERNAL_TASK_VERSION", "_F_TK_V")

    compressed_serialization_settings = os.environ.get(SERIALIZED_CONTEXT_ENV_VAR, "")

    ctx = FlyteContextManager.current_context()
    # Create directories
    user_workspace_dir = ctx.file_access.get_random_local_directory()
    logger.info(f"Using user directory {user_workspace_dir}")
    pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
    from flytekit import __version__ as _api_version

    checkpointer = None
    if checkpoint_path is not None:
        checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint)
        logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}")

    execution_parameters = ExecutionParameters(
        execution_id=_identifier.WorkflowExecutionIdentifier(
            project=exe_project,
            domain=exe_domain,
            name=exe_name,
        ),
        execution_date=_datetime.datetime.utcnow(),
        stats=_get_stats(
            cfg=StatsConfig.auto(),
            # Stats metric path will be:
            # registration_project.registration_domain.app.module.task_name.user_stats
            # and it will be tagged with execution-level values for project/domain/wf/lp
            prefix=f"{tk_project}.{tk_domain}.{tk_name}.user_stats",
            tags={
                "exec_project": exe_project,
                "exec_domain": exe_domain,
                "exec_workflow": exe_wf,
                "exec_launchplan": exe_lp,
                "api_version": _api_version,
            },
        ),
        logging=user_space_logger,
        tmp_dir=user_workspace_dir,
        raw_output_prefix=raw_output_data_prefix,
        checkpoint=checkpointer,
        task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version),
    )

    try:
        file_access = FileAccessProvider(
            local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"),
            raw_output_prefix=raw_output_data_prefix,
        )
    except TypeError:  # would be thrown from DataPersistencePlugins.find_plugin
        logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}")
        raise

    es = ctx.new_execution_state().with_params(
        mode=ExecutionState.Mode.TASK_EXECUTION,
        user_space_params=execution_parameters,
    )
    cb = ctx.new_builder().with_file_access(file_access).with_execution_state(es)

    if compressed_serialization_settings:
        ss = SerializationSettings.from_transport(compressed_serialization_settings)
        ssb = ss.new_builder()
        ssb.project = exe_project
        ssb.domain = exe_domain
        ssb.version = tk_version
        if dynamic_addl_distro:
            ssb.fast_serialization_settings = FastSerializationSettings(
                enabled=True,
                destination_dir=dynamic_dest_dir,
                distribution_location=dynamic_addl_distro,
            )
        cb = cb.with_serialization_settings(ssb.build())

    with FlyteContextManager.with_context(cb) as ctx:
        yield ctx