def test_runtime_metadata(): obj = task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python") assert obj.type == task.RuntimeMetadata.RuntimeType.FLYTE_SDK assert obj.version == "1.0.0" assert obj.flavor == "python" assert obj == task.RuntimeMetadata.from_flyte_idl(obj.to_flyte_idl()) assert obj != task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.1", "python") assert obj != task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.OTHER, "1.0.0", "python") assert obj != task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "golang")
def test_task_template__k8s_pod_target(): int_type = types.LiteralType(types.SimpleType.INTEGER) obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task.TaskMetadata( False, task.RuntimeMetadata(1, "v", "f"), timedelta(days=1), literal_models.RetryStrategy(5), False, "1.0", "deprecated", False, ), interface_models.TypedInterface( # inputs {"a": interface_models.Variable(int_type, "description1")}, # outputs { "b": interface_models.Variable(int_type, "description2"), "c": interface_models.Variable(int_type, "description3"), }, ), { "a": 1, "b": { "c": 2, "d": 3 } }, config={"a": "b"}, k8s_pod=task.K8sPod( metadata=task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}), pod_spec={ "str": "val", "int": 1 }, ), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" assert obj.id.domain == "domain" assert obj.id.name == "name" assert obj.id.version == "version" assert obj.type == "python" assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}} assert obj.k8s_pod.metadata == task.K8sObjectMetadata( labels={"label": "foo"}, annotations={"anno": "bar"}) assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1} assert text_format.MessageToString( obj.to_flyte_idl()) == text_format.MessageToString( task.TaskTemplate.from_flyte_idl( obj.to_flyte_idl()).to_flyte_idl()) assert obj.config == {"a": "b"}
def get_sample_task_metadata(): """ :rtype: flytekit.models.task.TaskMetadata """ return _task_model.TaskMetadata( True, _task_model.RuntimeMetadata( _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), _literals.RetryStrategy(3), True, "0.1.1b0", "This is deprecated!")
def to_taskmetadata_model(self) -> _task_model.TaskMetadata: """ Converts to _task_model.TaskMetadata """ return _task_model.TaskMetadata( discoverable=self.cache, # TODO Fix the version circular dependency before beta runtime=_task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, "0.16.0", "python"), timeout=self.timeout, retries=self.retry_strategy, interruptible=self.interruptable, discovery_version=self.cache_version, deprecated_error_message=self.deprecated, )
def __init__(self, task_function, task_type, discovery_version, retries, interruptible, deprecated, storage_request, cpu_request, gpu_request, memory_request, storage_limit, cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, custom): """ :param task_function: Function container user code. This will be executed via the SDK's engine. :param Text task_type: string describing the task type :param Text discovery_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param bool interruptible: Specify whether task is interruptible :param Text deprecated: :param Text storage_request: :param Text cpu_request: :param Text gpu_request: :param Text memory_request: :param Text storage_limit: :param Text cpu_limit: :param Text gpu_limit: :param Text memory_limit: :param bool discoverable: :param datetime.timedelta timeout: :param dict[Text, Text] environment: :param dict[Text, T] custom: """ self._task_function = task_function super(SdkRunnableTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, 'python'), timeout, _literal_models.RetryStrategy(retries), interruptible, discovery_version, deprecated), _interface.TypedInterface({}, {}), custom, container=self._get_container_definition( storage_request=storage_request, cpu_request=cpu_request, gpu_request=gpu_request, memory_request=memory_request, storage_limit=storage_limit, cpu_limit=cpu_limit, gpu_limit=gpu_limit, memory_limit=memory_limit, environment=environment)) self.id._name = "{}.{}".format(self.task_module, self.task_function_name)
def test_task_metadata(): obj = task.TaskMetadata( True, task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!") assert obj.discoverable is True assert obj.retries.retries == 3 assert obj.timeout == timedelta(days=1) assert obj.runtime.flavor == "python" assert obj.runtime.type == task.RuntimeMetadata.RuntimeType.FLYTE_SDK assert obj.runtime.version == "1.0.0" assert obj.deprecated_error_message == "This is deprecated!" assert obj.discovery_version == "0.1.1b0" assert obj == task.TaskMetadata.from_flyte_idl(obj.to_flyte_idl())
def to_taskmetadata_model(self) -> _task_model.TaskMetadata: """ Converts to _task_model.TaskMetadata """ from flytekit import __version__ return _task_model.TaskMetadata( discoverable=self.cache, runtime=_task_model.RuntimeMetadata( _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"), timeout=self.timeout, retries=self.retry_strategy, interruptible=self.interruptable, discovery_version=self.cache_version, deprecated_error_message=self.deprecated, )
def test_workflow_closure(): int_type = _types.LiteralType(_types.SimpleType.INTEGER) typed_interface = _interface.TypedInterface( {'a': _interface.Variable(int_type, "description1")}, { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) b0 = _literals.Binding( 'a', _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5)))) b1 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'b'))) b2 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'c'))) node_metadata = _workflow.NodeMetadata(name='node1', timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) task_metadata = _task.TaskMetadata( True, _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!") cpu_resource = _task.Resources.ResourceEntry( _task.Resources.ResourceName.CPU, "1") resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource]) task = _task.TaskTemplate( _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, typed_interface, { 'a': 1, 'b': { 'c': 2, 'd': 3 } }, container=_task.Container("my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {})) task_node = _workflow.TaskNode(task.id) node = _workflow.Node(id='my_node', metadata=node_metadata, inputs=[b0], upstream_node_ids=[], output_aliases=[], task_node=task_node) template = _workflow.WorkflowTemplate( id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version"), metadata=_workflow.WorkflowMetadata(), interface=typed_interface, nodes=[node], outputs=[b1, b2], ) obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task]) assert len(obj.tasks) == 1 obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2
def __init__( self, max_number_of_training_jobs: int, max_parallel_training_jobs: int, training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask], retries: int = 0, cacheable: bool = False, cache_version: str = "", tunable_parameters: typing.List[str] = None, ): """ :param int max_number_of_training_jobs: The maximum number of training jobs that can be launched by this hyperparameter tuning job :param int max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter tuning job in parallel :param typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask] training_job: The reference to the training job definition :param int retries: Number of retries to attempt :param bool cacheable: The flag to set if the user wants the output of the task execution to be cached :param str cache_version: String describing the caching version for task discovery purposes :param typing.List[str] tunable_parameters: A list of parameters that to tune. If you are tuning a built-int algorithm, refer to the algorithm's documentation to understand the possible values for the tunable parameters. E.g. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html for the list of hyperparameters for Image Classification built-in algorithm. If you are passing a custom training job, the list of tunable parameters must be a strict subset of the list of inputs defined on that job. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html for the list of supported hyperparameter types. """ # Use the training job model as a measure of type checking hpo_job = _hpo_job_model.HyperparameterTuningJob( max_number_of_training_jobs=max_number_of_training_jobs, max_parallel_training_jobs=max_parallel_training_jobs, training_job=training_job.training_job_model, ).to_flyte_idl() # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of # the underlying training job # TODO: Discuss whether this is a viable interface or contract timeout = _datetime.timedelta(seconds=0) inputs = {} inputs.update(training_job.interface.inputs) inputs.update({ "hyperparameter_tuning_job_config": _interface_model.Variable( HyperparameterTuningJobConfig.to_flyte_literal_type(), "", ), }) if tunable_parameters: inputs.update({ param: _interface_model.Variable( ParameterRange.to_flyte_literal_type(), "") for param in tunable_parameters }) super().__init__( type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, retries=_literal_models.RetryStrategy(retries=retries), interruptible=False, discovery_version=cache_version, deprecated_error_message="", ), interface=_interface.TypedInterface( inputs=inputs, outputs={ "model": _interface_model.Variable( type=_types_models.LiteralType( blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType. BlobDimensionality.SINGLE, )), description="", ) }, ), custom=MessageToDict(hpo_job), )
def __init__( self, max_number_of_training_jobs: int, max_parallel_training_jobs: int, training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask], retries: int = 0, cacheable: bool = False, cache_version: str = "", ): """ :param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this hyperparameter tuning job :param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter tuning job in parallel :param training_job: The reference to the training job definition :param retries: Number of retries to attempt :param cacheable: The flag to set if the user wants the output of the task execution to be cached :param cache_version: String describing the caching version for task discovery purposes """ # Use the training job model as a measure of type checking hpo_job = _hpo_job_model.HyperparameterTuningJob( max_number_of_training_jobs=max_number_of_training_jobs, max_parallel_training_jobs=max_parallel_training_jobs, training_job=training_job.training_job_model, ).to_flyte_idl() # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of # the underlying training job # TODO: Discuss whether this is a viable interface or contract timeout = _datetime.timedelta(seconds=0) inputs = { "hyperparameter_tuning_job_config": _interface_model.Variable( _sdk_types.Types.Proto( _pb2_hpo_job.HyperparameterTuningJobConfig). to_flyte_literal_type(), "", ), } inputs.update(training_job.interface.inputs) super(SdkSimpleHyperparameterTuningJobTask, self).__init__( type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, retries=_literal_models.RetryStrategy(retries=retries), interruptible=False, discovery_version=cache_version, deprecated_error_message="", ), interface=_interface.TypedInterface( inputs=inputs, outputs={ "model": _interface_model.Variable( type=_types_models.LiteralType( blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType. BlobDimensionality.SINGLE, )), description="", ) }, ), custom=MessageToDict(hpo_job), )
def __init__( self, region, role_arn, resource_config, algorithm_specification=None, stopping_condition=None, vpc_config=None, enable_spot_training=False, interruptible=False, retries=0, cacheable=False, cache_version="", ): """ :param Text region: The region in which to run the SageMaker job. :param Text role_arn: The ARN of the role to run in the SageMaker job. :param dict[Text,T] algorithm_specification: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html :param dict[Text,T] resource_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ResourceConfig.html :param dict[Text,T] stopping_condition: https://docs.aws.amazon.com/sagemaker/latest/dg/API_StoppingCondition.html :param dict[Text,T] vpc_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html :param bool enable_spot_training: https://docs.aws.amazon.com/sagemaker/latest/dg/API_HyperParameterTrainingJobDefinition.html :param int retries: Number of time to retry. :param bool cacheable: Whether or not to use Flyte's caching system. :param Text cache_version: Update this to notify a behavioral change requiring the cache to be invalidated. """ algorithm_specification = algorithm_specification or {} algorithm_specification["TrainingImage"] = ( algorithm_specification.get("TrainingImage") or "825641698319.dkr.ecr.us-east-2.amazonaws.com/xgboost:1") algorithm_specification["TrainingInputMode"] = "File" job_config = ParseDict( { "Region": region, "ResourceConfig": resource_config, "StoppingCondition": stopping_condition, "VpcConfig": vpc_config, "AlgorithmSpecification": algorithm_specification, "RoleArn": role_arn, }, sagemaker_pb2.SagemakerHPOJob(), ) print(MessageToDict(job_config)) # TODO: Optionally, pull timeout behavior from stopping condition and pass to Flyte task def. timeout = _datetime.timedelta(seconds=0) # TODO: The FlyteKit type engine is extensible so we can create a SagemakerInput type with custom # TODO: parsing/casting logic. For now, we will use the Generic type since there is a little that needs # TODO: to be done on Flyte side to unlock this cleanly. # TODO: This call to the super-constructor will be less verbose in future versions of Flytekit following a # TODO: refactor. # TODO: Add more configurations to the custom dict. These are things that are necessary to execute the task, # TODO: but might not affect the outputs (i.e. Running on a bigger machine). These are currently static for # TODO: a given definition of a task, but will be more dynamic in the future. Also, it is possible to # TODO: make it dynamic by using our @dynamic_task. # TODO: You might want to inherit the role ARN from the execution at runtime. super(SagemakerXgBoostOptimizer, self).__init__( type=_TASK_TYPE, metadata=_task_models.TaskMetadata( discoverable=cacheable, runtime=_task_models.RuntimeMetadata(0, "0.1.0b0", "sagemaker"), timeout=timeout, retries=_literal_models.RetryStrategy(retries=retries), interruptible=interruptible, discovery_version=cache_version, deprecated_error_message="", ), interface=_interface.TypedInterface({}, {}), custom=MessageToDict(job_config), ) # TODO: Add more inputs that we expect to change the outputs of the task. # TODO: We can add outputs too! # We use helper methods for adding to interface, thus overriding the one set above. This will be simplified post # refactor. self.add_inputs({ "static_hyperparameters": _interface_model.Variable( _sdk_types.Types.Generic.to_flyte_literal_type(), ""), "train": _interface_model.Variable( _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""), "validation": _interface_model.Variable( _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""), }) self.add_outputs({ "model": _interface_model.Variable( _sdk_types.Types.Blob.to_flyte_literal_type(), "") })
def _op_to_task(self, dag_id, image, op, node_map): """ Generate task given an operator inherited from dsl.ContainerOp. :param airflow.models.BaseOperator op: :param dict(Text, SdkNode) node_map: :rtype: Tuple(base_tasks.SdkTask, SdkNode) """ interface_inputs = {} interface_outputs = {} input_mappings = {} processed_args = None # for key, val in six.iteritems(op.params): # interface_inputs[key] = interface_model.Variable( # _type_helpers.python_std_to_sdk_type(Types.String).to_flyte_literal_type(), # '' # ) # # if param.op_name == '': # binding = promise_common.Input(sdk_type=Types.String, name=param.name) # else: # binding = promise_common.NodeOutput( # sdk_node=node_map[param.op_name], # sdk_type=Types.String, # var=param.name) # input_mappings[param.name] = binding # # for param in op.outputs.values(): # interface_outputs[param.name] = interface_model.Variable( # _type_helpers.python_std_to_sdk_type(Types.String).to_flyte_literal_type(), # '' # ) requests = [] if op.resources: requests.append( task_model.Resources.ResourceEntry( task_model.Resources.ResourceName.Cpu, op.resources.cpus)) requests.append( task_model.Resources.ResourceEntry( task_model.Resources.ResourceName.Memory, op.resources.ram)) requests.append( task_model.Resources.ResourceEntry( task_model.Resources.ResourceName.Gpu, op.resources.gpus)) requests.append( task_model.Resources.ResourceEntry( task_model.Resources.ResourceName.Storage, op.resources.disk)) task_instance = TaskInstance(op, datetime.datetime.now()) command = task_instance.command_as_list(local=True, mark_success=False, ignore_all_deps=True, ignore_depends_on_past=True, ignore_task_deps=True, ignore_ti_state=True, pool=task_instance.pool, pickle_id=dag_id, cfg_path=None) task = base_tasks.SdkTask( op.task_id, SingleStepTask, "airflow_op", task_model.TaskMetadata( False, task_model.RuntimeMetadata( type=task_model.RuntimeMetadata.RuntimeType.Other, version=airflow.version.version, flavor='airflow'), datetime.timedelta(seconds=0), literals_model.RetryStrategy(0), '1', None, ), interface_common.TypedInterface(inputs=interface_inputs, outputs=interface_outputs), custom=None, container=task_model.Container( image=image, command=command, args=[], resources=task_model.Resources(limits=[], requests=requests), env={}, config={}, )) return task, task(**input_mappings).assign_id_and_return(op.task_id)
def __init__( self, task_function, task_type, discovery_version, retries, interruptible, deprecated, storage_request, cpu_request, gpu_request, memory_request, storage_limit, cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, custom, ): """ :param task_function: Function container user code. This will be executed via the SDK's engine. :param Text task_type: string describing the task type :param Text discovery_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param bool interruptible: Specify whether task is interruptible :param Text deprecated: :param Text storage_request: :param Text cpu_request: :param Text gpu_request: :param Text memory_request: :param Text storage_limit: :param Text cpu_limit: :param Text gpu_limit: :param Text memory_limit: :param bool discoverable: :param datetime.timedelta timeout: :param dict[Text, Text] environment: :param dict[Text, T] custom: """ # Circular dependency from flytekit import __version__ self._task_function = task_function super(SdkRunnableTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python", ), timeout, _literal_models.RetryStrategy(retries), interruptible, discovery_version, deprecated, ), # TODO: If we end up using SdkRunnableTask for the new code, make sure this is set correctly. _interface.TypedInterface({}, {}), custom, container=self._get_container_definition( storage_request=storage_request, cpu_request=cpu_request, gpu_request=gpu_request, memory_request=memory_request, storage_limit=storage_limit, cpu_limit=cpu_limit, gpu_limit=gpu_limit, memory_limit=memory_limit, environment=environment, ), ) self.id._name = "{}.{}".format(self.task_module, self.task_function_name) self._has_fast_registered = False # TODO: Remove this in the future, I don't think we'll be using this. self._task_style = SdkRunnableTaskStyle.V0
def __init__( self, statement, output_schema, routing_group=None, catalog=None, schema=None, task_inputs=None, interruptible=False, discoverable=False, discovery_version=None, retries=1, timeout=None, deprecated=None, ): """ :param Text statement: Presto query specification :param flytekit.common.types.schema.Schema output_schema: Schema that represents that data queried from Presto :param Text routing_group: The routing group that a Presto query should be sent to for the given environment :param Text catalog: The catalog to set for the given Presto query :param Text schema: The schema to set for the given Presto query :param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] task_inputs: Optional inputs to the Presto task :param bool discoverable: :param Text discovery_version: String describing the version for task discovery purposes :param int retries: Number of retries to attempt :param datetime.timedelta timeout: :param Text deprecated: This string can be used to mark the task as deprecated. Consumers of the task will receive deprecation warnings. """ # Set as class fields which are used down below to configure implicit # parameters self._routing_group = routing_group or "" self._catalog = catalog or "" self._schema = schema or "" metadata = _task_model.TaskMetadata( discoverable, # This needs to have the proper version reflected in it _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"), timeout or _datetime.timedelta(seconds=0), _literals.RetryStrategy(retries), interruptible, discovery_version, deprecated, ) presto_query = _presto_models.PrestoQuery( routing_group=routing_group or "", catalog=catalog or "", schema=schema or "", statement=statement, ) # Here we set the routing_group, catalog, and schema as implicit # parameters for caching purposes i = _interface.TypedInterface( { "__implicit_routing_group": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), description="The routing group set as an implicit input", ), "__implicit_catalog": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), description="The catalog set as an implicit input", ), "__implicit_schema": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), description="The schema set as an implicit input", ), }, { # Set the schema for the Presto query as an output "results": _interface_model.Variable( type=_types.LiteralType(schema=output_schema.schema_type), description="The schema for the Presto query", ) }, ) super(SdkPrestoTask, self).__init__( _constants.SdkTaskType.PRESTO_TASK, metadata, i, _MessageToDict(presto_query.to_flyte_idl()), ) # Set user provided inputs task_inputs(self)
LIST_OF_RESOURCE_ENTRIES = [ task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "1"), task.Resources.ResourceEntry(task.Resources.ResourceName.GPU, "1"), task.Resources.ResourceEntry(task.Resources.ResourceName.MEMORY, "1G"), task.Resources.ResourceEntry(task.Resources.ResourceName.STORAGE, "1G") ] LIST_OF_RESOURCE_ENTRY_LISTS = [LIST_OF_RESOURCE_ENTRIES] LIST_OF_RESOURCES = [ task.Resources(request, limit) for request, limit in product( LIST_OF_RESOURCE_ENTRY_LISTS, LIST_OF_RESOURCE_ENTRY_LISTS) ] LIST_OF_RUNTIME_METADATA = [ task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.OTHER, "1.0.0", "python"), task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0b0", "golang") ] LIST_OF_RETRY_POLICIES = [ literals.RetryStrategy(retries=i) for i in [0, 1, 3, 100] ] LIST_OF_INTERRUPTIBLE = [None, True, False] LIST_OF_TASK_METADATA = [ task.TaskMetadata(discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated) for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated in
def __init__( self, notebook_path, inputs, outputs, task_type, discovery_version, retries, deprecated, storage_request, cpu_request, gpu_request, memory_request, storage_limit, cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, custom, ): if _os.path.isabs(notebook_path) is False: # Find absolute path for the notebook. task_module = _importlib.import_module(_find_instance_module()) module_path = _os.path.dirname(task_module.__file__) notebook_path = _os.path.normpath( _os.path.join(module_path, notebook_path)) self._notebook_path = notebook_path super(SdkNotebookTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "notebook", ), timeout, _literal_models.RetryStrategy(retries), False, discovery_version, deprecated, ), _interface2.TypedInterface({}, {}), custom, container=self._get_container_definition( storage_request=storage_request, cpu_request=cpu_request, gpu_request=gpu_request, memory_request=memory_request, storage_limit=storage_limit, cpu_limit=cpu_limit, gpu_limit=gpu_limit, memory_limit=memory_limit, environment=environment, ), ) # Add Inputs if inputs is not None: inputs(self) # Add outputs if outputs is not None: outputs(self) # Add a Notebook output as a Blob. self.interface.outputs.update(output_notebook=_interface.Variable( _Types.Blob.to_flyte_literal_type(), OUTPUT_NOTEBOOK))
def __init__( self, training_job_resource_config: _training_job_models. TrainingJobResourceConfig, algorithm_specification: _training_job_models.AlgorithmSpecification, retries: int = 0, cacheable: bool = False, cache_version: str = "", ): """ :param training_job_resource_config: The options to configure the training job :param algorithm_specification: The options to configure the target algorithm of the training :param retries: Number of retries to attempt :param cacheable: The flag to set if the user wants the output of the task execution to be cached :param cache_version: String describing the caching version for task discovery purposes """ # Use the training job model as a measure of type checking self._training_job_model = _training_job_models.TrainingJob( algorithm_specification=algorithm_specification, training_job_resource_config=training_job_resource_config, ) # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training # job gracefully timeout = _datetime.timedelta(seconds=0) super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__( type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, retries=_literal_models.RetryStrategy(retries=retries), interruptible=False, discovery_version=cache_version, deprecated_error_message="", ), interface=_interface.TypedInterface( inputs={ "static_hyperparameters": _interface_model.Variable( type=_idl_types.LiteralType( simple=_idl_types.SimpleType.STRUCT), description="", ), "train": _interface_model.Variable( type=_idl_types.LiteralType(blob=_core_types.BlobType( format=_content_type_to_blob_format( algorithm_specification.input_content_type), dimensionality=_core_types.BlobType. BlobDimensionality.MULTIPART, ), ), description="", ), "validation": _interface_model.Variable( type=_idl_types.LiteralType(blob=_core_types.BlobType( format=_content_type_to_blob_format( algorithm_specification.input_content_type), dimensionality=_core_types.BlobType. BlobDimensionality.MULTIPART, ), ), description="", ), }, outputs={ "model": _interface_model.Variable( type=_idl_types.LiteralType(blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType. BlobDimensionality.SINGLE, )), description="", ) }, ), custom=MessageToDict(self._training_job_model.to_flyte_idl()), )
def __init__( self, task_type, discovery_version, retries, interruptible, task_inputs, deprecated, discoverable, timeout, spark_type, main_class, main_application_file, spark_conf, hadoop_conf, environment, ): """ :param Text task_type: string describing the task type :param Text discovery_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param bool interruptible: Whether or not task is interruptible :param Text deprecated: :param bool discoverable: :param datetime.timedelta timeout: :param Text spark_type: Type of Spark Job: Scala/Java :param Text main_class: Main class to execute for Scala/Java jobs :param Text main_application_file: Main application file :param dict[Text,Text] spark_conf: :param dict[Text,Text] hadoop_conf: :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. """ spark_job = _task_models.SparkJob( spark_conf=spark_conf, hadoop_conf=hadoop_conf, spark_type=spark_type, application_file=main_application_file, main_class=main_class, executor_path=_sys.executable, ).to_flyte_idl() super(SdkGenericSparkTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, 'spark'), timeout, _literal_models.RetryStrategy(retries), interruptible, discovery_version, deprecated), _interface.TypedInterface({}, {}), _MessageToDict(spark_job), ) # Add Inputs if task_inputs is not None: task_inputs(self) # Container after the Inputs have been updated. self._container = self._get_container_definition( environment=environment)
def __init__( self, task_type, discovery_version, retries, interruptible, task_inputs, deprecated, discoverable, timeout, jar_file, main_class, args, flink_properties, environment, ): """ :param Text task_type: string describing the task type :param Text discovery_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param bool interruptible: Whether or not task is interruptible :param Text deprecated: :param bool discoverable: :param datetime.timedelta timeout: :param Text main_class: Main class to execute for Scala/Java jobs :param Text jar_file: fat jar file :param dict[Text,Text] flink_properties: :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. """ flink_job = _task_models.FlinkJob(flink_properties=flink_properties, jar_file=jar_file, main_class=main_class, args=args).to_flyte_idl() super(SdkGenericFlinkTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "flink", ), timeout, _literal_models.RetryStrategy(retries), interruptible, discovery_version, deprecated, ), _interface.TypedInterface({}, {}), _MessageToDict(flink_job), ) # Add Inputs if task_inputs is not None: task_inputs(self) # Container after the Inputs have been updated. self._container = self._get_container_definition( environment=environment)
def __init__( self, inputs: Dict[str, FlyteSdkType], image: str, outputs: Dict[str, FlyteSdkType] = None, input_data_dir: str = None, output_data_dir: str = None, metadata_format: int = METADATA_FORMAT_JSON, io_strategy: _task_models.IOStrategy = None, command: List[str] = None, args: List[str] = None, storage_request: str = None, cpu_request: str = None, gpu_request: str = None, memory_request: str = None, storage_limit: str = None, cpu_limit: str = None, gpu_limit: str = None, memory_limit: str = None, environment: Dict[str, str] = None, interruptible: bool = False, discoverable: bool = False, discovery_version: str = None, retries: int = 1, timeout: _datetime.timedelta = None, ): """ :param inputs: :param outputs: :param image: :param command: :param args: :param storage_request: :param cpu_request: :param gpu_request: :param memory_request: :param storage_limit: :param cpu_limit: :param gpu_limit: :param memory_limit: :param environment: :param interruptible: :param discoverable: :param discovery_version: :param retries: :param timeout: :param input_data_dir: This is the directory where data will be downloaded to :param output_data_dir: This is the directory where data will be uploaded from :param metadata_format: Format in which the metadata will be available for the script """ # Set as class fields which are used down below to configure implicit # parameters self._data_loading_config = _task_models.DataLoadingConfig( input_path=input_data_dir, output_path=output_data_dir, format=metadata_format, enabled=True, io_strategy=io_strategy, ) metadata = _task_models.TaskMetadata( discoverable, # This needs to have the proper version reflected in it _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, flytekit.__version__, "python", ), timeout or _datetime.timedelta(seconds=0), _literals.RetryStrategy(retries), interruptible, discovery_version, None, ) # The interface is defined using the inputs and outputs i = _interface.TypedInterface(inputs=types_to_variable(inputs), outputs=types_to_variable(outputs)) # This sets the base SDKTask with container etc super(SdkRawContainerTask, self).__init__( _constants.SdkTaskType.RAW_CONTAINER_TASK, metadata, i, None, container=_get_container_definition( image=image, args=args, command=command, data_loading_config=self._data_loading_config, storage_request=storage_request, cpu_request=cpu_request, gpu_request=gpu_request, memory_request=memory_request, storage_limit=storage_limit, cpu_limit=cpu_limit, gpu_limit=gpu_limit, memory_limit=memory_limit, environment=environment, ), )