Beispiel #1
0
def test_typed_interface(literal_type):
    typed_interface = interface.TypedInterface(
        {'a': interface.Variable(literal_type, "description1")}, {
            'b': interface.Variable(literal_type, "description2"),
            'c': interface.Variable(literal_type, "description3")
        })

    assert typed_interface.inputs['a'].type == literal_type
    assert typed_interface.outputs['b'].type == literal_type
    assert typed_interface.outputs['c'].type == literal_type
    assert typed_interface.inputs['a'].description == "description1"
    assert typed_interface.outputs['b'].description == "description2"
    assert typed_interface.outputs['c'].description == "description3"
    assert len(typed_interface.inputs) == 1
    assert len(typed_interface.outputs) == 2

    pb = typed_interface.to_flyte_idl()
    deserialized_typed_interface = interface.TypedInterface.from_flyte_idl(pb)
    assert typed_interface == deserialized_typed_interface

    assert deserialized_typed_interface.inputs['a'].type == literal_type
    assert deserialized_typed_interface.outputs['b'].type == literal_type
    assert deserialized_typed_interface.outputs['c'].type == literal_type
    assert deserialized_typed_interface.inputs[
        'a'].description == "description1"
    assert deserialized_typed_interface.outputs[
        'b'].description == "description2"
    assert deserialized_typed_interface.outputs[
        'c'].description == "description3"
    assert len(deserialized_typed_interface.inputs) == 1
    assert len(deserialized_typed_interface.outputs) == 2
Beispiel #2
0
def test_workflow_template_with_queuing_budget():
    task = _workflow.TaskNode(reference_id=_generic_id)
    nm = _get_sample_node_metadata()
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    wf_metadata = _workflow.WorkflowMetadata(queuing_budget=timedelta(
        seconds=10))
    wf_metadata_defaults = _workflow.WorkflowMetadataDefaults()
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })
    wf_node = _workflow.Node(id='some:node:id',
                             metadata=nm,
                             inputs=[],
                             upstream_node_ids=[],
                             output_aliases=[],
                             task_node=task)
    obj = _workflow.WorkflowTemplate(id=_generic_id,
                                     metadata=wf_metadata,
                                     metadata_defaults=wf_metadata_defaults,
                                     interface=typed_interface,
                                     nodes=[wf_node],
                                     outputs=[])
    obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
def test_basic_unit_test():

    def add_one(wf_params, value_in, value_out):
        value_out.set(value_in + 1)

    t = sdk_runnable.SdkRunnableTask(
        add_one,
        _common_constants.SdkTaskType.PYTHON_TASK,
        "1",
        1,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        False,
        None,
        {},
        None,
    )
    t.add_inputs({'value_in': interface.Variable(primitives.Integer.to_flyte_literal_type(), "")})
    t.add_outputs({'value_out': interface.Variable(primitives.Integer.to_flyte_literal_type(), "")})
    out = t.unit_test(value_in=1)
    assert out['value_out'] == 2
Beispiel #4
0
def test_basic_unit_test():

    def add_one(wf_params, value_in, value_out):
        value_out.set(value_in + 1)

    t = sdk_runnable.SdkRunnableTask(
        add_one,
        _common_constants.SdkTaskType.PYTHON_TASK,
        "1",
        1,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        False,
        None,
        {},
        None,
    )
    t.add_inputs({'value_in': interface.Variable(primitives.Integer.to_flyte_literal_type(), "")})
    t.add_outputs({'value_out': interface.Variable(primitives.Integer.to_flyte_literal_type(), "")})
    out = t.unit_test(value_in=1)
    assert out['value_out'] == 2

    with _pytest.raises(_user_exceptions.FlyteAssertion) as e:
        t()

    assert "value_in" in str(e.value)
    assert "INTEGER" in str(e.value)
Beispiel #5
0
def test_typed_interface(literal_type):
    typed_interface = interface.TypedInterface(
        {"a": interface.Variable(literal_type, "description1")},
        {
            "b": interface.Variable(literal_type, "description2"),
            "c": interface.Variable(literal_type, "description3")
        },
    )

    assert typed_interface.inputs["a"].type == literal_type
    assert typed_interface.outputs["b"].type == literal_type
    assert typed_interface.outputs["c"].type == literal_type
    assert typed_interface.inputs["a"].description == "description1"
    assert typed_interface.outputs["b"].description == "description2"
    assert typed_interface.outputs["c"].description == "description3"
    assert len(typed_interface.inputs) == 1
    assert len(typed_interface.outputs) == 2

    pb = typed_interface.to_flyte_idl()
    deserialized_typed_interface = interface.TypedInterface.from_flyte_idl(pb)
    assert typed_interface == deserialized_typed_interface

    assert deserialized_typed_interface.inputs["a"].type == literal_type
    assert deserialized_typed_interface.outputs["b"].type == literal_type
    assert deserialized_typed_interface.outputs["c"].type == literal_type
    assert deserialized_typed_interface.inputs[
        "a"].description == "description1"
    assert deserialized_typed_interface.outputs[
        "b"].description == "description2"
    assert deserialized_typed_interface.outputs[
        "c"].description == "description3"
    assert len(deserialized_typed_interface.inputs) == 1
    assert len(deserialized_typed_interface.outputs) == 2
