Ejemplo n.º 1
0
 def with_overrides(self, *args, **kwargs):
     if "node_name" in kwargs:
         self._id = kwargs["node_name"]
     if "aliases" in kwargs:
         alias_dict = kwargs["aliases"]
         if not isinstance(alias_dict, dict):
             raise AssertionError("Aliases should be specified as dict[str, str]")
         self._aliases = []
         for k, v in alias_dict.items():
             self._aliases.append(_workflow_model.Alias(var=k, alias=v))
     if "requests" in kwargs or "limits" in kwargs:
         requests = _convert_resource_overrides(kwargs.get("requests"), "requests")
         limits = _convert_resource_overrides(kwargs.get("limits"), "limits")
         self._resources = _resources_model(requests=requests, limits=limits)
     if "timeout" in kwargs:
         timeout = kwargs["timeout"]
         if timeout is None:
             self._metadata._timeout = datetime.timedelta()
         elif isinstance(timeout, int):
             self._metadata._timeout = datetime.timedelta(seconds=timeout)
         elif isinstance(timeout, datetime.timedelta):
             self._metadata._timeout = timeout
         else:
             raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
     if "retries" in kwargs:
         retries = kwargs["retries"]
         self._metadata._retries = (
             _literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
         )
     if "interruptible" in kwargs:
         self._metadata._interruptible = kwargs["interruptible"]
     return self
Ejemplo n.º 2
0
def test_retry_strategy():
    obj = literals.RetryStrategy(3)
    assert obj.retries == 3
    assert literals.RetryStrategy.from_flyte_idl(obj.to_flyte_idl()) == obj
    assert obj != literals.RetryStrategy(4)

    with pytest.raises(Exception):
        obj = literals.RetryStrategy(-1)
        obj.to_flyte_idl()
Ejemplo n.º 3
0
    def __call__(self, *args, **input_map):
        """
        :param list[T] args: Do not specify.  Kwargs only are supported for this function.
        :param dict[Text,T] input_map: Map of inputs.  Can be statically defined or OutputReference links.
        :rtype: flytekit.common.nodes.SdkNode
        """
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a launchplan as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args)))

        # Take the default values from the launch plan
        default_inputs = {
            k: v.sdk_default
            for k, v in _six.iteritems(self.default_inputs.parameters)
            if not v.required
        }
        default_inputs.update(input_map)

        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(
            default_inputs)

        return _nodes.SdkNode(
            id=None,
            metadata=_workflow_models.NodeMetadata(
                "", _datetime.timedelta(), _literal_models.RetryStrategy(0)),
            bindings=sorted(bindings, key=lambda b: b.var),
            upstream_nodes=upstream_nodes,
            sdk_launch_plan=self,
        )
Ejemplo n.º 4
0
    def __call__(self, *args, **input_map):
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a workflow as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args))
            )

        # Take the default values from the Inputs
        compiled_inputs = {
            v.name: v.sdk_default
            for v in self.user_inputs if not v.sdk_required
        }
        compiled_inputs.update(input_map)

        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(compiled_inputs)

        node = _nodes.SdkNode(
            id=None,
            metadata=_workflow_models.NodeMetadata("placeholder", _datetime.timedelta(),
                                                   _literal_models.RetryStrategy(0)),
            upstream_nodes=upstream_nodes,
            bindings=sorted(bindings, key=lambda b: b.var),
            sdk_workflow=self
        )
        return node
Ejemplo n.º 5
0
def get_sample_node_metadata(node_id):
    """
    :param Text node_id:
    :rtype: flytekit.models.core.workflow.NodeMetadata
    """

    return _workflow_model.NodeMetadata(name=node_id,
                                        timeout=timedelta(seconds=10),
                                        retries=_literals.RetryStrategy(0))
Ejemplo n.º 6
0
def test_task_template__k8s_pod_target():
    int_type = types.LiteralType(types.SimpleType.INTEGER)
    obj = task.TaskTemplate(
        identifier.Identifier(identifier.ResourceType.TASK, "project",
                              "domain", "name", "version"),
        "python",
        task.TaskMetadata(
            False,
            task.RuntimeMetadata(1, "v", "f"),
            timedelta(days=1),
            literal_models.RetryStrategy(5),
            False,
            "1.0",
            "deprecated",
            False,
        ),
        interface_models.TypedInterface(
            # inputs
            {"a": interface_models.Variable(int_type, "description1")},
            # outputs
            {
                "b": interface_models.Variable(int_type, "description2"),
                "c": interface_models.Variable(int_type, "description3"),
            },
        ),
        {
            "a": 1,
            "b": {
                "c": 2,
                "d": 3
            }
        },
        config={"a": "b"},
        k8s_pod=task.K8sPod(
            metadata=task.K8sObjectMetadata(labels={"label": "foo"},
                                            annotations={"anno": "bar"}),
            pod_spec={
                "str": "val",
                "int": 1
            },
        ),
    )
    assert obj.id.resource_type == identifier.ResourceType.TASK
    assert obj.id.project == "project"
    assert obj.id.domain == "domain"
    assert obj.id.name == "name"
    assert obj.id.version == "version"
    assert obj.type == "python"
    assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}}
    assert obj.k8s_pod.metadata == task.K8sObjectMetadata(
        labels={"label": "foo"}, annotations={"anno": "bar"})
    assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1}
    assert text_format.MessageToString(
        obj.to_flyte_idl()) == text_format.MessageToString(
            task.TaskTemplate.from_flyte_idl(
                obj.to_flyte_idl()).to_flyte_idl())
    assert obj.config == {"a": "b"}
