示例#1
0
def test_task_template__k8s_pod_target():
    int_type = types.LiteralType(types.SimpleType.INTEGER)
    obj = task.TaskTemplate(
        identifier.Identifier(identifier.ResourceType.TASK, "project",
                              "domain", "name", "version"),
        "python",
        task.TaskMetadata(
            False,
            task.RuntimeMetadata(1, "v", "f"),
            timedelta(days=1),
            literal_models.RetryStrategy(5),
            False,
            "1.0",
            "deprecated",
            False,
        ),
        interface_models.TypedInterface(
            # inputs
            {"a": interface_models.Variable(int_type, "description1")},
            # outputs
            {
                "b": interface_models.Variable(int_type, "description2"),
                "c": interface_models.Variable(int_type, "description3"),
            },
        ),
        {
            "a": 1,
            "b": {
                "c": 2,
                "d": 3
            }
        },
        config={"a": "b"},
        k8s_pod=task.K8sPod(
            metadata=task.K8sObjectMetadata(labels={"label": "foo"},
                                            annotations={"anno": "bar"}),
            pod_spec={
                "str": "val",
                "int": 1
            },
        ),
    )
    assert obj.id.resource_type == identifier.ResourceType.TASK
    assert obj.id.project == "project"
    assert obj.id.domain == "domain"
    assert obj.id.name == "name"
    assert obj.id.version == "version"
    assert obj.type == "python"
    assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}}
    assert obj.k8s_pod.metadata == task.K8sObjectMetadata(
        labels={"label": "foo"}, annotations={"anno": "bar"})
    assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1}
    assert text_format.MessageToString(
        obj.to_flyte_idl()) == text_format.MessageToString(
            task.TaskTemplate.from_flyte_idl(
                obj.to_flyte_idl()).to_flyte_idl())
    assert obj.config == {"a": "b"}
def get_sample_task_metadata():
    """
    :rtype: flytekit.models.task.TaskMetadata
    """
    return _task_model.TaskMetadata(
        True,
        _task_model.RuntimeMetadata(
            _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0",
            "python"), timedelta(days=1), _literals.RetryStrategy(3), True,
        "0.1.1b0", "This is deprecated!")
示例#3
0
 def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
     """
     Converts to _task_model.TaskMetadata
     """
     return _task_model.TaskMetadata(
         discoverable=self.cache,
         # TODO Fix the version circular dependency before beta
         runtime=_task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, "0.16.0", "python"),
         timeout=self.timeout,
         retries=self.retry_strategy,
         interruptible=self.interruptable,
         discovery_version=self.cache_version,
         deprecated_error_message=self.deprecated,
     )
示例#4
0
    def __init__(self, task_function, task_type, discovery_version, retries,
                 interruptible, deprecated, storage_request, cpu_request,
                 gpu_request, memory_request, storage_limit, cpu_limit,
                 gpu_limit, memory_limit, discoverable, timeout, environment,
                 custom):
        """
        :param task_function: Function container user code.  This will be executed via the SDK's engine.
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Specify whether task is interruptible
        :param Text deprecated:
        :param Text storage_request:
        :param Text cpu_request:
        :param Text gpu_request:
        :param Text memory_request:
        :param Text storage_limit:
        :param Text cpu_limit:
        :param Text gpu_limit:
        :param Text memory_limit:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param dict[Text, Text] environment:
        :param dict[Text, T] custom:
        """
        self._task_function = task_function

        super(SdkRunnableTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__, 'python'), timeout,
                _literal_models.RetryStrategy(retries), interruptible,
                discovery_version, deprecated),
            _interface.TypedInterface({}, {}),
            custom,
            container=self._get_container_definition(
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment))
        self.id._name = "{}.{}".format(self.task_module,
                                       self.task_function_name)
示例#5
0
def test_task_metadata():
    obj = task.TaskMetadata(
        True,
        task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                             "1.0.0", "python"), timedelta(days=1),
        literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!")

    assert obj.discoverable is True
    assert obj.retries.retries == 3
    assert obj.timeout == timedelta(days=1)
    assert obj.runtime.flavor == "python"
    assert obj.runtime.type == task.RuntimeMetadata.RuntimeType.FLYTE_SDK
    assert obj.runtime.version == "1.0.0"
    assert obj.deprecated_error_message == "This is deprecated!"
    assert obj.discovery_version == "0.1.1b0"
    assert obj == task.TaskMetadata.from_flyte_idl(obj.to_flyte_idl())
