예제 #1
0
    def __init__(self, inputs, outputs, nodes):
        """
        :param list[flytekit.common.promise.Input] inputs:
        :param list[Output] outputs:
        :param list[flytekit.common.nodes.SdkNode] nodes:
        """
        for n in nodes:
            for upstream in n.upstream_nodes:
                if upstream.id is None:
                    raise _user_exceptions.FlyteAssertion(
                        "Some nodes contained in the workflow were not found in the workflow description.  Please "
                        "ensure all nodes are either assigned to attributes within the class or an element in a "
                        "list, dict, or tuple which is stored as an attribute in the class."
                    )

        super(SdkWorkflow, self).__init__(
            id=_identifier.Identifier(_identifier_model.ResourceType.WORKFLOW,
                                      _internal_config.PROJECT.get(),
                                      _internal_config.DOMAIN.get(),
                                      _uuid.uuid4().hex,
                                      _internal_config.VERSION.get()),
            metadata=_workflow_models.WorkflowMetadata(),
            interface=_interface.TypedInterface(
                {v.name: v.var
                 for v in inputs}, {v.name: v.var
                                    for v in outputs}),
            nodes=nodes,
            outputs=[
                _literal_models.Binding(v.name, v.binding_data)
                for v in outputs
            ],
        )
        self._user_inputs = inputs
        self._upstream_entities = set(n.executable_sdk_object for n in nodes)
예제 #2
0
    def __init__(self,
                 inputs,
                 outputs,
                 nodes,
                 id=None,
                 metadata=None,
                 metadata_defaults=None,
                 interface=None,
                 output_bindings=None):
        """
        :param list[flytekit.common.promise.Input] inputs:
        :param list[Output] outputs:
        :param list[flytekit.common.nodes.SdkNode] nodes:
        :param flytekit.models.core.identifier.Identifier id: This is an autogenerated id by the system. The id is
            globally unique across Flyte.
        :param WorkflowMetadata metadata: This contains information on how to run the workflow.
        :param flytekit.models.interface.TypedInterface interface: Defines a strongly typed interface for the
            Workflow (inputs, outputs).  This can include some optional parameters.
        :param list[flytekit.models.literals.Binding] output_bindings: A list of output bindings that specify how to construct
            workflow outputs. Bindings can pull node outputs or specify literals. All workflow outputs specified in
            the interface field must be bound
            in order for the workflow to be validated. A workflow has an implicit dependency on all of its nodes
            to execute successfully in order to bind final outputs.

        """
        for n in nodes:
            for upstream in n.upstream_nodes:
                if upstream.id is None:
                    raise _user_exceptions.FlyteAssertion(
                        "Some nodes contained in the workflow were not found in the workflow description.  Please "
                        "ensure all nodes are either assigned to attributes within the class or an element in a "
                        "list, dict, or tuple which is stored as an attribute in the class."
                    )

        # Allow overrides if specified for all the arguments to the parent class constructor
        id = id if id is not None else _identifier.Identifier(
            _identifier_model.ResourceType.WORKFLOW,
            _internal_config.PROJECT.get(), _internal_config.DOMAIN.get(),
            _uuid.uuid4().hex, _internal_config.VERSION.get())
        metadata = metadata if metadata is not None else _workflow_models.WorkflowMetadata(
        )

        interface = interface if interface is not None else _interface.TypedInterface(
            {v.name: v.var
             for v in inputs}, {v.name: v.var
                                for v in outputs})

        output_bindings = output_bindings if output_bindings is not None else \
            [_literal_models.Binding(v.name, v.binding_data) for v in outputs]

        super(SdkWorkflow, self).__init__(
            id=id,
            metadata=metadata,
            metadata_defaults=_workflow_models.WorkflowMetadataDefaults(),
            interface=interface,
            nodes=nodes,
            outputs=output_bindings,
        )
        self._user_inputs = inputs
        self._upstream_entities = set(n.executable_sdk_object for n in nodes)