Ejemplo n.º 7
0
def test_node_metadata():
    obj = _workflow.NodeMetadata(name="node1",
                                 timeout=timedelta(seconds=10),
                                 retries=_literals.RetryStrategy(0))
    assert obj.timeout.seconds == 10
    assert obj.retries.retries == 0
    obj2 = _workflow.NodeMetadata.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.timeout.seconds == 10
    assert obj2.retries.retries == 0
Ejemplo n.º 8
0
def get_sample_task_metadata():
    """
    :rtype: flytekit.models.task.TaskMetadata
    """
    return _task_model.TaskMetadata(
        True,
        _task_model.RuntimeMetadata(
            _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0",
            "python"), timedelta(days=1), _literals.RetryStrategy(3), True,
        "0.1.1b0", "This is deprecated!")
Ejemplo n.º 9
0
    def __init__(self, task_function, task_type, discovery_version, retries,
                 interruptible, deprecated, storage_request, cpu_request,
                 gpu_request, memory_request, storage_limit, cpu_limit,
                 gpu_limit, memory_limit, discoverable, timeout, environment,
                 custom):
        """
        :param task_function: Function container user code.  This will be executed via the SDK's engine.
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Specify whether task is interruptible
        :param Text deprecated:
        :param Text storage_request:
        :param Text cpu_request:
        :param Text gpu_request:
        :param Text memory_request:
        :param Text storage_limit:
        :param Text cpu_limit:
        :param Text gpu_limit:
        :param Text memory_limit:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param dict[Text, Text] environment:
        :param dict[Text, T] custom:
        """
        self._task_function = task_function

        super(SdkRunnableTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__, 'python'), timeout,
                _literal_models.RetryStrategy(retries), interruptible,
                discovery_version, deprecated),
            _interface.TypedInterface({}, {}),
            custom,
            container=self._get_container_definition(
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment))
        self.id._name = "{}.{}".format(self.task_module,
                                       self.task_function_name)
Ejemplo n.º 10
0
def _create_hive_job_node(name, hive_job, metadata):
    """
    :param Text name:
    :param _qubole.QuboleHiveJob hive_job: Hive job spec
    :param flytekit.models.task.TaskMetadata metadata: This contains information needed at runtime to determine
        behavior such as whether or not outputs are discoverable, timeouts, and retries.
    :rtype: _nodes.SdkNode:
    """
    return _nodes.SdkNode(id=_six.text_type(_uuid.uuid4()),
                          upstream_nodes=[],
                          bindings=[],
                          metadata=_workflow_model.NodeMetadata(
                              name, metadata.timeout,
                              _literal_models.RetryStrategy(0)),
                          sdk_task=SdkHiveJob(hive_job, metadata))
Ejemplo n.º 11
0
def test_task_metadata():
    obj = task.TaskMetadata(
        True,
        task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                             "1.0.0", "python"), timedelta(days=1),
        literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!")

    assert obj.discoverable is True
    assert obj.retries.retries == 3
    assert obj.timeout == timedelta(days=1)
    assert obj.runtime.flavor == "python"
    assert obj.runtime.type == task.RuntimeMetadata.RuntimeType.FLYTE_SDK
    assert obj.runtime.version == "1.0.0"
    assert obj.deprecated_error_message == "This is deprecated!"
    assert obj.discovery_version == "0.1.1b0"
    assert obj == task.TaskMetadata.from_flyte_idl(obj.to_flyte_idl())
Ejemplo n.º 12
0
def test_sdk_node_from_lp():
    @_tasks.inputs(a=_types.Types.Integer)
    @_tasks.outputs(b=_types.Types.Integer)
    @_tasks.python_task()
    def testy_test(wf_params, a, b):
        pass

    @_workflow.workflow_class
    class test_workflow(object):
        a = _workflow.Input(_types.Types.Integer)
        test = testy_test(a=a)
        b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer)

    lp = test_workflow.create_launch_plan()

    n1 = _nodes.SdkNode(
        "n1",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_launch_plan=lp,
    )

    assert n1.id == "n1"
    assert len(n1.inputs) == 1
    assert n1.inputs[0].var == "a"
    assert n1.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n1.outputs) == 1
    assert "b" in n1.outputs
    assert n1.outputs["b"].node_id == "n1"
    assert n1.outputs["b"].var == "b"
    assert n1.outputs["b"].sdk_node == n1
    assert n1.outputs["b"].sdk_type == _types.Types.Integer
    assert n1.metadata.name == "abc"
    assert n1.metadata.retries.retries == 3
    assert len(n1.upstream_nodes) == 0
    assert len(n1.upstream_node_ids) == 0
    assert len(n1.output_aliases) == 0
Ejemplo n.º 13
0
    def __call__(self, *args, **input_map):
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a workflow as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args)))
        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(
            input_map)

        node = _nodes.SdkNode(
            id=None,
            metadata=_workflow_models.NodeMetadata(
                "placeholder", _datetime.timedelta(),
                _literal_models.RetryStrategy(0)),
            upstream_nodes=upstream_nodes,
            bindings=sorted(bindings, key=lambda b: b.var),
            sdk_workflow=self,
        )
        return node