示例#6
0
    def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
        """
        Converts to _task_model.TaskMetadata
        """
        from flytekit import __version__

        return _task_model.TaskMetadata(
            discoverable=self.cache,
            runtime=_task_model.RuntimeMetadata(
                _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__,
                "python"),
            timeout=self.timeout,
            retries=self.retry_strategy,
            interruptible=self.interruptable,
            discovery_version=self.cache_version,
            deprecated_error_message=self.deprecated,
        )
示例#7
0
def test_workflow_closure():
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })

    b0 = _literals.Binding(
        'a',
        _literals.BindingData(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=5))))
    b1 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'b')))
    b2 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'c')))

    node_metadata = _workflow.NodeMetadata(name='node1',
                                           timeout=timedelta(seconds=10),
                                           retries=_literals.RetryStrategy(0))

    task_metadata = _task.TaskMetadata(
        True,
        _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                              "1.0.0", "python"), timedelta(days=1),
        _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!")

    cpu_resource = _task.Resources.ResourceEntry(
        _task.Resources.ResourceName.CPU, "1")
    resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])

    task = _task.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "project",
                               "domain", "name", "version"),
        "python",
        task_metadata,
        typed_interface, {
            'a': 1,
            'b': {
                'c': 2,
                'd': 3
            }
        },
        container=_task.Container("my_image", ["this", "is", "a", "cmd"],
                                  ["this", "is", "an", "arg"], resources, {},
                                  {}))

    task_node = _workflow.TaskNode(task.id)
    node = _workflow.Node(id='my_node',
                          metadata=node_metadata,
                          inputs=[b0],
                          upstream_node_ids=[],
                          output_aliases=[],
                          task_node=task_node)

    template = _workflow.WorkflowTemplate(
        id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project",
                                  "domain", "name", "version"),
        metadata=_workflow.WorkflowMetadata(),
        interface=typed_interface,
        nodes=[node],
        outputs=[b1, b2],
    )

    obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task])
    assert len(obj.tasks) == 1

    obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
示例#8
0
    def __init__(
        self,
        task_type,
        discovery_version,
        retries,
        interruptible,
        task_inputs,
        deprecated,
        discoverable,
        timeout,
        spark_type,
        main_class,
        main_application_file,
        spark_conf,
        hadoop_conf,
        environment,
    ):
        """
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Whether or not task is interruptible
        :param Text deprecated:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param Text spark_type: Type of Spark Job: Scala/Java
        :param Text main_class: Main class to execute for Scala/Java jobs
        :param Text main_application_file: Main application file
        :param dict[Text,Text] spark_conf:
        :param dict[Text,Text] hadoop_conf:
        :param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
        """

        spark_job = _task_models.SparkJob(
            spark_conf=spark_conf,
            hadoop_conf=hadoop_conf,
            spark_type=spark_type,
            application_file=main_application_file,
            main_class=main_class,
            executor_path=_sys.executable,
        ).to_flyte_idl()

        super(SdkGenericSparkTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__, 'spark'), timeout,
                _literal_models.RetryStrategy(retries), interruptible,
                discovery_version, deprecated),
            _interface.TypedInterface({}, {}),
            _MessageToDict(spark_job),
        )

        # Add Inputs
        if task_inputs is not None:
            task_inputs(self)

        # Container after the Inputs have been updated.
        self._container = self._get_container_definition(
            environment=environment)