Beispiel #6
0
def test_workflow_template():
    task = _workflow.TaskNode(reference_id=_generic_id)
    nm = _get_sample_node_metadata()
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    wf_metadata = _workflow.WorkflowMetadata()
    wf_metadata_defaults = _workflow.WorkflowMetadataDefaults()
    typed_interface = _interface.TypedInterface(
        {"a": _interface.Variable(int_type, "description1")},
        {
            "b": _interface.Variable(int_type, "description2"),
            "c": _interface.Variable(int_type, "description3")
        },
    )
    wf_node = _workflow.Node(
        id="some:node:id",
        metadata=nm,
        inputs=[],
        upstream_node_ids=[],
        output_aliases=[],
        task_node=task,
    )
    obj = _workflow.WorkflowTemplate(
        id=_generic_id,
        metadata=wf_metadata,
        metadata_defaults=wf_metadata_defaults,
        interface=typed_interface,
        nodes=[wf_node],
        outputs=[],
    )
    obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
Beispiel #7
0
def test_task_template__k8s_pod_target():
    int_type = types.LiteralType(types.SimpleType.INTEGER)
    obj = task.TaskTemplate(
        identifier.Identifier(identifier.ResourceType.TASK, "project",
                              "domain", "name", "version"),
        "python",
        task.TaskMetadata(
            False,
            task.RuntimeMetadata(1, "v", "f"),
            timedelta(days=1),
            literal_models.RetryStrategy(5),
            False,
            "1.0",
            "deprecated",
            False,
        ),
        interface_models.TypedInterface(
            # inputs
            {"a": interface_models.Variable(int_type, "description1")},
            # outputs
            {
                "b": interface_models.Variable(int_type, "description2"),
                "c": interface_models.Variable(int_type, "description3"),
            },
        ),
        {
            "a": 1,
            "b": {
                "c": 2,
                "d": 3
            }
        },
        config={"a": "b"},
        k8s_pod=task.K8sPod(
            metadata=task.K8sObjectMetadata(labels={"label": "foo"},
                                            annotations={"anno": "bar"}),
            pod_spec={
                "str": "val",
                "int": 1
            },
        ),
    )
    assert obj.id.resource_type == identifier.ResourceType.TASK
    assert obj.id.project == "project"
    assert obj.id.domain == "domain"
    assert obj.id.name == "name"
    assert obj.id.version == "version"
    assert obj.type == "python"
    assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}}
    assert obj.k8s_pod.metadata == task.K8sObjectMetadata(
        labels={"label": "foo"}, annotations={"anno": "bar"})
    assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1}
    assert text_format.MessageToString(
        obj.to_flyte_idl()) == text_format.MessageToString(
            task.TaskTemplate.from_flyte_idl(
                obj.to_flyte_idl()).to_flyte_idl())
    assert obj.config == {"a": "b"}
def test_basic_workflow_promote(mock_task_fetch):
    # This section defines a sample workflow from a user
    @_sdk_tasks.inputs(a=_Types.Integer)
    @_sdk_tasks.outputs(b=_Types.Integer, c=_Types.Integer)
    @_sdk_tasks.python_task()
    def demo_task_for_promote(wf_params, a, b, c):
        b.set(a + 1)
        c.set(a + 2)

    @_sdk_workflow.workflow_class()
    class TestPromoteExampleWf(object):
        wf_input = _sdk_workflow.Input(_Types.Integer, required=True)
        my_task_node = demo_task_for_promote(a=wf_input)
        wf_output_b = _sdk_workflow.Output(my_task_node.outputs.b,
                                           sdk_type=_Types.Integer)
        wf_output_c = _sdk_workflow.Output(my_task_node.outputs.c,
                                           sdk_type=_Types.Integer)

    # This section uses the TaskTemplate stored in Admin to promote back to an Sdk Workflow
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    task_interface = _interface.TypedInterface(
        # inputs
        {'a': _interface.Variable(int_type, "description1")},
        # outputs
        {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })
    # Since the promotion of a workflow requires retrieving the task from Admin, we mock the SdkTask to return
    task_template = _task_model.TaskTemplate(_identifier.Identifier(
        _identifier.ResourceType.TASK, "project", "domain",
        "tests.flytekit.unit.common_tests.test_workflow_promote.demo_task_for_promote",
        "version"),
                                             "python_container",
                                             get_sample_task_metadata(),
                                             task_interface,
                                             custom={},
                                             container=get_sample_container())
    sdk_promoted_task = _task.SdkTask.promote_from_model(task_template)
    mock_task_fetch.return_value = sdk_promoted_task
    workflow_template = get_workflow_template()
    promoted_wf = _workflow_common.SdkWorkflow.promote_from_model(
        workflow_template)

    assert promoted_wf.interface.inputs[
        "wf_input"] == TestPromoteExampleWf.interface.inputs["wf_input"]
    assert promoted_wf.interface.outputs[
        "wf_output_b"] == TestPromoteExampleWf.interface.outputs["wf_output_b"]
    assert promoted_wf.interface.outputs[
        "wf_output_c"] == TestPromoteExampleWf.interface.outputs["wf_output_c"]

    assert len(promoted_wf.nodes) == 1
    assert len(TestPromoteExampleWf.nodes) == 1
    assert promoted_wf.nodes[0].inputs[0] == TestPromoteExampleWf.nodes[
        0].inputs[0]