Ejemplo n.º 14
0
def test_future_task_document(task):
    rs = _literals.RetryStrategy(0)
    nm = _workflow.NodeMetadata('node-name', _timedelta(minutes=10), rs)
    n = _workflow.Node(id="id",
                       metadata=nm,
                       inputs=[],
                       upstream_node_ids=[],
                       output_aliases=[],
                       task_node=_workflow.TaskNode(task.id))
    n.to_flyte_idl()
    doc = _dynamic_job.DynamicJobSpec(
        tasks=[task],
        nodes=[n],
        min_successes=1,
        outputs=[_literals.Binding("var", _literals.BindingData())],
        subworkflows=[])
    assert text_format.MessageToString(
        doc.to_flyte_idl()) == text_format.MessageToString(
            _dynamic_job.DynamicJobSpec.from_flyte_idl(
                doc.to_flyte_idl()).to_flyte_idl())
Ejemplo n.º 15
0
    def _op_to_task(self, dag_id, image, op, node_map):
        """
        Generate task given an operator inherited from dsl.ContainerOp.

        :param airflow.models.BaseOperator op:
        :param dict(Text, SdkNode) node_map:
        :rtype: Tuple(base_tasks.SdkTask, SdkNode)
        """

        interface_inputs = {}
        interface_outputs = {}
        input_mappings = {}
        processed_args = None

        # for key, val in six.iteritems(op.params):
        #     interface_inputs[key] = interface_model.Variable(
        #         _type_helpers.python_std_to_sdk_type(Types.String).to_flyte_literal_type(),
        #         ''
        #     )
        #
        #     if param.op_name == '':
        #         binding = promise_common.Input(sdk_type=Types.String, name=param.name)
        #     else:
        #         binding = promise_common.NodeOutput(
        #             sdk_node=node_map[param.op_name],
        #             sdk_type=Types.String,
        #             var=param.name)
        #     input_mappings[param.name] = binding
        #
        # for param in op.outputs.values():
        #     interface_outputs[param.name] = interface_model.Variable(
        #         _type_helpers.python_std_to_sdk_type(Types.String).to_flyte_literal_type(),
        #         ''
        #     )

        requests = []
        if op.resources:
            requests.append(
                task_model.Resources.ResourceEntry(
                    task_model.Resources.ResourceName.Cpu, op.resources.cpus))

            requests.append(
                task_model.Resources.ResourceEntry(
                    task_model.Resources.ResourceName.Memory,
                    op.resources.ram))

            requests.append(
                task_model.Resources.ResourceEntry(
                    task_model.Resources.ResourceName.Gpu, op.resources.gpus))

            requests.append(
                task_model.Resources.ResourceEntry(
                    task_model.Resources.ResourceName.Storage,
                    op.resources.disk))

        task_instance = TaskInstance(op, datetime.datetime.now())
        command = task_instance.command_as_list(local=True,
                                                mark_success=False,
                                                ignore_all_deps=True,
                                                ignore_depends_on_past=True,
                                                ignore_task_deps=True,
                                                ignore_ti_state=True,
                                                pool=task_instance.pool,
                                                pickle_id=dag_id,
                                                cfg_path=None)

        task = base_tasks.SdkTask(
            op.task_id,
            SingleStepTask,
            "airflow_op",
            task_model.TaskMetadata(
                False,
                task_model.RuntimeMetadata(
                    type=task_model.RuntimeMetadata.RuntimeType.Other,
                    version=airflow.version.version,
                    flavor='airflow'),
                datetime.timedelta(seconds=0),
                literals_model.RetryStrategy(0),
                '1',
                None,
            ),
            interface_common.TypedInterface(inputs=interface_inputs,
                                            outputs=interface_outputs),
            custom=None,
            container=task_model.Container(
                image=image,
                command=command,
                args=[],
                resources=task_model.Resources(limits=[], requests=requests),
                env={},
                config={},
            ))

        return task, task(**input_mappings).assign_id_and_return(op.task_id)
Ejemplo n.º 16
0
    def __init__(
        self,
        notebook_path,
        inputs,
        outputs,
        task_type,
        discovery_version,
        retries,
        deprecated,
        storage_request,
        cpu_request,
        gpu_request,
        memory_request,
        storage_limit,
        cpu_limit,
        gpu_limit,
        memory_limit,
        discoverable,
        timeout,
        environment,
        custom,
    ):

        if _os.path.isabs(notebook_path) is False:
            # Find absolute path for the notebook.
            task_module = _importlib.import_module(_find_instance_module())
            module_path = _os.path.dirname(task_module.__file__)
            notebook_path = _os.path.normpath(
                _os.path.join(module_path, notebook_path))

        self._notebook_path = notebook_path

        super(SdkNotebookTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__,
                    "notebook",
                ),
                timeout,
                _literal_models.RetryStrategy(retries),
                False,
                discovery_version,
                deprecated,
            ),
            _interface2.TypedInterface({}, {}),
            custom,
            container=self._get_container_definition(
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment,
            ),
        )
        # Add Inputs
        if inputs is not None:
            inputs(self)

        # Add outputs
        if outputs is not None:
            outputs(self)

        # Add a Notebook output as a Blob.
        self.interface.outputs.update(output_notebook=_interface.Variable(
            _Types.Blob.to_flyte_literal_type(), OUTPUT_NOTEBOOK))