예제 #3
0
    def construct_from_class_definition(
        cls,
        inputs: List[_promise.Input],
        outputs: List[Output],
        nodes: List[_nodes.SdkNode],
        metadata: _workflow_models.WorkflowMetadata = None,
        metadata_defaults: _workflow_models.WorkflowMetadataDefaults = None,
        disable_default_launch_plan: bool = False,
    ) -> "SdkRunnableWorkflow":
        """
        This constructor is here to provide backwards-compatibility for class-defined Workflows

        :param list[flytekit.common.promise.Input] inputs:
        :param list[Output] outputs:
        :param list[flytekit.common.nodes.SdkNode] nodes:
        :param WorkflowMetadata metadata: This contains information on how to run the workflow.
        :param flytekit.models.core.workflow.WorkflowMetadataDefaults metadata_defaults: Defaults to be passed
            to nodes contained within workflow.
        :param bool disable_default_launch_plan: Determines whether to create a default launch plan for the workflow or not.

        :rtype: SdkRunnableWorkflow
        """
        for n in nodes:
            for upstream in n.upstream_nodes:
                if upstream.id is None:
                    raise _user_exceptions.FlyteAssertion(
                        "Some nodes contained in the workflow were not found in the workflow description.  Please "
                        "ensure all nodes are either assigned to attributes within the class or an element in a "
                        "list, dict, or tuple which is stored as an attribute in the class."
                    )

        id = _identifier.Identifier(
            _identifier_model.ResourceType.WORKFLOW,
            _internal_config.PROJECT.get(),
            _internal_config.DOMAIN.get(),
            _uuid.uuid4().hex,
            _internal_config.VERSION.get(),
        )
        interface = _interface.TypedInterface({v.name: v.var
                                               for v in inputs},
                                              {v.name: v.var
                                               for v in outputs})

        output_bindings = [
            _literal_models.Binding(v.name, v.binding_data) for v in outputs
        ]

        return cls(
            inputs=inputs,
            nodes=nodes,
            interface=interface,
            output_bindings=output_bindings,
            id=id,
            metadata=metadata,
            metadata_defaults=metadata_defaults,
            disable_default_launch_plan=disable_default_launch_plan,
        )
예제 #4
0
    def __init__(self, task_function, task_type, discovery_version, retries,
                 interruptible, deprecated, storage_request, cpu_request,
                 gpu_request, memory_request, storage_limit, cpu_limit,
                 gpu_limit, memory_limit, discoverable, timeout, environment,
                 custom):
        """
        :param task_function: Function container user code.  This will be executed via the SDK's engine.
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Specify whether task is interruptible
        :param Text deprecated:
        :param Text storage_request:
        :param Text cpu_request:
        :param Text gpu_request:
        :param Text memory_request:
        :param Text storage_limit:
        :param Text cpu_limit:
        :param Text gpu_limit:
        :param Text memory_limit:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param dict[Text, Text] environment:
        :param dict[Text, T] custom:
        """
        self._task_function = task_function

        super(SdkRunnableTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__, 'python'), timeout,
                _literal_models.RetryStrategy(retries), interruptible,
                discovery_version, deprecated),
            _interface.TypedInterface({}, {}),
            custom,
            container=self._get_container_definition(
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment))
        self.id._name = "{}.{}".format(self.task_module,
                                       self.task_function_name)
예제 #5
0
    def __init__(
        self,
        task_function,
        task_type,
        discovery_version,
        retries,
        interruptible,
        deprecated,
        storage_request,
        cpu_request,
        gpu_request,
        memory_request,
        storage_limit,
        cpu_limit,
        gpu_limit,
        memory_limit,
        discoverable,
        timeout,
        environment,
        custom,
    ):
        """
        :param task_function: Function container user code.  This will be executed via the SDK's engine.
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Specify whether task is interruptible
        :param Text deprecated:
        :param Text storage_request:
        :param Text cpu_request:
        :param Text gpu_request:
        :param Text memory_request:
        :param Text storage_limit:
        :param Text cpu_limit:
        :param Text gpu_limit:
        :param Text memory_limit:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param dict[Text, Text] environment:
        :param dict[Text, T] custom:
        """
        # Circular dependency
        from flytekit import __version__

        self._task_function = task_function
        super(SdkRunnableTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__,
                    "python",
                ),
                timeout,
                _literal_models.RetryStrategy(retries),
                interruptible,
                discovery_version,
                deprecated,
            ),
            # TODO: If we end up using SdkRunnableTask for the new code, make sure this is set correctly.
            _interface.TypedInterface({}, {}),
            custom,
            container=self._get_container_definition(
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment,
            ),
        )
        self.id._name = "{}.{}".format(self.task_module, self.task_function_name)
        self._has_fast_registered = False

        # TODO: Remove this in the future, I don't think we'll be using this.
        self._task_style = SdkRunnableTaskStyle.V0