Beispiel #9
0
def test_variable_map():
    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           'asdf asdf asdf')
    obj = interface.VariableMap({'vvv': v})

    obj2 = interface.VariableMap.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
Beispiel #10
0
def test_launch_plan_spec():
    identifier_model = identifier.Identifier(identifier.ResourceType.TASK,
                                             "project", "domain", "name",
                                             "version")

    s = schedule.Schedule("asdf", "1 3 4 5 6 7")
    launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(
        schedule=s, notifications=[])

    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           "asdf asdf asdf")
    p = interface.Parameter(var=v)
    parameter_map = interface.ParameterMap({"ppp": p})

    fixed_inputs = literals.LiteralMap({
        "a":
        literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(
            integer=1)))
    })

    labels_model = common.Labels({})
    annotations_model = common.Annotations({"my": "annotation"})

    auth_role_model = common.AuthRole(assumable_iam_role="my:iam:role")
    raw_data_output_config = common.RawOutputDataConfig("s3://bucket")
    empty_raw_data_output_config = common.RawOutputDataConfig("")
    max_parallelism = 100

    lp_spec_raw_output_prefixed = launch_plan.LaunchPlanSpec(
        identifier_model,
        launch_plan_metadata_model,
        parameter_map,
        fixed_inputs,
        labels_model,
        annotations_model,
        auth_role_model,
        raw_data_output_config,
        max_parallelism,
    )

    obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(
        lp_spec_raw_output_prefixed.to_flyte_idl())
    assert obj2 == lp_spec_raw_output_prefixed

    lp_spec_no_prefix = launch_plan.LaunchPlanSpec(
        identifier_model,
        launch_plan_metadata_model,
        parameter_map,
        fixed_inputs,
        labels_model,
        annotations_model,
        auth_role_model,
        empty_raw_data_output_config,
        max_parallelism,
    )

    obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(
        lp_spec_no_prefix.to_flyte_idl())
    assert obj2 == lp_spec_no_prefix
 def named_tuple_to_variable_map(
         cls, t: typing.NamedTuple) -> _interface_models.VariableMap:
     variables = {}
     for idx, (var_name, var_type) in enumerate(t._field_types.items()):
         literal_type = cls.to_literal_type(var_type)
         variables[var_name] = _interface_models.Variable(
             type=literal_type, description=f"{idx}")
     return _interface_models.VariableMap(variables=variables)
Beispiel #12
0
def test_parameter_map():
    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           'asdf asdf asdf')
    p = interface.Parameter(var=v)

    obj = interface.ParameterMap({'ppp': p})
    obj2 = interface.ParameterMap.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
Beispiel #13
0
def test_parameter():
    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           'asdf asdf asdf')
    obj = interface.Parameter(var=v)
    assert obj.var == v

    obj2 = interface.Parameter.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.var == v
Beispiel #14
0
def test_lp_closure():
    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf')
    p = interface.Parameter(var=v)
    parameter_map = interface.ParameterMap({'ppp': p})
    parameter_map.to_flyte_idl()
    variable_map = interface.VariableMap({'vvv': v})
    obj = launch_plan.LaunchPlanClosure(state=launch_plan.LaunchPlanState.ACTIVE, expected_inputs=parameter_map,
                                        expected_outputs=variable_map)
    assert obj.expected_inputs == parameter_map
    assert obj.expected_outputs == variable_map

    obj2 = launch_plan.LaunchPlanClosure.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.expected_inputs == parameter_map
    assert obj2.expected_outputs == variable_map
Beispiel #15
0
    def __init__(self, name, sdk_type, help=None, **kwargs):
        """
        :param Text name:
        :param flytekit.common.types.base_sdk_types.FlyteSdkType sdk_type: This is the SDK type necessary to create an
            input to this workflow.
        :param Text help: An optional help string to describe the input to users.
        :param bool required: If set to True, default must be None
        :param T default:  If this is not a required input, the value will default to this value.
        """
        param_default = None
        if "required" not in kwargs and "default" not in kwargs:
            # Neither required or default is set so assume required
            required = True
            default = None
        elif kwargs.get("required", False) and "default" in kwargs:
            # Required cannot be set to True and have a default specified
            raise _user_exceptions.FlyteAssertion(
                "Default cannot be set when required is True")
        elif "default" in kwargs:
            # If default is specified, then required must be false and the value is whatever is specified
            required = None
            default = kwargs["default"]
            param_default = sdk_type.from_python_std(default)
        else:
            # If no default is set, but required is set, then the behavior is determined by required == True or False
            default = None
            required = kwargs["required"]
            if not required:
                # If required == False, we assume default to be None
                param_default = sdk_type.from_python_std(default)
                required = None

        self._sdk_required = required or False
        self._sdk_default = default
        self._help = help
        self._sdk_type = sdk_type
        self._promise = _type_models.OutputReference(
            _constants.GLOBAL_INPUT_NODE_ID, name)
        self._name = name
        super(Input, self).__init__(
            _interface_models.Variable(type=sdk_type.to_flyte_literal_type(),
                                       description=help or ""),
            required=required,
            default=param_default,
        )