Ejemplo n.º 17
0
    def __init__(
        self,
        inputs: Dict[str, FlyteSdkType],
        image: str,
        outputs: Dict[str, FlyteSdkType] = None,
        input_data_dir: str = None,
        output_data_dir: str = None,
        metadata_format: int = METADATA_FORMAT_JSON,
        io_strategy: _task_models.IOStrategy = None,
        command: List[str] = None,
        args: List[str] = None,
        storage_request: str = None,
        cpu_request: str = None,
        gpu_request: str = None,
        memory_request: str = None,
        storage_limit: str = None,
        cpu_limit: str = None,
        gpu_limit: str = None,
        memory_limit: str = None,
        environment: Dict[str, str] = None,
        interruptible: bool = False,
        discoverable: bool = False,
        discovery_version: str = None,
        retries: int = 1,
        timeout: _datetime.timedelta = None,
    ):
        """
        :param inputs:
        :param outputs:
        :param image:
        :param command:
        :param args:
        :param storage_request:
        :param cpu_request:
        :param gpu_request:
        :param memory_request:
        :param storage_limit:
        :param cpu_limit:
        :param gpu_limit:
        :param memory_limit:
        :param environment:
        :param interruptible:
        :param discoverable:
        :param discovery_version:
        :param retries:
        :param timeout:
        :param input_data_dir: This is the directory where data will be downloaded to
        :param output_data_dir: This is the directory where data will be uploaded from
        :param metadata_format: Format in which the metadata will be available for the script
        """

        # Set as class fields which are used down below to configure implicit
        # parameters
        self._data_loading_config = _task_models.DataLoadingConfig(
            input_path=input_data_dir,
            output_path=output_data_dir,
            format=metadata_format,
            enabled=True,
            io_strategy=io_strategy,
        )

        metadata = _task_models.TaskMetadata(
            discoverable,
            # This needs to have the proper version reflected in it
            _task_models.RuntimeMetadata(
                _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                flytekit.__version__,
                "python",
            ),
            timeout or _datetime.timedelta(seconds=0),
            _literals.RetryStrategy(retries),
            interruptible,
            discovery_version,
            None,
        )

        # The interface is defined using the inputs and outputs
        i = _interface.TypedInterface(inputs=types_to_variable(inputs),
                                      outputs=types_to_variable(outputs))

        # This sets the base SDKTask with container etc
        super(SdkRawContainerTask, self).__init__(
            _constants.SdkTaskType.RAW_CONTAINER_TASK,
            metadata,
            i,
            None,
            container=_get_container_definition(
                image=image,
                args=args,
                command=command,
                data_loading_config=self._data_loading_config,
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment,
            ),
        )