示例#9
0
    def __init__(
        self,
        max_number_of_training_jobs: int,
        max_parallel_training_jobs: int,
        training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask,
                                   CustomTrainingJobTask],
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
    ):
        """

        :param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
        hyperparameter tuning job
        :param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
        tuning job in parallel
        :param training_job: The reference to the training job definition
        :param retries: Number of retries to attempt
        :param cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param cache_version: String describing the caching version for task discovery purposes
        """
        # Use the training job model as a measure of type checking
        hpo_job = _hpo_job_model.HyperparameterTuningJob(
            max_number_of_training_jobs=max_number_of_training_jobs,
            max_parallel_training_jobs=max_parallel_training_jobs,
            training_job=training_job.training_job_model,
        ).to_flyte_idl()

        # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
        #   the underlying training job
        # TODO: Discuss whether this is a viable interface or contract
        timeout = _datetime.timedelta(seconds=0)

        inputs = {
            "hyperparameter_tuning_job_config":
            _interface_model.Variable(
                _sdk_types.Types.Proto(
                    _pb2_hpo_job.HyperparameterTuningJobConfig).
                to_flyte_literal_type(),
                "",
            ),
        }
        inputs.update(training_job.interface.inputs)

        super(SdkSimpleHyperparameterTuningJobTask, self).__init__(
            type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs=inputs,
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_types_models.LiteralType(
                            blob=_core_types.BlobType(
                                format="",
                                dimensionality=_core_types.BlobType.
                                BlobDimensionality.SINGLE,
                            )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(hpo_job),
        )
示例#10
0
    def __init__(
        self,
        region,
        role_arn,
        resource_config,
        algorithm_specification=None,
        stopping_condition=None,
        vpc_config=None,
        enable_spot_training=False,
        interruptible=False,
        retries=0,
        cacheable=False,
        cache_version="",
    ):
        """
        :param Text region: The region in which to run the SageMaker job.
        :param Text role_arn: The ARN of the role to run in the SageMaker job.
        :param dict[Text,T] algorithm_specification: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html
        :param dict[Text,T] resource_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ResourceConfig.html
        :param dict[Text,T] stopping_condition: https://docs.aws.amazon.com/sagemaker/latest/dg/API_StoppingCondition.html
        :param dict[Text,T] vpc_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html
        :param bool enable_spot_training: https://docs.aws.amazon.com/sagemaker/latest/dg/API_HyperParameterTrainingJobDefinition.html
        :param int retries: Number of time to retry.
        :param bool cacheable: Whether or not to use Flyte's caching system.
        :param Text cache_version: Update this to notify a behavioral change requiring the cache to be invalidated.
        """

        algorithm_specification = algorithm_specification or {}
        algorithm_specification["TrainingImage"] = (
            algorithm_specification.get("TrainingImage")
            or "825641698319.dkr.ecr.us-east-2.amazonaws.com/xgboost:1")
        algorithm_specification["TrainingInputMode"] = "File"

        job_config = ParseDict(
            {
                "Region": region,
                "ResourceConfig": resource_config,
                "StoppingCondition": stopping_condition,
                "VpcConfig": vpc_config,
                "AlgorithmSpecification": algorithm_specification,
                "RoleArn": role_arn,
            },
            sagemaker_pb2.SagemakerHPOJob(),
        )
        print(MessageToDict(job_config))

        # TODO: Optionally, pull timeout behavior from stopping condition and pass to Flyte task def.
        timeout = _datetime.timedelta(seconds=0)

        # TODO: The FlyteKit type engine is extensible so we can create a SagemakerInput type with custom
        # TODO:     parsing/casting logic. For now, we will use the Generic type since there is a little that needs
        # TODO:     to be done on Flyte side to unlock this cleanly.
        # TODO: This call to the super-constructor will be less verbose in future versions of Flytekit following a
        # TODO:     refactor.
        # TODO: Add more configurations to the custom dict. These are things that are necessary to execute the task,
        # TODO:     but might not affect the outputs (i.e. Running on a bigger machine). These are currently static for
        # TODO:     a given definition of a task, but will be more dynamic in the future. Also, it is possible to
        # TODO:     make it dynamic by using our @dynamic_task.
        # TODO: You might want to inherit the role ARN from the execution at runtime.
        super(SagemakerXgBoostOptimizer, self).__init__(
            type=_TASK_TYPE,
            metadata=_task_models.TaskMetadata(
                discoverable=cacheable,
                runtime=_task_models.RuntimeMetadata(0, "0.1.0b0",
                                                     "sagemaker"),
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=interruptible,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface({}, {}),
            custom=MessageToDict(job_config),
        )

        # TODO: Add more inputs that we expect to change the outputs of the task.
        # TODO: We can add outputs too!
        # We use helper methods for adding to interface, thus overriding the one set above. This will be simplified post
        # refactor.
        self.add_inputs({
            "static_hyperparameters":
            _interface_model.Variable(
                _sdk_types.Types.Generic.to_flyte_literal_type(), ""),
            "train":
            _interface_model.Variable(
                _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""),
            "validation":
            _interface_model.Variable(
                _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""),
        })
        self.add_outputs({
            "model":
            _interface_model.Variable(
                _sdk_types.Types.Blob.to_flyte_literal_type(), "")
        })
示例#11
0
LIST_OF_RUNTIME_METADATA = [
    task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.OTHER, "1.0.0",
                         "python"),
    task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0b0",
                         "golang")
]