Beispiel #16
0
    def __init__(self, name, value, sdk_type=None, help=None):
        """
        :param Text name:
        :param T value:
        :param U sdk_type: If specified, the value provided must cast to this type.  Normally should be an instance of
            flytekit.common.types.base_sdk_types.FlyteSdkType.  But could also be something like:

            list[flytekit.common.types.base_sdk_types.FlyteSdkType],
            dict[flytekit.common.types.base_sdk_types.FlyteSdkType,flytekit.common.types.base_sdk_types.FlyteSdkType],
            (flytekit.common.types.base_sdk_types.FlyteSdkType, flytekit.common.types.base_sdk_types.FlyteSdkType, ...)
        """
        if sdk_type is None:
            # This syntax didn't work for some reason: sdk_type = sdk_type or Output._infer_type(value)
            sdk_type = Output._infer_type(value)
        sdk_type = _type_helpers.python_std_to_sdk_type(sdk_type)

        self._binding_data = _interface.BindingData.from_python_std(sdk_type.to_flyte_literal_type(), value)
        self._var = _interface_models.Variable(sdk_type.to_flyte_literal_type(), help or "")
        self._name = name
Beispiel #17
0
def test_old_style_role():
    identifier_model = identifier.Identifier(identifier.ResourceType.TASK,
                                             "project", "domain", "name",
                                             "version")

    s = schedule.Schedule("asdf", "1 3 4 5 6 7")
    launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(
        schedule=s, notifications=[])

    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           "asdf asdf asdf")
    p = interface.Parameter(var=v)
    parameter_map = interface.ParameterMap({"ppp": p})

    fixed_inputs = literals.LiteralMap({
        "a":
        literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(
            integer=1)))
    })

    labels_model = common.Labels({})
    annotations_model = common.Annotations({"my": "annotation"})

    raw_data_output_config = common.RawOutputDataConfig("s3://bucket")

    old_role = _launch_plan_idl.Auth(
        kubernetes_service_account="my:service:account")

    old_style_spec = _launch_plan_idl.LaunchPlanSpec(
        workflow_id=identifier_model.to_flyte_idl(),
        entity_metadata=launch_plan_metadata_model.to_flyte_idl(),
        default_inputs=parameter_map.to_flyte_idl(),
        fixed_inputs=fixed_inputs.to_flyte_idl(),
        labels=labels_model.to_flyte_idl(),
        annotations=annotations_model.to_flyte_idl(),
        raw_output_data_config=raw_data_output_config.to_flyte_idl(),
        auth=old_role,
    )

    lp_spec = launch_plan.LaunchPlanSpec.from_flyte_idl(old_style_spec)

    assert lp_spec.auth_role.assumable_iam_role == "my:service:account"
Beispiel #18
0
    def apply_outputs_wrapper(task):
        if not isinstance(task, _sdk_runnable_tasks.SdkRunnableTask):
            additional_msg = \
                "Outputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format(
                    task.__module__,
                    task.__name__ if hasattr(task, "__name__") else "<unknown>"
                )
            raise _user_exceptions.FlyteTypeException(
                expected_type=_sdk_runnable_tasks.SdkRunnableTask,
                received_type=type(task),
                received_value=task,
                additional_msg=additional_msg)
        for k, v in _six.iteritems(kwargs):
            kwargs[k] = _interface_model.Variable(
                _type_helpers.python_std_to_sdk_type(
                    v).to_flyte_literal_type(),
                '')  # TODO: Support descriptions

        task.add_outputs(kwargs)
        return task