예제 #6
0
    def __init__(
        self,
        statement,
        output_schema,
        routing_group=None,
        catalog=None,
        schema=None,
        task_inputs=None,
        interruptible=False,
        discoverable=False,
        discovery_version=None,
        retries=1,
        timeout=None,
        deprecated=None,
    ):
        """
        :param Text statement: Presto query specification
        :param flytekit.common.types.schema.Schema output_schema: Schema that represents that data queried from Presto
        :param Text routing_group: The routing group that a Presto query should be sent to for the given environment
        :param Text catalog: The catalog to set for the given Presto query
        :param Text schema: The schema to set for the given Presto query
        :param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] task_inputs: Optional inputs to the Presto task
        :param bool discoverable:
        :param Text discovery_version: String describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param datetime.timedelta timeout:
        :param Text deprecated: This string can be used to mark the task as deprecated.  Consumers of the task will
            receive deprecation warnings.
        """

        # Set as class fields which are used down below to configure implicit
        # parameters
        self._routing_group = routing_group or ""
        self._catalog = catalog or ""
        self._schema = schema or ""

        metadata = _task_model.TaskMetadata(
            discoverable,
            # This needs to have the proper version reflected in it
            _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"),
            timeout or _datetime.timedelta(seconds=0),
            _literals.RetryStrategy(retries),
            interruptible,
            discovery_version,
            deprecated,
        )

        presto_query = _presto_models.PrestoQuery(
            routing_group=routing_group or "",
            catalog=catalog or "",
            schema=schema or "",
            statement=statement,
        )

        # Here we set the routing_group, catalog, and schema as implicit
        # parameters for caching purposes
        i = _interface.TypedInterface(
            {
                "__implicit_routing_group": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The routing group set as an implicit input",
                ),
                "__implicit_catalog": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The catalog set as an implicit input",
                ),
                "__implicit_schema": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The schema set as an implicit input",
                ),
            },
            {
                # Set the schema for the Presto query as an output
                "results": _interface_model.Variable(
                    type=_types.LiteralType(schema=output_schema.schema_type),
                    description="The schema for the Presto query",
                )
            },
        )

        super(SdkPrestoTask, self).__init__(
            _constants.SdkTaskType.PRESTO_TASK,
            metadata,
            i,
            _MessageToDict(presto_query.to_flyte_idl()),
        )

        # Set user provided inputs
        task_inputs(self)