LIST_OF_RETRY_POLICIES = [
    literals.RetryStrategy(retries=i) for i in [0, 1, 3, 100]
]

LIST_OF_INTERRUPTIBLE = [None, True, False]

LIST_OF_TASK_METADATA = [
    task.TaskMetadata(discoverable, runtime_metadata, timeout, retry_strategy,
                      interruptible, discovery_version, deprecated)
    for discoverable, runtime_metadata, timeout, retry_strategy, interruptible,
    discovery_version, deprecated in
    product([True, False], LIST_OF_RUNTIME_METADATA,
            [timedelta(days=i) for i in range(3)], LIST_OF_RETRY_POLICIES,
            LIST_OF_INTERRUPTIBLE, ["1.0"], ["deprecated"])
]

LIST_OF_TASK_TEMPLATES = [
    task.TaskTemplate(identifier.Identifier(identifier.ResourceType.TASK,
                                            "project", "domain", "name",
                                            "version"),
                      "python",
                      task_metadata,
                      interfaces, {
                          'a': 1,
示例#12
0
    def __init__(
        self,
        task_function,
        task_type,
        discovery_version,
        retries,
        interruptible,
        deprecated,
        storage_request,
        cpu_request,
        gpu_request,
        memory_request,
        storage_limit,
        cpu_limit,
        gpu_limit,
        memory_limit,
        discoverable,
        timeout,
        environment,
        custom,
    ):
        """
        :param task_function: Function container user code.  This will be executed via the SDK's engine.
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Specify whether task is interruptible
        :param Text deprecated:
        :param Text storage_request:
        :param Text cpu_request:
        :param Text gpu_request:
        :param Text memory_request:
        :param Text storage_limit:
        :param Text cpu_limit:
        :param Text gpu_limit:
        :param Text memory_limit:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param dict[Text, Text] environment:
        :param dict[Text, T] custom:
        """
        # Circular dependency
        from flytekit import __version__

        self._task_function = task_function
        super(SdkRunnableTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__,
                    "python",
                ),
                timeout,
                _literal_models.RetryStrategy(retries),
                interruptible,
                discovery_version,
                deprecated,
            ),
            # TODO: If we end up using SdkRunnableTask for the new code, make sure this is set correctly.
            _interface.TypedInterface({}, {}),
            custom,
            container=self._get_container_definition(
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment,
            ),
        )
        self.id._name = "{}.{}".format(self.task_module, self.task_function_name)
        self._has_fast_registered = False

        # TODO: Remove this in the future, I don't think we'll be using this.
        self._task_style = SdkRunnableTaskStyle.V0
示例#13
0
    def _op_to_task(self, dag_id, image, op, node_map):
        """
        Generate task given an operator inherited from dsl.ContainerOp.

        :param airflow.models.BaseOperator op:
        :param dict(Text, SdkNode) node_map:
        :rtype: Tuple(base_tasks.SdkTask, SdkNode)
        """

        interface_inputs = {}
        interface_outputs = {}
        input_mappings = {}
        processed_args = None

        # for key, val in six.iteritems(op.params):
        #     interface_inputs[key] = interface_model.Variable(
        #         _type_helpers.python_std_to_sdk_type(Types.String).to_flyte_literal_type(),
        #         ''
        #     )
        #
        #     if param.op_name == '':
        #         binding = promise_common.Input(sdk_type=Types.String, name=param.name)
        #     else:
        #         binding = promise_common.NodeOutput(
        #             sdk_node=node_map[param.op_name],
        #             sdk_type=Types.String,
        #             var=param.name)
        #     input_mappings[param.name] = binding
        #
        # for param in op.outputs.values():
        #     interface_outputs[param.name] = interface_model.Variable(
        #         _type_helpers.python_std_to_sdk_type(Types.String).to_flyte_literal_type(),
        #         ''
        #     )

        requests = []
        if op.resources:
            requests.append(
                task_model.Resources.ResourceEntry(
                    task_model.Resources.ResourceName.Cpu, op.resources.cpus))

            requests.append(
                task_model.Resources.ResourceEntry(
                    task_model.Resources.ResourceName.Memory,
                    op.resources.ram))

            requests.append(
                task_model.Resources.ResourceEntry(
                    task_model.Resources.ResourceName.Gpu, op.resources.gpus))

            requests.append(
                task_model.Resources.ResourceEntry(
                    task_model.Resources.ResourceName.Storage,
                    op.resources.disk))

        task_instance = TaskInstance(op, datetime.datetime.now())
        command = task_instance.command_as_list(local=True,
                                                mark_success=False,
                                                ignore_all_deps=True,
                                                ignore_depends_on_past=True,
                                                ignore_task_deps=True,
                                                ignore_ti_state=True,
                                                pool=task_instance.pool,
                                                pickle_id=dag_id,
                                                cfg_path=None)

        task = base_tasks.SdkTask(
            op.task_id,
            SingleStepTask,
            "airflow_op",
            task_model.TaskMetadata(
                False,
                task_model.RuntimeMetadata(
                    type=task_model.RuntimeMetadata.RuntimeType.Other,
                    version=airflow.version.version,
                    flavor='airflow'),
                datetime.timedelta(seconds=0),
                literals_model.RetryStrategy(0),
                '1',
                None,
            ),
            interface_common.TypedInterface(inputs=interface_inputs,
                                            outputs=interface_outputs),
            custom=None,
            container=task_model.Container(
                image=image,
                command=command,
                args=[],
                resources=task_model.Resources(limits=[], requests=requests),
                env={},
                config={},
            ))

        return task, task(**input_mappings).assign_id_and_return(op.task_id)
示例#14
0
    def __init__(
        self,
        statement,
        output_schema,
        routing_group=None,
        catalog=None,
        schema=None,
        task_inputs=None,
        interruptible=False,
        discoverable=False,
        discovery_version=None,
        retries=1,
        timeout=None,
        deprecated=None,
    ):
        """
        :param Text statement: Presto query specification
        :param flytekit.common.types.schema.Schema output_schema: Schema that represents that data queried from Presto
        :param Text routing_group: The routing group that a Presto query should be sent to for the given environment
        :param Text catalog: The catalog to set for the given Presto query
        :param Text schema: The schema to set for the given Presto query
        :param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] task_inputs: Optional inputs to the Presto task
        :param bool discoverable:
        :param Text discovery_version: String describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param datetime.timedelta timeout:
        :param Text deprecated: This string can be used to mark the task as deprecated.  Consumers of the task will
            receive deprecation warnings.
        """

        # Set as class fields which are used down below to configure implicit
        # parameters
        self._routing_group = routing_group or ""
        self._catalog = catalog or ""
        self._schema = schema or ""

        metadata = _task_model.TaskMetadata(
            discoverable,
            # This needs to have the proper version reflected in it
            _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"),
            timeout or _datetime.timedelta(seconds=0),
            _literals.RetryStrategy(retries),
            interruptible,
            discovery_version,
            deprecated,
        )

        presto_query = _presto_models.PrestoQuery(
            routing_group=routing_group or "",
            catalog=catalog or "",
            schema=schema or "",
            statement=statement,
        )

        # Here we set the routing_group, catalog, and schema as implicit
        # parameters for caching purposes
        i = _interface.TypedInterface(
            {
                "__implicit_routing_group": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The routing group set as an implicit input",
                ),
                "__implicit_catalog": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The catalog set as an implicit input",
                ),
                "__implicit_schema": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The schema set as an implicit input",
                ),
            },
            {
                # Set the schema for the Presto query as an output
                "results": _interface_model.Variable(
                    type=_types.LiteralType(schema=output_schema.schema_type),
                    description="The schema for the Presto query",
                )
            },
        )

        super(SdkPrestoTask, self).__init__(
            _constants.SdkTaskType.PRESTO_TASK,
            metadata,
            i,
            _MessageToDict(presto_query.to_flyte_idl()),
        )

        # Set user provided inputs
        task_inputs(self)