Beispiel #19
0
    def __init__(
        self,
        max_number_of_training_jobs: int,
        max_parallel_training_jobs: int,
        training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask,
                                   CustomTrainingJobTask],
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
        tunable_parameters: typing.List[str] = None,
    ):
        """
        :param int max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
        hyperparameter tuning job
        :param int max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
        tuning job in parallel
        :param typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask] training_job: The reference to the training job definition
        :param int retries: Number of retries to attempt
        :param bool cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param str cache_version: String describing the caching version for task discovery purposes
        :param typing.List[str] tunable_parameters: A list of parameters that to tune. If you are tuning a built-int
                algorithm, refer to the algorithm's documentation to understand the possible values for the tunable
                parameters. E.g. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html for the
                list of hyperparameters for Image Classification built-in algorithm. If you are passing a custom
                training job, the list of tunable parameters must be a strict subset of the list of inputs defined on
                that job. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html
                for the list of supported hyperparameter types.
        """
        # Use the training job model as a measure of type checking
        hpo_job = _hpo_job_model.HyperparameterTuningJob(
            max_number_of_training_jobs=max_number_of_training_jobs,
            max_parallel_training_jobs=max_parallel_training_jobs,
            training_job=training_job.training_job_model,
        ).to_flyte_idl()

        # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
        #   the underlying training job
        # TODO: Discuss whether this is a viable interface or contract
        timeout = _datetime.timedelta(seconds=0)

        inputs = {}
        inputs.update(training_job.interface.inputs)
        inputs.update({
            "hyperparameter_tuning_job_config":
            _interface_model.Variable(
                HyperparameterTuningJobConfig.to_flyte_literal_type(),
                "",
            ),
        })

        if tunable_parameters:
            inputs.update({
                param: _interface_model.Variable(
                    ParameterRange.to_flyte_literal_type(), "")
                for param in tunable_parameters
            })

        super().__init__(
            type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs=inputs,
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_types_models.LiteralType(
                            blob=_core_types.BlobType(
                                format="",
                                dimensionality=_core_types.BlobType.
                                BlobDimensionality.SINGLE,
                            )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(hpo_job),
        )
    def __init__(
        self,
        training_job_resource_config: _training_job_models.
        TrainingJobResourceConfig,
        algorithm_specification: _training_job_models.AlgorithmSpecification,
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
    ):
        """

        :param training_job_resource_config: The options to configure the training job
        :param algorithm_specification: The options to configure the target algorithm of the training
        :param retries: Number of retries to attempt
        :param cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param cache_version: String describing the caching version for task discovery purposes
        """
        # Use the training job model as a measure of type checking
        self._training_job_model = _training_job_models.TrainingJob(
            algorithm_specification=algorithm_specification,
            training_job_resource_config=training_job_resource_config,
        )

        # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training
        # job gracefully
        timeout = _datetime.timedelta(seconds=0)

        super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__(
            type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs={
                    "static_hyperparameters":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(
                            simple=_idl_types.SimpleType.STRUCT),
                        description="",
                    ),
                    "train":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format=_content_type_to_blob_format(
                                algorithm_specification.input_content_type),
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.MULTIPART,
                        ), ),
                        description="",
                    ),
                    "validation":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format=_content_type_to_blob_format(
                                algorithm_specification.input_content_type),
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.MULTIPART,
                        ), ),
                        description="",
                    ),
                },
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format="",
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.SINGLE,
                        )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(self._training_job_model.to_flyte_idl()),
        )