Ejemplo n.º 18
0
    def __init__(
        self,
        max_number_of_training_jobs: int,
        max_parallel_training_jobs: int,
        training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask,
                                   CustomTrainingJobTask],
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
        tunable_parameters: typing.List[str] = None,
    ):
        """
        :param int max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
        hyperparameter tuning job
        :param int max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
        tuning job in parallel
        :param typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask] training_job: The reference to the training job definition
        :param int retries: Number of retries to attempt
        :param bool cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param str cache_version: String describing the caching version for task discovery purposes
        :param typing.List[str] tunable_parameters: A list of parameters that to tune. If you are tuning a built-int
                algorithm, refer to the algorithm's documentation to understand the possible values for the tunable
                parameters. E.g. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html for the
                list of hyperparameters for Image Classification built-in algorithm. If you are passing a custom
                training job, the list of tunable parameters must be a strict subset of the list of inputs defined on
                that job. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html
                for the list of supported hyperparameter types.
        """
        # Use the training job model as a measure of type checking
        hpo_job = _hpo_job_model.HyperparameterTuningJob(
            max_number_of_training_jobs=max_number_of_training_jobs,
            max_parallel_training_jobs=max_parallel_training_jobs,
            training_job=training_job.training_job_model,
        ).to_flyte_idl()

        # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
        #   the underlying training job
        # TODO: Discuss whether this is a viable interface or contract
        timeout = _datetime.timedelta(seconds=0)

        inputs = {}
        inputs.update(training_job.interface.inputs)
        inputs.update({
            "hyperparameter_tuning_job_config":
            _interface_model.Variable(
                HyperparameterTuningJobConfig.to_flyte_literal_type(),
                "",
            ),
        })

        if tunable_parameters:
            inputs.update({
                param: _interface_model.Variable(
                    ParameterRange.to_flyte_literal_type(), "")
                for param in tunable_parameters
            })

        super().__init__(
            type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs=inputs,
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_types_models.LiteralType(
                            blob=_core_types.BlobType(
                                format="",
                                dimensionality=_core_types.BlobType.
                                BlobDimensionality.SINGLE,
                            )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(hpo_job),
        )
Ejemplo n.º 19
0
def create_and_link_node(
    ctx: FlyteContext,
    entity,
    interface: flyte_interface.Interface,
    timeout: Optional[datetime.timedelta] = None,
    retry_strategy: Optional[_literal_models.RetryStrategy] = None,
    **kwargs,
):
    """
    This method is used to generate a node with bindings. This is not used in the execution path.
    """
    if ctx.compilation_state is None:
        raise _user_exceptions.FlyteAssertion(
            "Cannot create node when not compiling...")

    used_inputs = set()
    bindings = []

    typed_interface = flyte_interface.transform_interface_to_typed_interface(
        interface)

    for k in sorted(interface.inputs):
        var = typed_interface.inputs[k]
        if k not in kwargs:
            raise _user_exceptions.FlyteAssertion(
                "Input was not specified for: {} of type {}".format(
                    k, var.type))
        v = kwargs[k]
        # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
        # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed
        # into the function.
        if isinstance(v, tuple):
            raise AssertionError(
                f"Variable({k}) for function({entity.name}) cannot receive a multi-valued tuple {v}."
                f" Check if the predecessor function returning more than one value?"
            )
        bindings.append(
            binding_from_python_std(ctx,
                                    var_name=k,
                                    expected_literal_type=var.type,
                                    t_value=v,
                                    t_value_type=interface.inputs[k]))
        used_inputs.add(k)

    extra_inputs = used_inputs ^ set(kwargs.keys())
    if len(extra_inputs) > 0:
        raise _user_exceptions.FlyteAssertion(
            "Too many inputs were specified for the interface.  Extra inputs were: {}"
            .format(extra_inputs))

    # Detect upstream nodes
    # These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes
    upstream_nodes = list(
        set([
            input_val.ref.node for input_val in kwargs.values()
            if isinstance(input_val, Promise)
            and input_val.ref.node_id != _common_constants.GLOBAL_INPUT_NODE_ID
        ]))

    node_metadata = _workflow_model.NodeMetadata(
        f"{entity.__module__}.{entity.name}",
        timeout or datetime.timedelta(),
        retry_strategy or _literal_models.RetryStrategy(0),
    )

    non_sdk_node = Node(
        # TODO: Better naming, probably a derivative of the function name.
        id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",
        metadata=node_metadata,
        bindings=sorted(bindings, key=lambda b: b.var),
        upstream_nodes=upstream_nodes,
        flyte_entity=entity,
    )
    ctx.compilation_state.add_node(non_sdk_node)

    if len(typed_interface.outputs) == 0:
        return VoidPromise(entity.name)

    # Create a node output object for each output, they should all point to this node of course.
    node_outputs = []
    for output_name, output_var_model in typed_interface.outputs.items():
        # TODO: If node id gets updated later, we have to make sure to update the NodeOutput model's ID, which
        #  is currently just a static str
        node_outputs.append(
            Promise(output_name, NodeOutput(node=non_sdk_node,
                                            var=output_name)))
        # Don't print this, it'll crash cuz sdk_node._upstream_node_ids might be None, but idl code will break

    return create_task_output(node_outputs, interface)
Ejemplo n.º 20
0
    def __init__(
        self,
        region,
        role_arn,
        resource_config,
        algorithm_specification=None,
        stopping_condition=None,
        vpc_config=None,
        enable_spot_training=False,
        interruptible=False,
        retries=0,
        cacheable=False,
        cache_version="",
    ):
        """
        :param Text region: The region in which to run the SageMaker job.
        :param Text role_arn: The ARN of the role to run in the SageMaker job.
        :param dict[Text,T] algorithm_specification: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html
        :param dict[Text,T] resource_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ResourceConfig.html
        :param dict[Text,T] stopping_condition: https://docs.aws.amazon.com/sagemaker/latest/dg/API_StoppingCondition.html
        :param dict[Text,T] vpc_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html
        :param bool enable_spot_training: https://docs.aws.amazon.com/sagemaker/latest/dg/API_HyperParameterTrainingJobDefinition.html
        :param int retries: Number of time to retry.
        :param bool cacheable: Whether or not to use Flyte's caching system.
        :param Text cache_version: Update this to notify a behavioral change requiring the cache to be invalidated.
        """

        algorithm_specification = algorithm_specification or {}
        algorithm_specification["TrainingImage"] = (
            algorithm_specification.get("TrainingImage")
            or "825641698319.dkr.ecr.us-east-2.amazonaws.com/xgboost:1")
        algorithm_specification["TrainingInputMode"] = "File"

        job_config = ParseDict(
            {
                "Region": region,
                "ResourceConfig": resource_config,
                "StoppingCondition": stopping_condition,
                "VpcConfig": vpc_config,
                "AlgorithmSpecification": algorithm_specification,
                "RoleArn": role_arn,
            },
            sagemaker_pb2.SagemakerHPOJob(),
        )
        print(MessageToDict(job_config))

        # TODO: Optionally, pull timeout behavior from stopping condition and pass to Flyte task def.
        timeout = _datetime.timedelta(seconds=0)

        # TODO: The FlyteKit type engine is extensible so we can create a SagemakerInput type with custom
        # TODO:     parsing/casting logic. For now, we will use the Generic type since there is a little that needs
        # TODO:     to be done on Flyte side to unlock this cleanly.
        # TODO: This call to the super-constructor will be less verbose in future versions of Flytekit following a
        # TODO:     refactor.
        # TODO: Add more configurations to the custom dict. These are things that are necessary to execute the task,
        # TODO:     but might not affect the outputs (i.e. Running on a bigger machine). These are currently static for
        # TODO:     a given definition of a task, but will be more dynamic in the future. Also, it is possible to
        # TODO:     make it dynamic by using our @dynamic_task.
        # TODO: You might want to inherit the role ARN from the execution at runtime.
        super(SagemakerXgBoostOptimizer, self).__init__(
            type=_TASK_TYPE,
            metadata=_task_models.TaskMetadata(
                discoverable=cacheable,
                runtime=_task_models.RuntimeMetadata(0, "0.1.0b0",
                                                     "sagemaker"),
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=interruptible,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface({}, {}),
            custom=MessageToDict(job_config),
        )

        # TODO: Add more inputs that we expect to change the outputs of the task.
        # TODO: We can add outputs too!
        # We use helper methods for adding to interface, thus overriding the one set above. This will be simplified post
        # refactor.
        self.add_inputs({
            "static_hyperparameters":
            _interface_model.Variable(
                _sdk_types.Types.Generic.to_flyte_literal_type(), ""),
            "train":
            _interface_model.Variable(
                _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""),
            "validation":
            _interface_model.Variable(
                _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""),
        })
        self.add_outputs({
            "model":
            _interface_model.Variable(
                _sdk_types.Types.Blob.to_flyte_literal_type(), "")
        })
Ejemplo n.º 21
0
    def __init__(
        self,
        task_function,
        task_type,
        discovery_version,
        retries,
        interruptible,
        deprecated,
        storage_request,
        cpu_request,
        gpu_request,
        memory_request,
        storage_limit,
        cpu_limit,
        gpu_limit,
        memory_limit,
        discoverable,
        timeout,
        environment,
        custom,
    ):
        """
        :param task_function: Function container user code.  This will be executed via the SDK's engine.
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Specify whether task is interruptible
        :param Text deprecated:
        :param Text storage_request:
        :param Text cpu_request:
        :param Text gpu_request:
        :param Text memory_request:
        :param Text storage_limit:
        :param Text cpu_limit:
        :param Text gpu_limit:
        :param Text memory_limit:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param dict[Text, Text] environment:
        :param dict[Text, T] custom:
        """
        # Circular dependency
        from flytekit import __version__

        self._task_function = task_function
        super(SdkRunnableTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__,
                    "python",
                ),
                timeout,
                _literal_models.RetryStrategy(retries),
                interruptible,
                discovery_version,
                deprecated,
            ),
            # TODO: If we end up using SdkRunnableTask for the new code, make sure this is set correctly.
            _interface.TypedInterface({}, {}),
            custom,
            container=self._get_container_definition(
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment,
            ),
        )
        self.id._name = "{}.{}".format(self.task_module, self.task_function_name)
        self._has_fast_registered = False

        # TODO: Remove this in the future, I don't think we'll be using this.
        self._task_style = SdkRunnableTaskStyle.V0