예제 #7
0
    def __init__(
        self,
        max_number_of_training_jobs: int,
        max_parallel_training_jobs: int,
        training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask,
                                   CustomTrainingJobTask],
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
        tunable_parameters: typing.List[str] = None,
    ):
        """
        :param int max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
        hyperparameter tuning job
        :param int max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
        tuning job in parallel
        :param typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask] training_job: The reference to the training job definition
        :param int retries: Number of retries to attempt
        :param bool cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param str cache_version: String describing the caching version for task discovery purposes
        :param typing.List[str] tunable_parameters: A list of parameters that to tune. If you are tuning a built-int
                algorithm, refer to the algorithm's documentation to understand the possible values for the tunable
                parameters. E.g. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html for the
                list of hyperparameters for Image Classification built-in algorithm. If you are passing a custom
                training job, the list of tunable parameters must be a strict subset of the list of inputs defined on
                that job. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html
                for the list of supported hyperparameter types.
        """
        # Use the training job model as a measure of type checking
        hpo_job = _hpo_job_model.HyperparameterTuningJob(
            max_number_of_training_jobs=max_number_of_training_jobs,
            max_parallel_training_jobs=max_parallel_training_jobs,
            training_job=training_job.training_job_model,
        ).to_flyte_idl()

        # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
        #   the underlying training job
        # TODO: Discuss whether this is a viable interface or contract
        timeout = _datetime.timedelta(seconds=0)

        inputs = {}
        inputs.update(training_job.interface.inputs)
        inputs.update({
            "hyperparameter_tuning_job_config":
            _interface_model.Variable(
                HyperparameterTuningJobConfig.to_flyte_literal_type(),
                "",
            ),
        })

        if tunable_parameters:
            inputs.update({
                param: _interface_model.Variable(
                    ParameterRange.to_flyte_literal_type(), "")
                for param in tunable_parameters
            })

        super().__init__(
            type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs=inputs,
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_types_models.LiteralType(
                            blob=_core_types.BlobType(
                                format="",
                                dimensionality=_core_types.BlobType.
                                BlobDimensionality.SINGLE,
                            )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(hpo_job),
        )
예제 #8
0
    def __init__(
        self,
        sdk_workflow,
        default_inputs=None,
        fixed_inputs=None,
        role=None,
        schedule=None,
        notifications=None,
        labels=None,
        annotations=None,
        auth=None,
    ):
        """
        :param flytekit.common.workflow.SdkWorkflow sdk_workflow:
        :param dict[Text,flytekit.common.promise.Input] default_inputs:
        :param dict[Text,Any] fixed_inputs: These inputs will be fixed and not need to be set when executing this
            launch plan.
        :param Text role: Deprecated. IAM role to execute this launch plan with.
        :param flytekit.models.schedule.Schedule: Schedule to apply to this workflow.
        :param list[flytekit.models.common.Notification]: List of notifications to apply to this launch plan.
        :param flytekit.models.common.Labels labels: Any custom kubernetes labels to apply to workflows executed by this
            launch plan.
        :param flytekit.models.common.Annotations annotations: Any custom kubernetes annotations to apply to workflows
            executed by this launch plan.
            Any custom kubernetes annotations to apply to workflows executed by this launch plan.
        :param flytekit.models.launch_plan.Auth auth: The auth method with which to execute the workflow.
        """
        if role and auth:
            raise ValueError(
                "Cannot set both role and auth. Role is deprecated, use auth instead."
            )

        fixed_inputs = fixed_inputs or {}
        default_inputs = default_inputs or {}

        if role:
            auth = _launch_plan_models.Auth(assumable_iam_role=role)

        super(SdkRunnableLaunchPlan, self).__init__(
            _identifier.Identifier(_identifier_model.ResourceType.WORKFLOW,
                                   _internal_config.PROJECT.get(),
                                   _internal_config.DOMAIN.get(),
                                   _uuid.uuid4().hex,
                                   _internal_config.VERSION.get()),
            _launch_plan_models.LaunchPlanMetadata(
                schedule=schedule or _schedule_model.Schedule(''),
                notifications=notifications or []),
            _interface_models.ParameterMap(default_inputs),
            _type_helpers.pack_python_std_map_to_literal_map(
                fixed_inputs, {
                    k: _type_helpers.get_sdk_type_from_literal_type(var.type)
                    for k, var in _six.iteritems(sdk_workflow.interface.inputs)
                    if k in fixed_inputs
                }),
            labels or _common_models.Labels({}),
            annotations or _common_models.Annotations({}),
            auth,
        )
        self._interface = _interface.TypedInterface(
            {k: v.var
             for k, v in _six.iteritems(default_inputs)},
            sdk_workflow.interface.outputs)
        self._upstream_entities = {sdk_workflow}
        self._sdk_workflow = sdk_workflow
    def __init__(
        self,
        training_job_resource_config: _training_job_models.
        TrainingJobResourceConfig,
        algorithm_specification: _training_job_models.AlgorithmSpecification,
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
    ):
        """

        :param training_job_resource_config: The options to configure the training job
        :param algorithm_specification: The options to configure the target algorithm of the training
        :param retries: Number of retries to attempt
        :param cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param cache_version: String describing the caching version for task discovery purposes
        """
        # Use the training job model as a measure of type checking
        self._training_job_model = _training_job_models.TrainingJob(
            algorithm_specification=algorithm_specification,
            training_job_resource_config=training_job_resource_config,
        )

        # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training
        # job gracefully
        timeout = _datetime.timedelta(seconds=0)

        super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__(
            type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs={
                    "static_hyperparameters":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(
                            simple=_idl_types.SimpleType.STRUCT),
                        description="",
                    ),
                    "train":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format=_content_type_to_blob_format(
                                algorithm_specification.input_content_type),
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.MULTIPART,
                        ), ),
                        description="",
                    ),
                    "validation":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format=_content_type_to_blob_format(
                                algorithm_specification.input_content_type),
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.MULTIPART,
                        ), ),
                        description="",
                    ),
                },
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format="",
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.SINGLE,
                        )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(self._training_job_model.to_flyte_idl()),
        )