Beispiel #21
0
    def __init__(
        self,
        region,
        role_arn,
        resource_config,
        algorithm_specification=None,
        stopping_condition=None,
        vpc_config=None,
        enable_spot_training=False,
        interruptible=False,
        retries=0,
        cacheable=False,
        cache_version="",
    ):
        """
        :param Text region: The region in which to run the SageMaker job.
        :param Text role_arn: The ARN of the role to run in the SageMaker job.
        :param dict[Text,T] algorithm_specification: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html
        :param dict[Text,T] resource_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ResourceConfig.html
        :param dict[Text,T] stopping_condition: https://docs.aws.amazon.com/sagemaker/latest/dg/API_StoppingCondition.html
        :param dict[Text,T] vpc_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html
        :param bool enable_spot_training: https://docs.aws.amazon.com/sagemaker/latest/dg/API_HyperParameterTrainingJobDefinition.html
        :param int retries: Number of time to retry.
        :param bool cacheable: Whether or not to use Flyte's caching system.
        :param Text cache_version: Update this to notify a behavioral change requiring the cache to be invalidated.
        """

        algorithm_specification = algorithm_specification or {}
        algorithm_specification["TrainingImage"] = (
            algorithm_specification.get("TrainingImage")
            or "825641698319.dkr.ecr.us-east-2.amazonaws.com/xgboost:1")
        algorithm_specification["TrainingInputMode"] = "File"

        job_config = ParseDict(
            {
                "Region": region,
                "ResourceConfig": resource_config,
                "StoppingCondition": stopping_condition,
                "VpcConfig": vpc_config,
                "AlgorithmSpecification": algorithm_specification,
                "RoleArn": role_arn,
            },
            sagemaker_pb2.SagemakerHPOJob(),
        )
        print(MessageToDict(job_config))

        # TODO: Optionally, pull timeout behavior from stopping condition and pass to Flyte task def.
        timeout = _datetime.timedelta(seconds=0)

        # TODO: The FlyteKit type engine is extensible so we can create a SagemakerInput type with custom
        # TODO:     parsing/casting logic. For now, we will use the Generic type since there is a little that needs
        # TODO:     to be done on Flyte side to unlock this cleanly.
        # TODO: This call to the super-constructor will be less verbose in future versions of Flytekit following a
        # TODO:     refactor.
        # TODO: Add more configurations to the custom dict. These are things that are necessary to execute the task,
        # TODO:     but might not affect the outputs (i.e. Running on a bigger machine). These are currently static for
        # TODO:     a given definition of a task, but will be more dynamic in the future. Also, it is possible to
        # TODO:     make it dynamic by using our @dynamic_task.
        # TODO: You might want to inherit the role ARN from the execution at runtime.
        super(SagemakerXgBoostOptimizer, self).__init__(
            type=_TASK_TYPE,
            metadata=_task_models.TaskMetadata(
                discoverable=cacheable,
                runtime=_task_models.RuntimeMetadata(0, "0.1.0b0",
                                                     "sagemaker"),
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=interruptible,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface({}, {}),
            custom=MessageToDict(job_config),
        )

        # TODO: Add more inputs that we expect to change the outputs of the task.
        # TODO: We can add outputs too!
        # We use helper methods for adding to interface, thus overriding the one set above. This will be simplified post
        # refactor.
        self.add_inputs({
            "static_hyperparameters":
            _interface_model.Variable(
                _sdk_types.Types.Generic.to_flyte_literal_type(), ""),
            "train":
            _interface_model.Variable(
                _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""),
            "validation":
            _interface_model.Variable(
                _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""),
        })
        self.add_outputs({
            "model":
            _interface_model.Variable(
                _sdk_types.Types.Blob.to_flyte_literal_type(), "")
        })
Beispiel #22
0
    def __init__(
        self,
        max_number_of_training_jobs: int,
        max_parallel_training_jobs: int,
        training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask,
                                   CustomTrainingJobTask],
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
    ):
        """

        :param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
        hyperparameter tuning job
        :param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
        tuning job in parallel
        :param training_job: The reference to the training job definition
        :param retries: Number of retries to attempt
        :param cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param cache_version: String describing the caching version for task discovery purposes
        """
        # Use the training job model as a measure of type checking
        hpo_job = _hpo_job_model.HyperparameterTuningJob(
            max_number_of_training_jobs=max_number_of_training_jobs,
            max_parallel_training_jobs=max_parallel_training_jobs,
            training_job=training_job.training_job_model,
        ).to_flyte_idl()

        # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
        #   the underlying training job
        # TODO: Discuss whether this is a viable interface or contract
        timeout = _datetime.timedelta(seconds=0)

        inputs = {
            "hyperparameter_tuning_job_config":
            _interface_model.Variable(
                _sdk_types.Types.Proto(
                    _pb2_hpo_job.HyperparameterTuningJobConfig).
                to_flyte_literal_type(),
                "",
            ),
        }
        inputs.update(training_job.interface.inputs)

        super(SdkSimpleHyperparameterTuningJobTask, self).__init__(
            type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs=inputs,
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_types_models.LiteralType(
                            blob=_core_types.BlobType(
                                format="",
                                dimensionality=_core_types.BlobType.
                                BlobDimensionality.SINGLE,
                            )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(hpo_job),
        )
Beispiel #23
0
def transform_type(x: type,
                   description: str = None) -> _interface_models.Variable:
    return _interface_models.Variable(type=TypeEngine.to_literal_type(x),
                                      description=description)
Beispiel #24
0
    def __init__(
        self,
        notebook_path,
        inputs,
        outputs,
        task_type,
        discovery_version,
        retries,
        deprecated,
        storage_request,
        cpu_request,
        gpu_request,
        memory_request,
        storage_limit,
        cpu_limit,
        gpu_limit,
        memory_limit,
        discoverable,
        timeout,
        environment,
        custom,
    ):

        if _os.path.isabs(notebook_path) is False:
            # Find absolute path for the notebook.
            task_module = _importlib.import_module(_find_instance_module())
            module_path = _os.path.dirname(task_module.__file__)
            notebook_path = _os.path.normpath(
                _os.path.join(module_path, notebook_path))

        self._notebook_path = notebook_path

        super(SdkNotebookTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__,
                    "notebook",
                ),
                timeout,
                _literal_models.RetryStrategy(retries),
                False,
                discovery_version,
                deprecated,
            ),
            _interface2.TypedInterface({}, {}),
            custom,
            container=self._get_container_definition(
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment,
            ),
        )
        # Add Inputs
        if inputs is not None:
            inputs(self)

        # Add outputs
        if outputs is not None:
            outputs(self)

        # Add a Notebook output as a Blob.
        self.interface.outputs.update(output_notebook=_interface.Variable(
            _Types.Blob.to_flyte_literal_type(), OUTPUT_NOTEBOOK))
Beispiel #25
0
    types.LiteralType(collection_type=literal_type)
    for literal_type in LIST_OF_SCALAR_LITERAL_TYPES
]

LIST_OF_NESTED_COLLECTION_LITERAL_TYPES = [
    types.LiteralType(collection_type=literal_type)
    for literal_type in LIST_OF_COLLECTION_LITERAL_TYPES
]