示例#15
0
    def __init__(
        self,
        max_number_of_training_jobs: int,
        max_parallel_training_jobs: int,
        training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask,
                                   CustomTrainingJobTask],
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
        tunable_parameters: typing.List[str] = None,
    ):
        """
        :param int max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
        hyperparameter tuning job
        :param int max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
        tuning job in parallel
        :param typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask] training_job: The reference to the training job definition
        :param int retries: Number of retries to attempt
        :param bool cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param str cache_version: String describing the caching version for task discovery purposes
        :param typing.List[str] tunable_parameters: A list of parameters that to tune. If you are tuning a built-int
                algorithm, refer to the algorithm's documentation to understand the possible values for the tunable
                parameters. E.g. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html for the
                list of hyperparameters for Image Classification built-in algorithm. If you are passing a custom
                training job, the list of tunable parameters must be a strict subset of the list of inputs defined on
                that job. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html
                for the list of supported hyperparameter types.
        """
        # Use the training job model as a measure of type checking
        hpo_job = _hpo_job_model.HyperparameterTuningJob(
            max_number_of_training_jobs=max_number_of_training_jobs,
            max_parallel_training_jobs=max_parallel_training_jobs,
            training_job=training_job.training_job_model,
        ).to_flyte_idl()

        # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
        #   the underlying training job
        # TODO: Discuss whether this is a viable interface or contract
        timeout = _datetime.timedelta(seconds=0)

        inputs = {}
        inputs.update(training_job.interface.inputs)
        inputs.update({
            "hyperparameter_tuning_job_config":
            _interface_model.Variable(
                HyperparameterTuningJobConfig.to_flyte_literal_type(),
                "",
            ),
        })

        if tunable_parameters:
            inputs.update({
                param: _interface_model.Variable(
                    ParameterRange.to_flyte_literal_type(), "")
                for param in tunable_parameters
            })

        super().__init__(
            type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs=inputs,
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_types_models.LiteralType(
                            blob=_core_types.BlobType(
                                format="",
                                dimensionality=_core_types.BlobType.
                                BlobDimensionality.SINGLE,
                            )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(hpo_job),
        )
    def __init__(
        self,
        training_job_resource_config: _training_job_models.
        TrainingJobResourceConfig,
        algorithm_specification: _training_job_models.AlgorithmSpecification,
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
    ):
        """

        :param training_job_resource_config: The options to configure the training job
        :param algorithm_specification: The options to configure the target algorithm of the training
        :param retries: Number of retries to attempt
        :param cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param cache_version: String describing the caching version for task discovery purposes
        """
        # Use the training job model as a measure of type checking
        self._training_job_model = _training_job_models.TrainingJob(
            algorithm_specification=algorithm_specification,
            training_job_resource_config=training_job_resource_config,
        )

        # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training
        # job gracefully
        timeout = _datetime.timedelta(seconds=0)

        super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__(
            type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs={
                    "static_hyperparameters":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(
                            simple=_idl_types.SimpleType.STRUCT),
                        description="",
                    ),
                    "train":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format=_content_type_to_blob_format(
                                algorithm_specification.input_content_type),
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.MULTIPART,
                        ), ),
                        description="",
                    ),
                    "validation":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format=_content_type_to_blob_format(
                                algorithm_specification.input_content_type),
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.MULTIPART,
                        ), ),
                        description="",
                    ),
                },
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format="",
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.SINGLE,
                        )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(self._training_job_model.to_flyte_idl()),
        )
