def with_overrides(self, *args, **kwargs): if "node_name" in kwargs: self._id = kwargs["node_name"] if "aliases" in kwargs: alias_dict = kwargs["aliases"] if not isinstance(alias_dict, dict): raise AssertionError("Aliases should be specified as dict[str, str]") self._aliases = [] for k, v in alias_dict.items(): self._aliases.append(_workflow_model.Alias(var=k, alias=v)) if "requests" in kwargs or "limits" in kwargs: requests = _convert_resource_overrides(kwargs.get("requests"), "requests") limits = _convert_resource_overrides(kwargs.get("limits"), "limits") self._resources = _resources_model(requests=requests, limits=limits) if "timeout" in kwargs: timeout = kwargs["timeout"] if timeout is None: self._metadata._timeout = datetime.timedelta() elif isinstance(timeout, int): self._metadata._timeout = datetime.timedelta(seconds=timeout) elif isinstance(timeout, datetime.timedelta): self._metadata._timeout = timeout else: raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds") if "retries" in kwargs: retries = kwargs["retries"] self._metadata._retries = ( _literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries) ) if "interruptible" in kwargs: self._metadata._interruptible = kwargs["interruptible"] return self
def test_retry_strategy(): obj = literals.RetryStrategy(3) assert obj.retries == 3 assert literals.RetryStrategy.from_flyte_idl(obj.to_flyte_idl()) == obj assert obj != literals.RetryStrategy(4) with pytest.raises(Exception): obj = literals.RetryStrategy(-1) obj.to_flyte_idl()
def __call__(self, *args, **input_map): """ :param list[T] args: Do not specify. Kwargs only are supported for this function. :param dict[Text,T] input_map: Map of inputs. Can be statically defined or OutputReference links. :rtype: flytekit.common.nodes.SdkNode """ if len(args) > 0: raise _user_exceptions.FlyteAssertion( "When adding a launchplan as a node in a workflow, all inputs must be specified with kwargs only. We " "detected {} positional args.".format(len(args))) # Take the default values from the launch plan default_inputs = { k: v.sdk_default for k, v in _six.iteritems(self.default_inputs.parameters) if not v.required } default_inputs.update(input_map) bindings, upstream_nodes = self.interface.create_bindings_for_inputs( default_inputs) return _nodes.SdkNode( id=None, metadata=_workflow_models.NodeMetadata( "", _datetime.timedelta(), _literal_models.RetryStrategy(0)), bindings=sorted(bindings, key=lambda b: b.var), upstream_nodes=upstream_nodes, sdk_launch_plan=self, )
def __call__(self, *args, **input_map): if len(args) > 0: raise _user_exceptions.FlyteAssertion( "When adding a workflow as a node in a workflow, all inputs must be specified with kwargs only. We " "detected {} positional args.".format(len(args)) ) # Take the default values from the Inputs compiled_inputs = { v.name: v.sdk_default for v in self.user_inputs if not v.sdk_required } compiled_inputs.update(input_map) bindings, upstream_nodes = self.interface.create_bindings_for_inputs(compiled_inputs) node = _nodes.SdkNode( id=None, metadata=_workflow_models.NodeMetadata("placeholder", _datetime.timedelta(), _literal_models.RetryStrategy(0)), upstream_nodes=upstream_nodes, bindings=sorted(bindings, key=lambda b: b.var), sdk_workflow=self ) return node
def get_sample_node_metadata(node_id): """ :param Text node_id: :rtype: flytekit.models.core.workflow.NodeMetadata """ return _workflow_model.NodeMetadata(name=node_id, timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0))
def test_task_template__k8s_pod_target(): int_type = types.LiteralType(types.SimpleType.INTEGER) obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task.TaskMetadata( False, task.RuntimeMetadata(1, "v", "f"), timedelta(days=1), literal_models.RetryStrategy(5), False, "1.0", "deprecated", False, ), interface_models.TypedInterface( # inputs {"a": interface_models.Variable(int_type, "description1")}, # outputs { "b": interface_models.Variable(int_type, "description2"), "c": interface_models.Variable(int_type, "description3"), }, ), { "a": 1, "b": { "c": 2, "d": 3 } }, config={"a": "b"}, k8s_pod=task.K8sPod( metadata=task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}), pod_spec={ "str": "val", "int": 1 }, ), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" assert obj.id.domain == "domain" assert obj.id.name == "name" assert obj.id.version == "version" assert obj.type == "python" assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}} assert obj.k8s_pod.metadata == task.K8sObjectMetadata( labels={"label": "foo"}, annotations={"anno": "bar"}) assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1} assert text_format.MessageToString( obj.to_flyte_idl()) == text_format.MessageToString( task.TaskTemplate.from_flyte_idl( obj.to_flyte_idl()).to_flyte_idl()) assert obj.config == {"a": "b"}
def test_node_metadata(): obj = _workflow.NodeMetadata(name="node1", timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) assert obj.timeout.seconds == 10 assert obj.retries.retries == 0 obj2 = _workflow.NodeMetadata.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.timeout.seconds == 10 assert obj2.retries.retries == 0
def get_sample_task_metadata(): """ :rtype: flytekit.models.task.TaskMetadata """ return _task_model.TaskMetadata( True, _task_model.RuntimeMetadata( _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), _literals.RetryStrategy(3), True, "0.1.1b0", "This is deprecated!")
def __init__(self, task_function, task_type, discovery_version, retries, interruptible, deprecated, storage_request, cpu_request, gpu_request, memory_request, storage_limit, cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, custom): """ :param task_function: Function container user code. This will be executed via the SDK's engine. :param Text task_type: string describing the task type :param Text discovery_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param bool interruptible: Specify whether task is interruptible :param Text deprecated: :param Text storage_request: :param Text cpu_request: :param Text gpu_request: :param Text memory_request: :param Text storage_limit: :param Text cpu_limit: :param Text gpu_limit: :param Text memory_limit: :param bool discoverable: :param datetime.timedelta timeout: :param dict[Text, Text] environment: :param dict[Text, T] custom: """ self._task_function = task_function super(SdkRunnableTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, 'python'), timeout, _literal_models.RetryStrategy(retries), interruptible, discovery_version, deprecated), _interface.TypedInterface({}, {}), custom, container=self._get_container_definition( storage_request=storage_request, cpu_request=cpu_request, gpu_request=gpu_request, memory_request=memory_request, storage_limit=storage_limit, cpu_limit=cpu_limit, gpu_limit=gpu_limit, memory_limit=memory_limit, environment=environment)) self.id._name = "{}.{}".format(self.task_module, self.task_function_name)
def _create_hive_job_node(name, hive_job, metadata): """ :param Text name: :param _qubole.QuboleHiveJob hive_job: Hive job spec :param flytekit.models.task.TaskMetadata metadata: This contains information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, and retries. :rtype: _nodes.SdkNode: """ return _nodes.SdkNode(id=_six.text_type(_uuid.uuid4()), upstream_nodes=[], bindings=[], metadata=_workflow_model.NodeMetadata( name, metadata.timeout, _literal_models.RetryStrategy(0)), sdk_task=SdkHiveJob(hive_job, metadata))
def test_task_metadata(): obj = task.TaskMetadata( True, task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!") assert obj.discoverable is True assert obj.retries.retries == 3 assert obj.timeout == timedelta(days=1) assert obj.runtime.flavor == "python" assert obj.runtime.type == task.RuntimeMetadata.RuntimeType.FLYTE_SDK assert obj.runtime.version == "1.0.0" assert obj.deprecated_error_message == "This is deprecated!" assert obj.discovery_version == "0.1.1b0" assert obj == task.TaskMetadata.from_flyte_idl(obj.to_flyte_idl())
def test_sdk_node_from_lp(): @_tasks.inputs(a=_types.Types.Integer) @_tasks.outputs(b=_types.Types.Integer) @_tasks.python_task() def testy_test(wf_params, a, b): pass @_workflow.workflow_class class test_workflow(object): a = _workflow.Input(_types.Types.Integer) test = testy_test(a=a) b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer) lp = test_workflow.create_launch_plan() n1 = _nodes.SdkNode( "n1", [], [ _literals.Binding( "a", _interface.BindingData.from_python_std( _types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), sdk_launch_plan=lp, ) assert n1.id == "n1" assert len(n1.inputs) == 1 assert n1.inputs[0].var == "a" assert n1.inputs[0].binding.scalar.primitive.integer == 3 assert len(n1.outputs) == 1 assert "b" in n1.outputs assert n1.outputs["b"].node_id == "n1" assert n1.outputs["b"].var == "b" assert n1.outputs["b"].sdk_node == n1 assert n1.outputs["b"].sdk_type == _types.Types.Integer assert n1.metadata.name == "abc" assert n1.metadata.retries.retries == 3 assert len(n1.upstream_nodes) == 0 assert len(n1.upstream_node_ids) == 0 assert len(n1.output_aliases) == 0
def __call__(self, *args, **input_map): if len(args) > 0: raise _user_exceptions.FlyteAssertion( "When adding a workflow as a node in a workflow, all inputs must be specified with kwargs only. We " "detected {} positional args.".format(len(args))) bindings, upstream_nodes = self.interface.create_bindings_for_inputs( input_map) node = _nodes.SdkNode( id=None, metadata=_workflow_models.NodeMetadata( "placeholder", _datetime.timedelta(), _literal_models.RetryStrategy(0)), upstream_nodes=upstream_nodes, bindings=sorted(bindings, key=lambda b: b.var), sdk_workflow=self, ) return node
def test_future_task_document(task): rs = _literals.RetryStrategy(0) nm = _workflow.NodeMetadata('node-name', _timedelta(minutes=10), rs) n = _workflow.Node(id="id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=_workflow.TaskNode(task.id)) n.to_flyte_idl() doc = _dynamic_job.DynamicJobSpec( tasks=[task], nodes=[n], min_successes=1, outputs=[_literals.Binding("var", _literals.BindingData())], subworkflows=[]) assert text_format.MessageToString( doc.to_flyte_idl()) == text_format.MessageToString( _dynamic_job.DynamicJobSpec.from_flyte_idl( doc.to_flyte_idl()).to_flyte_idl())
def _op_to_task(self, dag_id, image, op, node_map): """ Generate task given an operator inherited from dsl.ContainerOp. :param airflow.models.BaseOperator op: :param dict(Text, SdkNode) node_map: :rtype: Tuple(base_tasks.SdkTask, SdkNode) """ interface_inputs = {} interface_outputs = {} input_mappings = {} processed_args = None # for key, val in six.iteritems(op.params): # interface_inputs[key] = interface_model.Variable( # _type_helpers.python_std_to_sdk_type(Types.String).to_flyte_literal_type(), # '' # ) # # if param.op_name == '': # binding = promise_common.Input(sdk_type=Types.String, name=param.name) # else: # binding = promise_common.NodeOutput( # sdk_node=node_map[param.op_name], # sdk_type=Types.String, # var=param.name) # input_mappings[param.name] = binding # # for param in op.outputs.values(): # interface_outputs[param.name] = interface_model.Variable( # _type_helpers.python_std_to_sdk_type(Types.String).to_flyte_literal_type(), # '' # ) requests = [] if op.resources: requests.append( task_model.Resources.ResourceEntry( task_model.Resources.ResourceName.Cpu, op.resources.cpus)) requests.append( task_model.Resources.ResourceEntry( task_model.Resources.ResourceName.Memory, op.resources.ram)) requests.append( task_model.Resources.ResourceEntry( task_model.Resources.ResourceName.Gpu, op.resources.gpus)) requests.append( task_model.Resources.ResourceEntry( task_model.Resources.ResourceName.Storage, op.resources.disk)) task_instance = TaskInstance(op, datetime.datetime.now()) command = task_instance.command_as_list(local=True, mark_success=False, ignore_all_deps=True, ignore_depends_on_past=True, ignore_task_deps=True, ignore_ti_state=True, pool=task_instance.pool, pickle_id=dag_id, cfg_path=None) task = base_tasks.SdkTask( op.task_id, SingleStepTask, "airflow_op", task_model.TaskMetadata( False, task_model.RuntimeMetadata( type=task_model.RuntimeMetadata.RuntimeType.Other, version=airflow.version.version, flavor='airflow'), datetime.timedelta(seconds=0), literals_model.RetryStrategy(0), '1', None, ), interface_common.TypedInterface(inputs=interface_inputs, outputs=interface_outputs), custom=None, container=task_model.Container( image=image, command=command, args=[], resources=task_model.Resources(limits=[], requests=requests), env={}, config={}, )) return task, task(**input_mappings).assign_id_and_return(op.task_id)
def __init__( self, notebook_path, inputs, outputs, task_type, discovery_version, retries, deprecated, storage_request, cpu_request, gpu_request, memory_request, storage_limit, cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, custom, ): if _os.path.isabs(notebook_path) is False: # Find absolute path for the notebook. task_module = _importlib.import_module(_find_instance_module()) module_path = _os.path.dirname(task_module.__file__) notebook_path = _os.path.normpath( _os.path.join(module_path, notebook_path)) self._notebook_path = notebook_path super(SdkNotebookTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "notebook", ), timeout, _literal_models.RetryStrategy(retries), False, discovery_version, deprecated, ), _interface2.TypedInterface({}, {}), custom, container=self._get_container_definition( storage_request=storage_request, cpu_request=cpu_request, gpu_request=gpu_request, memory_request=memory_request, storage_limit=storage_limit, cpu_limit=cpu_limit, gpu_limit=gpu_limit, memory_limit=memory_limit, environment=environment, ), ) # Add Inputs if inputs is not None: inputs(self) # Add outputs if outputs is not None: outputs(self) # Add a Notebook output as a Blob. self.interface.outputs.update(output_notebook=_interface.Variable( _Types.Blob.to_flyte_literal_type(), OUTPUT_NOTEBOOK))
def __init__( self, inputs: Dict[str, FlyteSdkType], image: str, outputs: Dict[str, FlyteSdkType] = None, input_data_dir: str = None, output_data_dir: str = None, metadata_format: int = METADATA_FORMAT_JSON, io_strategy: _task_models.IOStrategy = None, command: List[str] = None, args: List[str] = None, storage_request: str = None, cpu_request: str = None, gpu_request: str = None, memory_request: str = None, storage_limit: str = None, cpu_limit: str = None, gpu_limit: str = None, memory_limit: str = None, environment: Dict[str, str] = None, interruptible: bool = False, discoverable: bool = False, discovery_version: str = None, retries: int = 1, timeout: _datetime.timedelta = None, ): """ :param inputs: :param outputs: :param image: :param command: :param args: :param storage_request: :param cpu_request: :param gpu_request: :param memory_request: :param storage_limit: :param cpu_limit: :param gpu_limit: :param memory_limit: :param environment: :param interruptible: :param discoverable: :param discovery_version: :param retries: :param timeout: :param input_data_dir: This is the directory where data will be downloaded to :param output_data_dir: This is the directory where data will be uploaded from :param metadata_format: Format in which the metadata will be available for the script """ # Set as class fields which are used down below to configure implicit # parameters self._data_loading_config = _task_models.DataLoadingConfig( input_path=input_data_dir, output_path=output_data_dir, format=metadata_format, enabled=True, io_strategy=io_strategy, ) metadata = _task_models.TaskMetadata( discoverable, # This needs to have the proper version reflected in it _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, flytekit.__version__, "python", ), timeout or _datetime.timedelta(seconds=0), _literals.RetryStrategy(retries), interruptible, discovery_version, None, ) # The interface is defined using the inputs and outputs i = _interface.TypedInterface(inputs=types_to_variable(inputs), outputs=types_to_variable(outputs)) # This sets the base SDKTask with container etc super(SdkRawContainerTask, self).__init__( _constants.SdkTaskType.RAW_CONTAINER_TASK, metadata, i, None, container=_get_container_definition( image=image, args=args, command=command, data_loading_config=self._data_loading_config, storage_request=storage_request, cpu_request=cpu_request, gpu_request=gpu_request, memory_request=memory_request, storage_limit=storage_limit, cpu_limit=cpu_limit, gpu_limit=gpu_limit, memory_limit=memory_limit, environment=environment, ), )
def __init__( self, max_number_of_training_jobs: int, max_parallel_training_jobs: int, training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask], retries: int = 0, cacheable: bool = False, cache_version: str = "", tunable_parameters: typing.List[str] = None, ): """ :param int max_number_of_training_jobs: The maximum number of training jobs that can be launched by this hyperparameter tuning job :param int max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter tuning job in parallel :param typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask] training_job: The reference to the training job definition :param int retries: Number of retries to attempt :param bool cacheable: The flag to set if the user wants the output of the task execution to be cached :param str cache_version: String describing the caching version for task discovery purposes :param typing.List[str] tunable_parameters: A list of parameters that to tune. If you are tuning a built-int algorithm, refer to the algorithm's documentation to understand the possible values for the tunable parameters. E.g. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html for the list of hyperparameters for Image Classification built-in algorithm. If you are passing a custom training job, the list of tunable parameters must be a strict subset of the list of inputs defined on that job. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html for the list of supported hyperparameter types. """ # Use the training job model as a measure of type checking hpo_job = _hpo_job_model.HyperparameterTuningJob( max_number_of_training_jobs=max_number_of_training_jobs, max_parallel_training_jobs=max_parallel_training_jobs, training_job=training_job.training_job_model, ).to_flyte_idl() # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of # the underlying training job # TODO: Discuss whether this is a viable interface or contract timeout = _datetime.timedelta(seconds=0) inputs = {} inputs.update(training_job.interface.inputs) inputs.update({ "hyperparameter_tuning_job_config": _interface_model.Variable( HyperparameterTuningJobConfig.to_flyte_literal_type(), "", ), }) if tunable_parameters: inputs.update({ param: _interface_model.Variable( ParameterRange.to_flyte_literal_type(), "") for param in tunable_parameters }) super().__init__( type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, retries=_literal_models.RetryStrategy(retries=retries), interruptible=False, discovery_version=cache_version, deprecated_error_message="", ), interface=_interface.TypedInterface( inputs=inputs, outputs={ "model": _interface_model.Variable( type=_types_models.LiteralType( blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType. BlobDimensionality.SINGLE, )), description="", ) }, ), custom=MessageToDict(hpo_job), )
def create_and_link_node( ctx: FlyteContext, entity, interface: flyte_interface.Interface, timeout: Optional[datetime.timedelta] = None, retry_strategy: Optional[_literal_models.RetryStrategy] = None, **kwargs, ): """ This method is used to generate a node with bindings. This is not used in the execution path. """ if ctx.compilation_state is None: raise _user_exceptions.FlyteAssertion( "Cannot create node when not compiling...") used_inputs = set() bindings = [] typed_interface = flyte_interface.transform_interface_to_typed_interface( interface) for k in sorted(interface.inputs): var = typed_interface.inputs[k] if k not in kwargs: raise _user_exceptions.FlyteAssertion( "Input was not specified for: {} of type {}".format( k, var.type)) v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed # into the function. if isinstance(v, tuple): raise AssertionError( f"Variable({k}) for function({entity.name}) cannot receive a multi-valued tuple {v}." f" Check if the predecessor function returning more than one value?" ) bindings.append( binding_from_python_std(ctx, var_name=k, expected_literal_type=var.type, t_value=v, t_value_type=interface.inputs[k])) used_inputs.add(k) extra_inputs = used_inputs ^ set(kwargs.keys()) if len(extra_inputs) > 0: raise _user_exceptions.FlyteAssertion( "Too many inputs were specified for the interface. Extra inputs were: {}" .format(extra_inputs)) # Detect upstream nodes # These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes upstream_nodes = list( set([ input_val.ref.node for input_val in kwargs.values() if isinstance(input_val, Promise) and input_val.ref.node_id != _common_constants.GLOBAL_INPUT_NODE_ID ])) node_metadata = _workflow_model.NodeMetadata( f"{entity.__module__}.{entity.name}", timeout or datetime.timedelta(), retry_strategy or _literal_models.RetryStrategy(0), ) non_sdk_node = Node( # TODO: Better naming, probably a derivative of the function name. id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}", metadata=node_metadata, bindings=sorted(bindings, key=lambda b: b.var), upstream_nodes=upstream_nodes, flyte_entity=entity, ) ctx.compilation_state.add_node(non_sdk_node) if len(typed_interface.outputs) == 0: return VoidPromise(entity.name) # Create a node output object for each output, they should all point to this node of course. node_outputs = [] for output_name, output_var_model in typed_interface.outputs.items(): # TODO: If node id gets updated later, we have to make sure to update the NodeOutput model's ID, which # is currently just a static str node_outputs.append( Promise(output_name, NodeOutput(node=non_sdk_node, var=output_name))) # Don't print this, it'll crash cuz sdk_node._upstream_node_ids might be None, but idl code will break return create_task_output(node_outputs, interface)
def __init__( self, region, role_arn, resource_config, algorithm_specification=None, stopping_condition=None, vpc_config=None, enable_spot_training=False, interruptible=False, retries=0, cacheable=False, cache_version="", ): """ :param Text region: The region in which to run the SageMaker job. :param Text role_arn: The ARN of the role to run in the SageMaker job. :param dict[Text,T] algorithm_specification: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html :param dict[Text,T] resource_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ResourceConfig.html :param dict[Text,T] stopping_condition: https://docs.aws.amazon.com/sagemaker/latest/dg/API_StoppingCondition.html :param dict[Text,T] vpc_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html :param bool enable_spot_training: https://docs.aws.amazon.com/sagemaker/latest/dg/API_HyperParameterTrainingJobDefinition.html :param int retries: Number of time to retry. :param bool cacheable: Whether or not to use Flyte's caching system. :param Text cache_version: Update this to notify a behavioral change requiring the cache to be invalidated. """ algorithm_specification = algorithm_specification or {} algorithm_specification["TrainingImage"] = ( algorithm_specification.get("TrainingImage") or "825641698319.dkr.ecr.us-east-2.amazonaws.com/xgboost:1") algorithm_specification["TrainingInputMode"] = "File" job_config = ParseDict( { "Region": region, "ResourceConfig": resource_config, "StoppingCondition": stopping_condition, "VpcConfig": vpc_config, "AlgorithmSpecification": algorithm_specification, "RoleArn": role_arn, }, sagemaker_pb2.SagemakerHPOJob(), ) print(MessageToDict(job_config)) # TODO: Optionally, pull timeout behavior from stopping condition and pass to Flyte task def. timeout = _datetime.timedelta(seconds=0) # TODO: The FlyteKit type engine is extensible so we can create a SagemakerInput type with custom # TODO: parsing/casting logic. For now, we will use the Generic type since there is a little that needs # TODO: to be done on Flyte side to unlock this cleanly. # TODO: This call to the super-constructor will be less verbose in future versions of Flytekit following a # TODO: refactor. # TODO: Add more configurations to the custom dict. These are things that are necessary to execute the task, # TODO: but might not affect the outputs (i.e. Running on a bigger machine). These are currently static for # TODO: a given definition of a task, but will be more dynamic in the future. Also, it is possible to # TODO: make it dynamic by using our @dynamic_task. # TODO: You might want to inherit the role ARN from the execution at runtime. super(SagemakerXgBoostOptimizer, self).__init__( type=_TASK_TYPE, metadata=_task_models.TaskMetadata( discoverable=cacheable, runtime=_task_models.RuntimeMetadata(0, "0.1.0b0", "sagemaker"), timeout=timeout, retries=_literal_models.RetryStrategy(retries=retries), interruptible=interruptible, discovery_version=cache_version, deprecated_error_message="", ), interface=_interface.TypedInterface({}, {}), custom=MessageToDict(job_config), ) # TODO: Add more inputs that we expect to change the outputs of the task. # TODO: We can add outputs too! # We use helper methods for adding to interface, thus overriding the one set above. This will be simplified post # refactor. self.add_inputs({ "static_hyperparameters": _interface_model.Variable( _sdk_types.Types.Generic.to_flyte_literal_type(), ""), "train": _interface_model.Variable( _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""), "validation": _interface_model.Variable( _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""), }) self.add_outputs({ "model": _interface_model.Variable( _sdk_types.Types.Blob.to_flyte_literal_type(), "") })
def __init__( self, task_function, task_type, discovery_version, retries, interruptible, deprecated, storage_request, cpu_request, gpu_request, memory_request, storage_limit, cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, custom, ): """ :param task_function: Function container user code. This will be executed via the SDK's engine. :param Text task_type: string describing the task type :param Text discovery_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param bool interruptible: Specify whether task is interruptible :param Text deprecated: :param Text storage_request: :param Text cpu_request: :param Text gpu_request: :param Text memory_request: :param Text storage_limit: :param Text cpu_limit: :param Text gpu_limit: :param Text memory_limit: :param bool discoverable: :param datetime.timedelta timeout: :param dict[Text, Text] environment: :param dict[Text, T] custom: """ # Circular dependency from flytekit import __version__ self._task_function = task_function super(SdkRunnableTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python", ), timeout, _literal_models.RetryStrategy(retries), interruptible, discovery_version, deprecated, ), # TODO: If we end up using SdkRunnableTask for the new code, make sure this is set correctly. _interface.TypedInterface({}, {}), custom, container=self._get_container_definition( storage_request=storage_request, cpu_request=cpu_request, gpu_request=gpu_request, memory_request=memory_request, storage_limit=storage_limit, cpu_limit=cpu_limit, gpu_limit=gpu_limit, memory_limit=memory_limit, environment=environment, ), ) self.id._name = "{}.{}".format(self.task_module, self.task_function_name) self._has_fast_registered = False # TODO: Remove this in the future, I don't think we'll be using this. self._task_style = SdkRunnableTaskStyle.V0
LIST_OF_RESOURCE_ENTRY_LISTS = [LIST_OF_RESOURCE_ENTRIES] LIST_OF_RESOURCES = [ task.Resources(request, limit) for request, limit in product( LIST_OF_RESOURCE_ENTRY_LISTS, LIST_OF_RESOURCE_ENTRY_LISTS) ] LIST_OF_RUNTIME_METADATA = [ task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.OTHER, "1.0.0", "python"), task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0b0", "golang") ] LIST_OF_RETRY_POLICIES = [ literals.RetryStrategy(retries=i) for i in [0, 1, 3, 100] ] LIST_OF_INTERRUPTIBLE = [None, True, False] LIST_OF_TASK_METADATA = [ task.TaskMetadata(discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated) for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated in product([True, False], LIST_OF_RUNTIME_METADATA, [timedelta(days=i) for i in range(3)], LIST_OF_RETRY_POLICIES, LIST_OF_INTERRUPTIBLE, ["1.0"], ["deprecated"]) ] LIST_OF_TASK_TEMPLATES = [
def __init__( self, statement, output_schema, routing_group=None, catalog=None, schema=None, task_inputs=None, interruptible=False, discoverable=False, discovery_version=None, retries=1, timeout=None, deprecated=None, ): """ :param Text statement: Presto query specification :param flytekit.common.types.schema.Schema output_schema: Schema that represents that data queried from Presto :param Text routing_group: The routing group that a Presto query should be sent to for the given environment :param Text catalog: The catalog to set for the given Presto query :param Text schema: The schema to set for the given Presto query :param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] task_inputs: Optional inputs to the Presto task :param bool discoverable: :param Text discovery_version: String describing the version for task discovery purposes :param int retries: Number of retries to attempt :param datetime.timedelta timeout: :param Text deprecated: This string can be used to mark the task as deprecated. Consumers of the task will receive deprecation warnings. """ # Set as class fields which are used down below to configure implicit # parameters self._routing_group = routing_group or "" self._catalog = catalog or "" self._schema = schema or "" metadata = _task_model.TaskMetadata( discoverable, # This needs to have the proper version reflected in it _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"), timeout or _datetime.timedelta(seconds=0), _literals.RetryStrategy(retries), interruptible, discovery_version, deprecated, ) presto_query = _presto_models.PrestoQuery( routing_group=routing_group or "", catalog=catalog or "", schema=schema or "", statement=statement, ) # Here we set the routing_group, catalog, and schema as implicit # parameters for caching purposes i = _interface.TypedInterface( { "__implicit_routing_group": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), description="The routing group set as an implicit input", ), "__implicit_catalog": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), description="The catalog set as an implicit input", ), "__implicit_schema": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), description="The schema set as an implicit input", ), }, { # Set the schema for the Presto query as an output "results": _interface_model.Variable( type=_types.LiteralType(schema=output_schema.schema_type), description="The schema for the Presto query", ) }, ) super(SdkPrestoTask, self).__init__( _constants.SdkTaskType.PRESTO_TASK, metadata, i, _MessageToDict(presto_query.to_flyte_idl()), ) # Set user provided inputs task_inputs(self)
def _get_sample_node_metadata(): return _workflow.NodeMetadata(name="node1", timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0))
def __init__( self, max_number_of_training_jobs: int, max_parallel_training_jobs: int, training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask], retries: int = 0, cacheable: bool = False, cache_version: str = "", ): """ :param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this hyperparameter tuning job :param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter tuning job in parallel :param training_job: The reference to the training job definition :param retries: Number of retries to attempt :param cacheable: The flag to set if the user wants the output of the task execution to be cached :param cache_version: String describing the caching version for task discovery purposes """ # Use the training job model as a measure of type checking hpo_job = _hpo_job_model.HyperparameterTuningJob( max_number_of_training_jobs=max_number_of_training_jobs, max_parallel_training_jobs=max_parallel_training_jobs, training_job=training_job.training_job_model, ).to_flyte_idl() # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of # the underlying training job # TODO: Discuss whether this is a viable interface or contract timeout = _datetime.timedelta(seconds=0) inputs = { "hyperparameter_tuning_job_config": _interface_model.Variable( _sdk_types.Types.Proto( _pb2_hpo_job.HyperparameterTuningJobConfig). to_flyte_literal_type(), "", ), } inputs.update(training_job.interface.inputs) super(SdkSimpleHyperparameterTuningJobTask, self).__init__( type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, retries=_literal_models.RetryStrategy(retries=retries), interruptible=False, discovery_version=cache_version, deprecated_error_message="", ), interface=_interface.TypedInterface( inputs=inputs, outputs={ "model": _interface_model.Variable( type=_types_models.LiteralType( blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType. BlobDimensionality.SINGLE, )), description="", ) }, ), custom=MessageToDict(hpo_job), )
def __init__( self, task_type, discovery_version, retries, interruptible, task_inputs, deprecated, discoverable, timeout, jar_file, main_class, args, flink_properties, environment, ): """ :param Text task_type: string describing the task type :param Text discovery_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param bool interruptible: Whether or not task is interruptible :param Text deprecated: :param bool discoverable: :param datetime.timedelta timeout: :param Text main_class: Main class to execute for Scala/Java jobs :param Text jar_file: fat jar file :param dict[Text,Text] flink_properties: :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. """ flink_job = _task_models.FlinkJob(flink_properties=flink_properties, jar_file=jar_file, main_class=main_class, args=args).to_flyte_idl() super(SdkGenericFlinkTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "flink", ), timeout, _literal_models.RetryStrategy(retries), interruptible, discovery_version, deprecated, ), _interface.TypedInterface({}, {}), _MessageToDict(flink_job), ) # Add Inputs if task_inputs is not None: task_inputs(self) # Container after the Inputs have been updated. self._container = self._get_container_definition( environment=environment)
def test_workflow_closure(): int_type = _types.LiteralType(_types.SimpleType.INTEGER) typed_interface = _interface.TypedInterface( {'a': _interface.Variable(int_type, "description1")}, { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) b0 = _literals.Binding( 'a', _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5)))) b1 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'b'))) b2 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'c'))) node_metadata = _workflow.NodeMetadata(name='node1', timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) task_metadata = _task.TaskMetadata( True, _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!") cpu_resource = _task.Resources.ResourceEntry( _task.Resources.ResourceName.CPU, "1") resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource]) task = _task.TaskTemplate( _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, typed_interface, { 'a': 1, 'b': { 'c': 2, 'd': 3 } }, container=_task.Container("my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {})) task_node = _workflow.TaskNode(task.id) node = _workflow.Node(id='my_node', metadata=node_metadata, inputs=[b0], upstream_node_ids=[], output_aliases=[], task_node=task_node) template = _workflow.WorkflowTemplate( id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version"), metadata=_workflow.WorkflowMetadata(), interface=typed_interface, nodes=[node], outputs=[b1, b2], ) obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task]) assert len(obj.tasks) == 1 obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2
def test_sdk_node_from_task(): @_tasks.inputs(a=_types.Types.Integer) @_tasks.outputs(b=_types.Types.Integer) @_tasks.python_task() def testy_test(wf_params, a, b): pass n = _nodes.SdkNode( "n", [], [ _literals.Binding( "a", _interface.BindingData.from_python_std( _types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), sdk_task=testy_test, sdk_workflow=None, sdk_launch_plan=None, sdk_branch=None, ) assert n.id == "n" assert len(n.inputs) == 1 assert n.inputs[0].var == "a" assert n.inputs[0].binding.scalar.primitive.integer == 3 assert len(n.outputs) == 1 assert "b" in n.outputs assert n.outputs["b"].node_id == "n" assert n.outputs["b"].var == "b" assert n.outputs["b"].sdk_node == n assert n.outputs["b"].sdk_type == _types.Types.Integer assert n.metadata.name == "abc" assert n.metadata.retries.retries == 3 assert n.metadata.interruptible is None assert len(n.upstream_nodes) == 0 assert len(n.upstream_node_ids) == 0 assert len(n.output_aliases) == 0 n2 = _nodes.SdkNode( "n2", [n], [ _literals.Binding( "a", _interface.BindingData.from_python_std( _types.Types.Integer.to_flyte_literal_type(), n.outputs.b), ) ], _core_workflow_models.NodeMetadata("abc2", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), sdk_task=testy_test, sdk_workflow=None, sdk_launch_plan=None, sdk_branch=None, ) assert n2.id == "n2" assert len(n2.inputs) == 1 assert n2.inputs[0].var == "a" assert n2.inputs[0].binding.promise.var == "b" assert n2.inputs[0].binding.promise.node_id == "n" assert len(n2.outputs) == 1 assert "b" in n2.outputs assert n2.outputs["b"].node_id == "n2" assert n2.outputs["b"].var == "b" assert n2.outputs["b"].sdk_node == n2 assert n2.outputs["b"].sdk_type == _types.Types.Integer assert n2.metadata.name == "abc2" assert n2.metadata.retries.retries == 3 assert "n" in n2.upstream_node_ids assert n in n2.upstream_nodes assert len(n2.upstream_nodes) == 1 assert len(n2.upstream_node_ids) == 1 assert len(n2.output_aliases) == 0 # Test right shift operator and late binding n3 = _nodes.SdkNode( "n3", [], [ _literals.Binding( "a", _interface.BindingData.from_python_std( _types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc3", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), sdk_task=testy_test, sdk_workflow=None, sdk_launch_plan=None, sdk_branch=None, ) n2 >> n3 n >> n2 >> n3 n3 << n2 n3 << n2 << n assert n3.id == "n3" assert len(n3.inputs) == 1 assert n3.inputs[0].var == "a" assert n3.inputs[0].binding.scalar.primitive.integer == 3 assert len(n3.outputs) == 1 assert "b" in n3.outputs assert n3.outputs["b"].node_id == "n3" assert n3.outputs["b"].var == "b" assert n3.outputs["b"].sdk_node == n3 assert n3.outputs["b"].sdk_type == _types.Types.Integer assert n3.metadata.name == "abc3" assert n3.metadata.retries.retries == 3 assert "n2" in n3.upstream_node_ids assert n2 in n3.upstream_nodes assert len(n3.upstream_nodes) == 1 assert len(n3.upstream_node_ids) == 1 assert len(n3.output_aliases) == 0 # Test left shift operator and late binding n4 = _nodes.SdkNode( "n4", [], [ _literals.Binding( "a", _interface.BindingData.from_python_std( _types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc4", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), sdk_task=testy_test, sdk_workflow=None, sdk_launch_plan=None, sdk_branch=None, ) n4 << n3 # Test that implicit dependencies don't cause direct dependencies n4 << n3 << n2 << n n >> n2 >> n3 >> n4 assert n4.id == "n4" assert len(n4.inputs) == 1 assert n4.inputs[0].var == "a" assert n4.inputs[0].binding.scalar.primitive.integer == 3 assert len(n4.outputs) == 1 assert "b" in n4.outputs assert n4.outputs["b"].node_id == "n4" assert n4.outputs["b"].var == "b" assert n4.outputs["b"].sdk_node == n4 assert n4.outputs["b"].sdk_type == _types.Types.Integer assert n4.metadata.name == "abc4" assert n4.metadata.retries.retries == 3 assert "n3" in n4.upstream_node_ids assert n3 in n4.upstream_nodes assert len(n4.upstream_nodes) == 1 assert len(n4.upstream_node_ids) == 1 assert len(n4.output_aliases) == 0 # Add another dependency n4 << n2 assert "n3" in n4.upstream_node_ids assert n3 in n4.upstream_nodes assert "n2" in n4.upstream_node_ids assert n2 in n4.upstream_nodes assert len(n4.upstream_nodes) == 2 assert len(n4.upstream_node_ids) == 2
def retry_strategy(self) -> _literal_models.RetryStrategy: return _literal_models.RetryStrategy(self.retries)
def __init__( self, task_type, discovery_version, retries, interruptible, task_inputs, deprecated, discoverable, timeout, spark_type, main_class, main_application_file, spark_conf, hadoop_conf, environment, ): """ :param Text task_type: string describing the task type :param Text discovery_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param bool interruptible: Whether or not task is interruptible :param Text deprecated: :param bool discoverable: :param datetime.timedelta timeout: :param Text spark_type: Type of Spark Job: Scala/Java :param Text main_class: Main class to execute for Scala/Java jobs :param Text main_application_file: Main application file :param dict[Text,Text] spark_conf: :param dict[Text,Text] hadoop_conf: :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. """ spark_job = _task_models.SparkJob( spark_conf=spark_conf, hadoop_conf=hadoop_conf, spark_type=spark_type, application_file=main_application_file, main_class=main_class, executor_path=_sys.executable, ).to_flyte_idl() super(SdkGenericSparkTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, 'spark'), timeout, _literal_models.RetryStrategy(retries), interruptible, discovery_version, deprecated), _interface.TypedInterface({}, {}), _MessageToDict(spark_job), ) # Add Inputs if task_inputs is not None: task_inputs(self) # Container after the Inputs have been updated. self._container = self._get_container_definition( environment=environment)