LIST_OF_ALL_LITERAL_TYPES = \
    LIST_OF_SCALAR_LITERAL_TYPES + \
    LIST_OF_COLLECTION_LITERAL_TYPES + \
    LIST_OF_NESTED_COLLECTION_LITERAL_TYPES

LIST_OF_INTERFACES = [
    interface.TypedInterface({'a': interface.Variable(t, "description 1")},
                             {'b': interface.Variable(t, "description 2")})
    for t in LIST_OF_ALL_LITERAL_TYPES
]

LIST_OF_RESOURCE_ENTRIES = [
    task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "1"),
    task.Resources.ResourceEntry(task.Resources.ResourceName.GPU, "1"),
    task.Resources.ResourceEntry(task.Resources.ResourceName.MEMORY, "1G"),
    task.Resources.ResourceEntry(task.Resources.ResourceName.STORAGE, "1G")
]

LIST_OF_RESOURCE_ENTRY_LISTS = [LIST_OF_RESOURCE_ENTRIES]

LIST_OF_RESOURCES = [
    task.Resources(request, limit) for request, limit in product(
Beispiel #26
0
    types.LiteralType(collection_type=literal_type)
    for literal_type in LIST_OF_SCALAR_LITERAL_TYPES
]

LIST_OF_NESTED_COLLECTION_LITERAL_TYPES = [
    types.LiteralType(collection_type=literal_type)
    for literal_type in LIST_OF_COLLECTION_LITERAL_TYPES
]

LIST_OF_ALL_LITERAL_TYPES = (LIST_OF_SCALAR_LITERAL_TYPES +
                             LIST_OF_COLLECTION_LITERAL_TYPES +
                             LIST_OF_NESTED_COLLECTION_LITERAL_TYPES)

LIST_OF_INTERFACES = [
    interface.TypedInterface(
        {"a": interface.Variable(t, "description 1")},
        {"b": interface.Variable(t, "description 2")},
    ) for t in LIST_OF_ALL_LITERAL_TYPES
]

LIST_OF_RESOURCE_ENTRIES = [
    task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "1"),
    task.Resources.ResourceEntry(task.Resources.ResourceName.GPU, "1"),
    task.Resources.ResourceEntry(task.Resources.ResourceName.MEMORY, "1G"),
    task.Resources.ResourceEntry(task.Resources.ResourceName.STORAGE, "1G"),
]

LIST_OF_RESOURCE_ENTRY_LISTS = [LIST_OF_RESOURCE_ENTRIES]

LIST_OF_RESOURCES = [
    task.Resources(request, limit) for request, limit in product(
Beispiel #27
0
def test_variable_type(literal_type):
    var = interface.Variable(type=literal_type, description="abc")
    assert var.type == literal_type
    assert var.description == "abc"
    assert var == interface.Variable.from_flyte_idl(var.to_flyte_idl())
Beispiel #28
0
    def __init__(
        self,
        statement,
        output_schema,
        routing_group=None,
        catalog=None,
        schema=None,
        task_inputs=None,
        interruptible=False,
        discoverable=False,
        discovery_version=None,
        retries=1,
        timeout=None,
        deprecated=None,
    ):
        """
        :param Text statement: Presto query specification
        :param flytekit.common.types.schema.Schema output_schema: Schema that represents that data queried from Presto
        :param Text routing_group: The routing group that a Presto query should be sent to for the given environment
        :param Text catalog: The catalog to set for the given Presto query
        :param Text schema: The schema to set for the given Presto query
        :param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] task_inputs: Optional inputs to the Presto task
        :param bool discoverable:
        :param Text discovery_version: String describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param datetime.timedelta timeout:
        :param Text deprecated: This string can be used to mark the task as deprecated.  Consumers of the task will
            receive deprecation warnings.
        """

        # Set as class fields which are used down below to configure implicit
        # parameters
        self._routing_group = routing_group or ""
        self._catalog = catalog or ""
        self._schema = schema or ""

        metadata = _task_model.TaskMetadata(
            discoverable,
            # This needs to have the proper version reflected in it
            _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"),
            timeout or _datetime.timedelta(seconds=0),
            _literals.RetryStrategy(retries),
            interruptible,
            discovery_version,
            deprecated,
        )

        presto_query = _presto_models.PrestoQuery(
            routing_group=routing_group or "",
            catalog=catalog or "",
            schema=schema or "",
            statement=statement,
        )

        # Here we set the routing_group, catalog, and schema as implicit
        # parameters for caching purposes
        i = _interface.TypedInterface(
            {
                "__implicit_routing_group": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The routing group set as an implicit input",
                ),
                "__implicit_catalog": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The catalog set as an implicit input",
                ),
                "__implicit_schema": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The schema set as an implicit input",
                ),
            },
            {
                # Set the schema for the Presto query as an output
                "results": _interface_model.Variable(
                    type=_types.LiteralType(schema=output_schema.schema_type),
                    description="The schema for the Presto query",
                )
            },
        )

        super(SdkPrestoTask, self).__init__(
            _constants.SdkTaskType.PRESTO_TASK,
            metadata,
            i,
            _MessageToDict(presto_query.to_flyte_idl()),
        )

        # Set user provided inputs
        task_inputs(self)