예제 #10
0
    def __init__(
        self,
        task_type,
        discovery_version,
        retries,
        interruptible,
        task_inputs,
        deprecated,
        discoverable,
        timeout,
        jar_file,
        main_class,
        args,
        flink_properties,
        environment,
    ):
        """
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Whether or not task is interruptible
        :param Text deprecated:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param Text main_class: Main class to execute for Scala/Java jobs
        :param Text jar_file: fat jar file
        :param dict[Text,Text] flink_properties:
        :param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
        """

        flink_job = _task_models.FlinkJob(flink_properties=flink_properties,
                                          jar_file=jar_file,
                                          main_class=main_class,
                                          args=args).to_flyte_idl()

        super(SdkGenericFlinkTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__,
                    "flink",
                ),
                timeout,
                _literal_models.RetryStrategy(retries),
                interruptible,
                discovery_version,
                deprecated,
            ),
            _interface.TypedInterface({}, {}),
            _MessageToDict(flink_job),
        )

        # Add Inputs
        if task_inputs is not None:
            task_inputs(self)

        # Container after the Inputs have been updated.
        self._container = self._get_container_definition(
            environment=environment)
예제 #11
0
    def __init__(
        self,
        inputs: Dict[str, FlyteSdkType],
        image: str,
        outputs: Dict[str, FlyteSdkType] = None,
        input_data_dir: str = None,
        output_data_dir: str = None,
        metadata_format: int = METADATA_FORMAT_JSON,
        io_strategy: _task_models.IOStrategy = None,
        command: List[str] = None,
        args: List[str] = None,
        storage_request: str = None,
        cpu_request: str = None,
        gpu_request: str = None,
        memory_request: str = None,
        storage_limit: str = None,
        cpu_limit: str = None,
        gpu_limit: str = None,
        memory_limit: str = None,
        environment: Dict[str, str] = None,
        interruptible: bool = False,
        discoverable: bool = False,
        discovery_version: str = None,
        retries: int = 1,
        timeout: _datetime.timedelta = None,
    ):
        """
        :param inputs:
        :param outputs:
        :param image:
        :param command:
        :param args:
        :param storage_request:
        :param cpu_request:
        :param gpu_request:
        :param memory_request:
        :param storage_limit:
        :param cpu_limit:
        :param gpu_limit:
        :param memory_limit:
        :param environment:
        :param interruptible:
        :param discoverable:
        :param discovery_version:
        :param retries:
        :param timeout:
        :param input_data_dir: This is the directory where data will be downloaded to
        :param output_data_dir: This is the directory where data will be uploaded from
        :param metadata_format: Format in which the metadata will be available for the script
        """

        # Set as class fields which are used down below to configure implicit
        # parameters
        self._data_loading_config = _task_models.DataLoadingConfig(
            input_path=input_data_dir,
            output_path=output_data_dir,
            format=metadata_format,
            enabled=True,
            io_strategy=io_strategy,
        )

        metadata = _task_models.TaskMetadata(
            discoverable,
            # This needs to have the proper version reflected in it
            _task_models.RuntimeMetadata(
                _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                flytekit.__version__,
                "python",
            ),
            timeout or _datetime.timedelta(seconds=0),
            _literals.RetryStrategy(retries),
            interruptible,
            discovery_version,
            None,
        )

        # The interface is defined using the inputs and outputs
        i = _interface.TypedInterface(inputs=types_to_variable(inputs),
                                      outputs=types_to_variable(outputs))

        # This sets the base SDKTask with container etc
        super(SdkRawContainerTask, self).__init__(
            _constants.SdkTaskType.RAW_CONTAINER_TASK,
            metadata,
            i,
            None,
            container=_get_container_definition(
                image=image,
                args=args,
                command=command,
                data_loading_config=self._data_loading_config,
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment,
            ),
        )
예제 #12
0
    def __init__(
        self,
        notebook_path,
        inputs,
        outputs,
        task_type,
        discovery_version,
        retries,
        deprecated,
        storage_request,
        cpu_request,
        gpu_request,
        memory_request,
        storage_limit,
        cpu_limit,
        gpu_limit,
        memory_limit,
        discoverable,
        timeout,
        environment,
        custom,
    ):

        if _os.path.isabs(notebook_path) is False:
            # Find absolute path for the notebook.
            task_module = _importlib.import_module(_find_instance_module())
            module_path = _os.path.dirname(task_module.__file__)
            notebook_path = _os.path.normpath(
                _os.path.join(module_path, notebook_path))

        self._notebook_path = notebook_path

        super(SdkNotebookTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__,
                    "notebook",
                ),
                timeout,
                _literal_models.RetryStrategy(retries),
                False,
                discovery_version,
                deprecated,
            ),
            _interface2.TypedInterface({}, {}),
            custom,
            container=self._get_container_definition(
                storage_request=storage_request,
                cpu_request=cpu_request,
                gpu_request=gpu_request,
                memory_request=memory_request,
                storage_limit=storage_limit,
                cpu_limit=cpu_limit,
                gpu_limit=gpu_limit,
                memory_limit=memory_limit,
                environment=environment,
            ),
        )
        # Add Inputs
        if inputs is not None:
            inputs(self)

        # Add outputs
        if outputs is not None:
            outputs(self)

        # Add a Notebook output as a Blob.
        self.interface.outputs.update(output_notebook=_interface.Variable(
            _Types.Blob.to_flyte_literal_type(), OUTPUT_NOTEBOOK))