示例#17
0
    def __init__(
        self,
        task_type,
        discovery_version,
        retries,
        interruptible,
        task_inputs,
        deprecated,
        discoverable,
        timeout,
        jar_file,
        main_class,
        args,
        flink_properties,
        environment,
    ):
        """
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Whether or not task is interruptible
        :param Text deprecated:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param Text main_class: Main class to execute for Scala/Java jobs
        :param Text jar_file: fat jar file
        :param dict[Text,Text] flink_properties:
        :param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
        """

        flink_job = _task_models.FlinkJob(flink_properties=flink_properties,
                                          jar_file=jar_file,
                                          main_class=main_class,
                                          args=args).to_flyte_idl()

        super(SdkGenericFlinkTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__,
                    "flink",
                ),
                timeout,
                _literal_models.RetryStrategy(retries),
                interruptible,
                discovery_version,
                deprecated,
            ),
            _interface.TypedInterface({}, {}),
            _MessageToDict(flink_job),
        )

        # Add Inputs
        if task_inputs is not None:
            task_inputs(self)

        # Container after the Inputs have been updated.
        self._container = self._get_container_definition(
            environment=environment)
示例#18
0
    def __init__(
        self,
        inputs: Dict[str, FlyteSdkType],
        image: str,
        outputs: Dict[str, FlyteSdkType] = None,
        input_data_dir: str = None,
        output_data_dir: str = None,
        metadata_format: int = METADATA_FORMAT_JSON,
        io_strategy: _task_models.IOStrategy = None,
        command: List[str] = None,
        args: List[str] = None,
        storage_request: str = None,
        cpu_request: str = None,
        gpu_request: str = None,
        memory_request: str = None,
        storage_limit: str = None,
        cpu_limit: str = None,
        gpu_limit: str = None,
        memory_limit: str = None,
        environment: Dict[str, str] = None,
        interruptible: bool = False,
        discoverable: bool = False,
        discovery_version: str = None,
        retries: int = 1,
        timeout: _datetime.timedelta = None,
    ):
        """
        :param inputs:
        :param outputs:
        :param image:
        :param command:
        :param args:
        :param storage_request:
        :param cpu_request:
        :param gpu_request:
        :param memory_request:
        :param storage_limit:
        :param cpu_limit:
        :param gpu_limit:
        :param memory_limit:
        :param environment:
        :param interruptible:
        :param discoverable:
        :param discovery_version:
        :param retries:
        :param timeout:
        :param input_data_dir: This is the directory where data will be downloaded to
        :param output_data_dir: This is the directory where data will be uploaded from
        :param metadata_format: Format in which the metadata will be available for the script
        """

        # Set as class fields which are used down below to configure implicit
        # parameters
        self._data_loading_config = _task_models.DataLoadingConfig(
            input_path=input_data_dir,
            output_path=output_data_dir,
            format=metadata_format,
            enabled=True,
            io_strategy=io_strategy,
        )

        metadata = _task_models.TaskMetadata(
            discoverable,
            # This needs to have the proper version reflected in it
            _task_models.RuntimeMetadata(
                _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                flytekit.__version__,
                "python",
            ),
            timeout or _datetime.timedelta(seconds=0),
            _literals.RetryStrategy(retries),
            interruptible,
            discovery_version,
            None,
        )

        # The interface is defined using the inputs and outputs
        i = _interface.TypedInterface(inputs=types_to_variable(inputs),
                                      outputs=types_to_variable(outputs))

        # This sets the base SDKTask with container etc
        super(SdkRawContainerTask, self).__init__(
            _constants.SdkTaskType.RAW_CONTAINER_TASK,
            metadata,
            i,
            None,
            container=_get_container_definition(
                image=image,
                args=args,
                command=command,
                data_loading_config=self._data_loading_config,
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment,
            ),
        )