Beispiel #29
0
def test_workflow_closure():
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })

    b0 = _literals.Binding(
        'a',
        _literals.BindingData(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=5))))
    b1 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'b')))
    b2 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'c')))

    node_metadata = _workflow.NodeMetadata(name='node1',
                                           timeout=timedelta(seconds=10),
                                           retries=_literals.RetryStrategy(0))

    task_metadata = _task.TaskMetadata(
        True,
        _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                              "1.0.0", "python"), timedelta(days=1),
        _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!")

    cpu_resource = _task.Resources.ResourceEntry(
        _task.Resources.ResourceName.CPU, "1")
    resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])

    task = _task.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "project",
                               "domain", "name", "version"),
        "python",
        task_metadata,
        typed_interface, {
            'a': 1,
            'b': {
                'c': 2,
                'd': 3
            }
        },
        container=_task.Container("my_image", ["this", "is", "a", "cmd"],
                                  ["this", "is", "an", "arg"], resources, {},
                                  {}))

    task_node = _workflow.TaskNode(task.id)
    node = _workflow.Node(id='my_node',
                          metadata=node_metadata,
                          inputs=[b0],
                          upstream_node_ids=[],
                          output_aliases=[],
                          task_node=task_node)

    template = _workflow.WorkflowTemplate(
        id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project",
                                  "domain", "name", "version"),
        metadata=_workflow.WorkflowMetadata(),
        interface=typed_interface,
        nodes=[node],
        outputs=[b1, b2],
    )

    obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task])
    assert len(obj.tasks) == 1

    obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
Beispiel #30
0
def test_interface():
    ctx = FlyteContextManager.current_context()
    lt = TypeEngine.to_literal_type(pd.DataFrame)
    df = pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]})

    annotated_sd_type = Annotated[StructuredDataset, kwtypes(name=str, age=int)]
    df_literal_type = TypeEngine.to_literal_type(annotated_sd_type)
    assert df_literal_type.structured_dataset_type is not None
    assert len(df_literal_type.structured_dataset_type.columns) == 2
    assert df_literal_type.structured_dataset_type.columns[0].name == "name"
    assert df_literal_type.structured_dataset_type.columns[0].literal_type.simple is not None
    assert df_literal_type.structured_dataset_type.columns[1].name == "age"
    assert df_literal_type.structured_dataset_type.columns[1].literal_type.simple is not None

    sd = annotated_sd_type(df)
    sd_literal = TypeEngine.to_literal(ctx, sd, python_type=annotated_sd_type, expected=lt)

    lm = {
        "my_map": Literal(
            map=LiteralMap(
                literals={
                    "k1": Literal(scalar=Scalar(primitive=Primitive(string_value="v1"))),
                    "k2": Literal(scalar=Scalar(primitive=Primitive(string_value="2"))),
                },
            )
        ),
        "my_list": Literal(
            collection=LiteralCollection(
                literals=[
                    Literal(scalar=Scalar(primitive=Primitive(integer=1))),
                    Literal(scalar=Scalar(primitive=Primitive(integer=2))),
                    Literal(scalar=Scalar(primitive=Primitive(integer=3))),
                ]
            )
        ),
        "val_a": Literal(scalar=Scalar(primitive=Primitive(integer=21828))),
        "my_df": sd_literal,
    }

    variable_map = {
        "my_map": interface_models.Variable(type=TypeEngine.to_literal_type(typing.Dict[str, str]), description=""),
        "my_list": interface_models.Variable(type=TypeEngine.to_literal_type(typing.List[int]), description=""),
        "val_a": interface_models.Variable(type=TypeEngine.to_literal_type(int), description=""),
        "my_df": interface_models.Variable(type=df_literal_type, description=""),
    }

    lr = LiteralsResolver(lm, variable_map=variable_map, ctx=ctx)
    assert lr._ctx is ctx

    with pytest.raises(ValueError):
        lr["not"]  # noqa

    with pytest.raises(ValueError):
        lr.get_literal("not")

    # Test that just using [] works, guessing from the Flyte type is invoked
    result = lr["my_list"]
    assert result == [1, 2, 3]

    # Test that using get works, guessing from the Flyte type is invoked
    result = lr.get("my_map")
    assert result == {
        "k1": "v1",
        "k2": "2",
    }

    # Getting the literal will return the Literal object itself
    assert lr.get_literal("my_df") is sd_literal

    guessed_df = lr["my_df"]
    # Based on guessing, so no column information
    assert len(guessed_df.metadata.structured_dataset_type.columns) == 0
    guessed_df_2 = lr["my_df"]
    assert guessed_df is guessed_df_2

    # Update type hints with the annotated type
    lr.update_type_hints({"my_df": annotated_sd_type})
    del lr._native_values["my_df"]
    guessed_df = lr.get("my_df")
    # Using the user specified type, so number of columns is correct.
    assert len(guessed_df.metadata.structured_dataset_type.columns) == 2