Ejemplo n.º 22
0
LIST_OF_RESOURCE_ENTRY_LISTS = [LIST_OF_RESOURCE_ENTRIES]

LIST_OF_RESOURCES = [
    task.Resources(request, limit) for request, limit in product(
        LIST_OF_RESOURCE_ENTRY_LISTS, LIST_OF_RESOURCE_ENTRY_LISTS)
]

LIST_OF_RUNTIME_METADATA = [
    task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.OTHER, "1.0.0",
                         "python"),
    task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0b0",
                         "golang")
]

LIST_OF_RETRY_POLICIES = [
    literals.RetryStrategy(retries=i) for i in [0, 1, 3, 100]
]

LIST_OF_INTERRUPTIBLE = [None, True, False]

LIST_OF_TASK_METADATA = [
    task.TaskMetadata(discoverable, runtime_metadata, timeout, retry_strategy,
                      interruptible, discovery_version, deprecated)
    for discoverable, runtime_metadata, timeout, retry_strategy, interruptible,
    discovery_version, deprecated in
    product([True, False], LIST_OF_RUNTIME_METADATA,
            [timedelta(days=i) for i in range(3)], LIST_OF_RETRY_POLICIES,
            LIST_OF_INTERRUPTIBLE, ["1.0"], ["deprecated"])
]

LIST_OF_TASK_TEMPLATES = [
Ejemplo n.º 23
0
    def __init__(
        self,
        statement,
        output_schema,
        routing_group=None,
        catalog=None,
        schema=None,
        task_inputs=None,
        interruptible=False,
        discoverable=False,
        discovery_version=None,
        retries=1,
        timeout=None,
        deprecated=None,
    ):
        """
        :param Text statement: Presto query specification
        :param flytekit.common.types.schema.Schema output_schema: Schema that represents that data queried from Presto
        :param Text routing_group: The routing group that a Presto query should be sent to for the given environment
        :param Text catalog: The catalog to set for the given Presto query
        :param Text schema: The schema to set for the given Presto query
        :param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] task_inputs: Optional inputs to the Presto task
        :param bool discoverable:
        :param Text discovery_version: String describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param datetime.timedelta timeout:
        :param Text deprecated: This string can be used to mark the task as deprecated.  Consumers of the task will
            receive deprecation warnings.
        """

        # Set as class fields which are used down below to configure implicit
        # parameters
        self._routing_group = routing_group or ""
        self._catalog = catalog or ""
        self._schema = schema or ""

        metadata = _task_model.TaskMetadata(
            discoverable,
            # This needs to have the proper version reflected in it
            _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"),
            timeout or _datetime.timedelta(seconds=0),
            _literals.RetryStrategy(retries),
            interruptible,
            discovery_version,
            deprecated,
        )

        presto_query = _presto_models.PrestoQuery(
            routing_group=routing_group or "",
            catalog=catalog or "",
            schema=schema or "",
            statement=statement,
        )

        # Here we set the routing_group, catalog, and schema as implicit
        # parameters for caching purposes
        i = _interface.TypedInterface(
            {
                "__implicit_routing_group": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The routing group set as an implicit input",
                ),
                "__implicit_catalog": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The catalog set as an implicit input",
                ),
                "__implicit_schema": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The schema set as an implicit input",
                ),
            },
            {
                # Set the schema for the Presto query as an output
                "results": _interface_model.Variable(
                    type=_types.LiteralType(schema=output_schema.schema_type),
                    description="The schema for the Presto query",
                )
            },
        )

        super(SdkPrestoTask, self).__init__(
            _constants.SdkTaskType.PRESTO_TASK,
            metadata,
            i,
            _MessageToDict(presto_query.to_flyte_idl()),
        )

        # Set user provided inputs
        task_inputs(self)