示例#19
0
    def __init__(
        self,
        notebook_path,
        inputs,
        outputs,
        task_type,
        discovery_version,
        retries,
        deprecated,
        storage_request,
        cpu_request,
        gpu_request,
        memory_request,
        storage_limit,
        cpu_limit,
        gpu_limit,
        memory_limit,
        discoverable,
        timeout,
        environment,
        custom,
    ):

        if _os.path.isabs(notebook_path) is False:
            # Find absolute path for the notebook.
            task_module = _importlib.import_module(_find_instance_module())
            module_path = _os.path.dirname(task_module.__file__)
            notebook_path = _os.path.normpath(
                _os.path.join(module_path, notebook_path))

        self._notebook_path = notebook_path

        super(SdkNotebookTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__,
                    "notebook",
                ),
                timeout,
                _literal_models.RetryStrategy(retries),
                False,
                discovery_version,
                deprecated,
            ),
            _interface2.TypedInterface({}, {}),
            custom,
            container=self._get_container_definition(
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment,
            ),
        )
        # Add Inputs
        if inputs is not None:
            inputs(self)

        # Add outputs
        if outputs is not None:
            outputs(self)

        # Add a Notebook output as a Blob.
        self.interface.outputs.update(output_notebook=_interface.Variable(
            _Types.Blob.to_flyte_literal_type(), OUTPUT_NOTEBOOK))