예제 #13
0
    def _op_to_task(self, dag_id, image, op, node_map):
        """
        Generate task given an operator inherited from dsl.ContainerOp.

        :param airflow.models.BaseOperator op:
        :param dict(Text, SdkNode) node_map:
        :rtype: Tuple(base_tasks.SdkTask, SdkNode)
        """

        interface_inputs = {}
        interface_outputs = {}
        input_mappings = {}
        processed_args = None

        # for key, val in six.iteritems(op.params):
        #     interface_inputs[key] = interface_model.Variable(
        #         _type_helpers.python_std_to_sdk_type(Types.String).to_flyte_literal_type(),
        #         ''
        #     )
        #
        #     if param.op_name == '':
        #         binding = promise_common.Input(sdk_type=Types.String, name=param.name)
        #     else:
        #         binding = promise_common.NodeOutput(
        #             sdk_node=node_map[param.op_name],
        #             sdk_type=Types.String,
        #             var=param.name)
        #     input_mappings[param.name] = binding
        #
        # for param in op.outputs.values():
        #     interface_outputs[param.name] = interface_model.Variable(
        #         _type_helpers.python_std_to_sdk_type(Types.String).to_flyte_literal_type(),
        #         ''
        #     )

        requests = []
        if op.resources:
            requests.append(
                task_model.Resources.ResourceEntry(
                    task_model.Resources.ResourceName.Cpu, op.resources.cpus))

            requests.append(
                task_model.Resources.ResourceEntry(
                    task_model.Resources.ResourceName.Memory,
                    op.resources.ram))

            requests.append(
                task_model.Resources.ResourceEntry(
                    task_model.Resources.ResourceName.Gpu, op.resources.gpus))

            requests.append(
                task_model.Resources.ResourceEntry(
                    task_model.Resources.ResourceName.Storage,
                    op.resources.disk))

        task_instance = TaskInstance(op, datetime.datetime.now())
        command = task_instance.command_as_list(local=True,
                                                mark_success=False,
                                                ignore_all_deps=True,
                                                ignore_depends_on_past=True,
                                                ignore_task_deps=True,
                                                ignore_ti_state=True,
                                                pool=task_instance.pool,
                                                pickle_id=dag_id,
                                                cfg_path=None)

        task = base_tasks.SdkTask(
            op.task_id,
            SingleStepTask,
            "airflow_op",
            task_model.TaskMetadata(
                False,
                task_model.RuntimeMetadata(
                    type=task_model.RuntimeMetadata.RuntimeType.Other,
                    version=airflow.version.version,
                    flavor='airflow'),
                datetime.timedelta(seconds=0),
                literals_model.RetryStrategy(0),
                '1',
                None,
            ),
            interface_common.TypedInterface(inputs=interface_inputs,
                                            outputs=interface_outputs),
            custom=None,
            container=task_model.Container(
                image=image,
                command=command,
                args=[],
                resources=task_model.Resources(limits=[], requests=requests),
                env={},
                config={},
            ))

        return task, task(**input_mappings).assign_id_and_return(op.task_id)