Ejemplo n.º 24
0
def _get_sample_node_metadata():
    return _workflow.NodeMetadata(name="node1",
                                  timeout=timedelta(seconds=10),
                                  retries=_literals.RetryStrategy(0))
Ejemplo n.º 25
0
    def __init__(
        self,
        max_number_of_training_jobs: int,
        max_parallel_training_jobs: int,
        training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask,
                                   CustomTrainingJobTask],
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
    ):
        """

        :param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
        hyperparameter tuning job
        :param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
        tuning job in parallel
        :param training_job: The reference to the training job definition
        :param retries: Number of retries to attempt
        :param cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param cache_version: String describing the caching version for task discovery purposes
        """
        # Use the training job model as a measure of type checking
        hpo_job = _hpo_job_model.HyperparameterTuningJob(
            max_number_of_training_jobs=max_number_of_training_jobs,
            max_parallel_training_jobs=max_parallel_training_jobs,
            training_job=training_job.training_job_model,
        ).to_flyte_idl()

        # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
        #   the underlying training job
        # TODO: Discuss whether this is a viable interface or contract
        timeout = _datetime.timedelta(seconds=0)

        inputs = {
            "hyperparameter_tuning_job_config":
            _interface_model.Variable(
                _sdk_types.Types.Proto(
                    _pb2_hpo_job.HyperparameterTuningJobConfig).
                to_flyte_literal_type(),
                "",
            ),
        }
        inputs.update(training_job.interface.inputs)

        super(SdkSimpleHyperparameterTuningJobTask, self).__init__(
            type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs=inputs,
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_types_models.LiteralType(
                            blob=_core_types.BlobType(
                                format="",
                                dimensionality=_core_types.BlobType.
                                BlobDimensionality.SINGLE,
                            )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(hpo_job),
        )
Ejemplo n.º 26
0
    def __init__(
        self,
        task_type,
        discovery_version,
        retries,
        interruptible,
        task_inputs,
        deprecated,
        discoverable,
        timeout,
        jar_file,
        main_class,
        args,
        flink_properties,
        environment,
    ):
        """
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Whether or not task is interruptible
        :param Text deprecated:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param Text main_class: Main class to execute for Scala/Java jobs
        :param Text jar_file: fat jar file
        :param dict[Text,Text] flink_properties:
        :param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
        """

        flink_job = _task_models.FlinkJob(flink_properties=flink_properties,
                                          jar_file=jar_file,
                                          main_class=main_class,
                                          args=args).to_flyte_idl()

        super(SdkGenericFlinkTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__,
                    "flink",
                ),
                timeout,
                _literal_models.RetryStrategy(retries),
                interruptible,
                discovery_version,
                deprecated,
            ),
            _interface.TypedInterface({}, {}),
            _MessageToDict(flink_job),
        )

        # Add Inputs
        if task_inputs is not None:
            task_inputs(self)

        # Container after the Inputs have been updated.
        self._container = self._get_container_definition(
            environment=environment)
Ejemplo n.º 27
0
def test_workflow_closure():
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })

    b0 = _literals.Binding(
        'a',
        _literals.BindingData(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=5))))
    b1 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'b')))
    b2 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'c')))

    node_metadata = _workflow.NodeMetadata(name='node1',
                                           timeout=timedelta(seconds=10),
                                           retries=_literals.RetryStrategy(0))

    task_metadata = _task.TaskMetadata(
        True,
        _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                              "1.0.0", "python"), timedelta(days=1),
        _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!")

    cpu_resource = _task.Resources.ResourceEntry(
        _task.Resources.ResourceName.CPU, "1")
    resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])

    task = _task.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "project",
                               "domain", "name", "version"),
        "python",
        task_metadata,
        typed_interface, {
            'a': 1,
            'b': {
                'c': 2,
                'd': 3
            }
        },
        container=_task.Container("my_image", ["this", "is", "a", "cmd"],
                                  ["this", "is", "an", "arg"], resources, {},
                                  {}))

    task_node = _workflow.TaskNode(task.id)
    node = _workflow.Node(id='my_node',
                          metadata=node_metadata,
                          inputs=[b0],
                          upstream_node_ids=[],
                          output_aliases=[],
                          task_node=task_node)

    template = _workflow.WorkflowTemplate(
        id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project",
                                  "domain", "name", "version"),
        metadata=_workflow.WorkflowMetadata(),
        interface=typed_interface,
        nodes=[node],
        outputs=[b1, b2],
    )

    obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task])
    assert len(obj.tasks) == 1

    obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
Ejemplo n.º 28
0
def test_sdk_node_from_task():
    @_tasks.inputs(a=_types.Types.Integer)
    @_tasks.outputs(b=_types.Types.Integer)
    @_tasks.python_task()
    def testy_test(wf_params, a, b):
        pass

    n = _nodes.SdkNode(
        "n",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    assert n.id == "n"
    assert len(n.inputs) == 1
    assert n.inputs[0].var == "a"
    assert n.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n.outputs) == 1
    assert "b" in n.outputs
    assert n.outputs["b"].node_id == "n"
    assert n.outputs["b"].var == "b"
    assert n.outputs["b"].sdk_node == n
    assert n.outputs["b"].sdk_type == _types.Types.Integer
    assert n.metadata.name == "abc"
    assert n.metadata.retries.retries == 3
    assert n.metadata.interruptible is None
    assert len(n.upstream_nodes) == 0
    assert len(n.upstream_node_ids) == 0
    assert len(n.output_aliases) == 0

    n2 = _nodes.SdkNode(
        "n2",
        [n],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), n.outputs.b),
            )
        ],
        _core_workflow_models.NodeMetadata("abc2",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    assert n2.id == "n2"
    assert len(n2.inputs) == 1
    assert n2.inputs[0].var == "a"
    assert n2.inputs[0].binding.promise.var == "b"
    assert n2.inputs[0].binding.promise.node_id == "n"
    assert len(n2.outputs) == 1
    assert "b" in n2.outputs
    assert n2.outputs["b"].node_id == "n2"
    assert n2.outputs["b"].var == "b"
    assert n2.outputs["b"].sdk_node == n2
    assert n2.outputs["b"].sdk_type == _types.Types.Integer
    assert n2.metadata.name == "abc2"
    assert n2.metadata.retries.retries == 3
    assert "n" in n2.upstream_node_ids
    assert n in n2.upstream_nodes
    assert len(n2.upstream_nodes) == 1
    assert len(n2.upstream_node_ids) == 1
    assert len(n2.output_aliases) == 0

    # Test right shift operator and late binding
    n3 = _nodes.SdkNode(
        "n3",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc3",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )
    n2 >> n3
    n >> n2 >> n3
    n3 << n2
    n3 << n2 << n

    assert n3.id == "n3"
    assert len(n3.inputs) == 1
    assert n3.inputs[0].var == "a"
    assert n3.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n3.outputs) == 1
    assert "b" in n3.outputs
    assert n3.outputs["b"].node_id == "n3"
    assert n3.outputs["b"].var == "b"
    assert n3.outputs["b"].sdk_node == n3
    assert n3.outputs["b"].sdk_type == _types.Types.Integer
    assert n3.metadata.name == "abc3"
    assert n3.metadata.retries.retries == 3
    assert "n2" in n3.upstream_node_ids
    assert n2 in n3.upstream_nodes
    assert len(n3.upstream_nodes) == 1
    assert len(n3.upstream_node_ids) == 1
    assert len(n3.output_aliases) == 0

    # Test left shift operator and late binding
    n4 = _nodes.SdkNode(
        "n4",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc4",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    n4 << n3

    # Test that implicit dependencies don't cause direct dependencies
    n4 << n3 << n2 << n
    n >> n2 >> n3 >> n4

    assert n4.id == "n4"
    assert len(n4.inputs) == 1
    assert n4.inputs[0].var == "a"
    assert n4.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n4.outputs) == 1
    assert "b" in n4.outputs
    assert n4.outputs["b"].node_id == "n4"
    assert n4.outputs["b"].var == "b"
    assert n4.outputs["b"].sdk_node == n4
    assert n4.outputs["b"].sdk_type == _types.Types.Integer
    assert n4.metadata.name == "abc4"
    assert n4.metadata.retries.retries == 3
    assert "n3" in n4.upstream_node_ids
    assert n3 in n4.upstream_nodes
    assert len(n4.upstream_nodes) == 1
    assert len(n4.upstream_node_ids) == 1
    assert len(n4.output_aliases) == 0

    # Add another dependency
    n4 << n2
    assert "n3" in n4.upstream_node_ids
    assert n3 in n4.upstream_nodes
    assert "n2" in n4.upstream_node_ids
    assert n2 in n4.upstream_nodes
    assert len(n4.upstream_nodes) == 2
    assert len(n4.upstream_node_ids) == 2
Ejemplo n.º 29
0
 def retry_strategy(self) -> _literal_models.RetryStrategy:
     return _literal_models.RetryStrategy(self.retries)
Ejemplo n.º 30
0
    def __init__(
        self,
        task_type,
        discovery_version,
        retries,
        interruptible,
        task_inputs,
        deprecated,
        discoverable,
        timeout,
        spark_type,
        main_class,
        main_application_file,
        spark_conf,
        hadoop_conf,
        environment,
    ):
        """
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Whether or not task is interruptible
        :param Text deprecated:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param Text spark_type: Type of Spark Job: Scala/Java
        :param Text main_class: Main class to execute for Scala/Java jobs
        :param Text main_application_file: Main application file
        :param dict[Text,Text] spark_conf:
        :param dict[Text,Text] hadoop_conf:
        :param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
        """

        spark_job = _task_models.SparkJob(
            spark_conf=spark_conf,
            hadoop_conf=hadoop_conf,
            spark_type=spark_type,
            application_file=main_application_file,
            main_class=main_class,
            executor_path=_sys.executable,
        ).to_flyte_idl()

        super(SdkGenericSparkTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__, 'spark'), timeout,
                _literal_models.RetryStrategy(retries), interruptible,
                discovery_version, deprecated),
            _interface.TypedInterface({}, {}),
            _MessageToDict(spark_job),
        )

        # Add Inputs
        if task_inputs is not None:
            task_inputs(self)

        # Container after the Inputs have been updated.
        self._container = self._get_container_definition(
            environment=environment)