예제 #14
0
    def __init__(
        self,
        max_number_of_training_jobs: int,
        max_parallel_training_jobs: int,
        training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask,
                                   CustomTrainingJobTask],
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
    ):
        """

        :param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
        hyperparameter tuning job
        :param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
        tuning job in parallel
        :param training_job: The reference to the training job definition
        :param retries: Number of retries to attempt
        :param cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param cache_version: String describing the caching version for task discovery purposes
        """
        # Use the training job model as a measure of type checking
        hpo_job = _hpo_job_model.HyperparameterTuningJob(
            max_number_of_training_jobs=max_number_of_training_jobs,
            max_parallel_training_jobs=max_parallel_training_jobs,
            training_job=training_job.training_job_model,
        ).to_flyte_idl()

        # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
        #   the underlying training job
        # TODO: Discuss whether this is a viable interface or contract
        timeout = _datetime.timedelta(seconds=0)

        inputs = {
            "hyperparameter_tuning_job_config":
            _interface_model.Variable(
                _sdk_types.Types.Proto(
                    _pb2_hpo_job.HyperparameterTuningJobConfig).
                to_flyte_literal_type(),
                "",
            ),
        }
        inputs.update(training_job.interface.inputs)

        super(SdkSimpleHyperparameterTuningJobTask, self).__init__(
            type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs=inputs,
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_types_models.LiteralType(
                            blob=_core_types.BlobType(
                                format="",
                                dimensionality=_core_types.BlobType.
                                BlobDimensionality.SINGLE,
                            )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(hpo_job),
        )
예제 #15
0
    def __init__(
        self,
        sdk_workflow,
        default_inputs=None,
        fixed_inputs=None,
        role=None,
        schedule=None,
        notifications=None,
        labels=None,
        annotations=None,
        auth_role=None,
        raw_output_data_config=None,
    ):
        """
        :param flytekit.common.local_workflow.SdkRunnableWorkflow sdk_workflow:
        :param dict[Text,flytekit.common.promise.Input] default_inputs:
        :param dict[Text,Any] fixed_inputs: These inputs will be fixed and not need to be set when executing this
            launch plan.
        :param Text role: Deprecated. IAM role to execute this launch plan with.
        :param flytekit.models.schedule.Schedule: Schedule to apply to this workflow.
        :param list[flytekit.models.common.Notification]: List of notifications to apply to this launch plan.
        :param flytekit.models.common.Labels labels: Any custom kubernetes labels to apply to workflows executed by this
            launch plan.
        :param flytekit.models.common.Annotations annotations: Any custom kubernetes annotations to apply to workflows
            executed by this launch plan.
            Any custom kubernetes annotations to apply to workflows executed by this launch plan.
        :param flytekit.models.common.Authrole auth_role: The auth method with which to execute the workflow.
        :param flytekit.models.common.RawOutputDataConfig raw_output_data_config: Config for offloading data
        """
        if role and auth_role:
            raise ValueError(
                "Cannot set both role and auth. Role is deprecated, use auth instead."
            )

        fixed_inputs = fixed_inputs or {}
        default_inputs = default_inputs or {}

        if role:
            auth_role = _common_models.AuthRole(assumable_iam_role=role)

        # The constructor for SdkLaunchPlan sets the id to None anyways so we don't bother passing in an ID. The ID
        # should be set in one of three places,
        #   1) When the object is registered (in the code above)
        #   2) By the dynamic task code after this runnable object has already been __call__'ed. The SdkNode produced
        #      maintains a link to this object and will set the ID according to the configuration variables present.
        #   3) When SdkLaunchPlan.fetch() is run
        super(SdkRunnableLaunchPlan, self).__init__(
            None,
            _launch_plan_models.LaunchPlanMetadata(
                schedule=schedule or _schedule_model.Schedule(""),
                notifications=notifications or [],
            ),
            _interface_models.ParameterMap(default_inputs),
            _type_helpers.pack_python_std_map_to_literal_map(
                fixed_inputs,
                {
                    k: _type_helpers.get_sdk_type_from_literal_type(var.type)
                    for k, var in _six.iteritems(sdk_workflow.interface.inputs)
                    if k in fixed_inputs
                },
            ),
            labels or _common_models.Labels({}),
            annotations or _common_models.Annotations({}),
            auth_role,
            raw_output_data_config or _common_models.RawOutputDataConfig(""),
        )
        self._interface = _interface.TypedInterface(
            {k: v.var
             for k, v in _six.iteritems(default_inputs)},
            sdk_workflow.interface.outputs,
        )
        self._upstream_entities = {sdk_workflow}
        self._sdk_workflow = sdk_workflow
예제 #16
0
    def __init__(
        self,
        region,
        role_arn,
        resource_config,
        algorithm_specification=None,
        stopping_condition=None,
        vpc_config=None,
        enable_spot_training=False,
        interruptible=False,
        retries=0,
        cacheable=False,
        cache_version="",
    ):
        """
        :param Text region: The region in which to run the SageMaker job.
        :param Text role_arn: The ARN of the role to run in the SageMaker job.
        :param dict[Text,T] algorithm_specification: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html
        :param dict[Text,T] resource_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ResourceConfig.html
        :param dict[Text,T] stopping_condition: https://docs.aws.amazon.com/sagemaker/latest/dg/API_StoppingCondition.html
        :param dict[Text,T] vpc_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html
        :param bool enable_spot_training: https://docs.aws.amazon.com/sagemaker/latest/dg/API_HyperParameterTrainingJobDefinition.html
        :param int retries: Number of time to retry.
        :param bool cacheable: Whether or not to use Flyte's caching system.
        :param Text cache_version: Update this to notify a behavioral change requiring the cache to be invalidated.
        """

        algorithm_specification = algorithm_specification or {}
        algorithm_specification["TrainingImage"] = (
            algorithm_specification.get("TrainingImage")
            or "825641698319.dkr.ecr.us-east-2.amazonaws.com/xgboost:1")
        algorithm_specification["TrainingInputMode"] = "File"

        job_config = ParseDict(
            {
                "Region": region,
                "ResourceConfig": resource_config,
                "StoppingCondition": stopping_condition,
                "VpcConfig": vpc_config,
                "AlgorithmSpecification": algorithm_specification,
                "RoleArn": role_arn,
            },
            sagemaker_pb2.SagemakerHPOJob(),
        )
        print(MessageToDict(job_config))

        # TODO: Optionally, pull timeout behavior from stopping condition and pass to Flyte task def.
        timeout = _datetime.timedelta(seconds=0)

        # TODO: The FlyteKit type engine is extensible so we can create a SagemakerInput type with custom
        # TODO:     parsing/casting logic. For now, we will use the Generic type since there is a little that needs
        # TODO:     to be done on Flyte side to unlock this cleanly.
        # TODO: This call to the super-constructor will be less verbose in future versions of Flytekit following a
        # TODO:     refactor.
        # TODO: Add more configurations to the custom dict. These are things that are necessary to execute the task,
        # TODO:     but might not affect the outputs (i.e. Running on a bigger machine). These are currently static for
        # TODO:     a given definition of a task, but will be more dynamic in the future. Also, it is possible to
        # TODO:     make it dynamic by using our @dynamic_task.
        # TODO: You might want to inherit the role ARN from the execution at runtime.
        super(SagemakerXgBoostOptimizer, self).__init__(
            type=_TASK_TYPE,
            metadata=_task_models.TaskMetadata(
                discoverable=cacheable,
                runtime=_task_models.RuntimeMetadata(0, "0.1.0b0",
                                                     "sagemaker"),
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=interruptible,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface({}, {}),
            custom=MessageToDict(job_config),
        )

        # TODO: Add more inputs that we expect to change the outputs of the task.
        # TODO: We can add outputs too!
        # We use helper methods for adding to interface, thus overriding the one set above. This will be simplified post
        # refactor.
        self.add_inputs({
            "static_hyperparameters":
            _interface_model.Variable(
                _sdk_types.Types.Generic.to_flyte_literal_type(), ""),
            "train":
            _interface_model.Variable(
                _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""),
            "validation":
            _interface_model.Variable(
                _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""),
        })
        self.add_outputs({
            "model":
            _interface_model.Variable(
                _sdk_types.Types.Blob.to_flyte_literal_type(), "")
        })
예제 #17
0
    def __init__(
        self,
        task_type,
        discovery_version,
        retries,
        interruptible,
        task_inputs,
        deprecated,
        discoverable,
        timeout,
        spark_type,
        main_class,
        main_application_file,
        spark_conf,
        hadoop_conf,
        environment,
    ):
        """
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Whether or not task is interruptible
        :param Text deprecated:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param Text spark_type: Type of Spark Job: Scala/Java
        :param Text main_class: Main class to execute for Scala/Java jobs
        :param Text main_application_file: Main application file
        :param dict[Text,Text] spark_conf:
        :param dict[Text,Text] hadoop_conf:
        :param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
        """

        spark_job = _task_models.SparkJob(
            spark_conf=spark_conf,
            hadoop_conf=hadoop_conf,
            spark_type=spark_type,
            application_file=main_application_file,
            main_class=main_class,
            executor_path=_sys.executable,
        ).to_flyte_idl()

        super(SdkGenericSparkTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__, 'spark'), timeout,
                _literal_models.RetryStrategy(retries), interruptible,
                discovery_version, deprecated),
            _interface.TypedInterface({}, {}),
            _MessageToDict(spark_job),
        )

        # Add Inputs
        if task_inputs is not None:
            task_inputs(self)

        # Container after the Inputs have been updated.
        self._container = self._get_container_definition(
            environment=environment)