def test_typed_interface(literal_type): typed_interface = interface.TypedInterface( {'a': interface.Variable(literal_type, "description1")}, { 'b': interface.Variable(literal_type, "description2"), 'c': interface.Variable(literal_type, "description3") }) assert typed_interface.inputs['a'].type == literal_type assert typed_interface.outputs['b'].type == literal_type assert typed_interface.outputs['c'].type == literal_type assert typed_interface.inputs['a'].description == "description1" assert typed_interface.outputs['b'].description == "description2" assert typed_interface.outputs['c'].description == "description3" assert len(typed_interface.inputs) == 1 assert len(typed_interface.outputs) == 2 pb = typed_interface.to_flyte_idl() deserialized_typed_interface = interface.TypedInterface.from_flyte_idl(pb) assert typed_interface == deserialized_typed_interface assert deserialized_typed_interface.inputs['a'].type == literal_type assert deserialized_typed_interface.outputs['b'].type == literal_type assert deserialized_typed_interface.outputs['c'].type == literal_type assert deserialized_typed_interface.inputs[ 'a'].description == "description1" assert deserialized_typed_interface.outputs[ 'b'].description == "description2" assert deserialized_typed_interface.outputs[ 'c'].description == "description3" assert len(deserialized_typed_interface.inputs) == 1 assert len(deserialized_typed_interface.outputs) == 2
def test_workflow_template_with_queuing_budget(): task = _workflow.TaskNode(reference_id=_generic_id) nm = _get_sample_node_metadata() int_type = _types.LiteralType(_types.SimpleType.INTEGER) wf_metadata = _workflow.WorkflowMetadata(queuing_budget=timedelta( seconds=10)) wf_metadata_defaults = _workflow.WorkflowMetadataDefaults() typed_interface = _interface.TypedInterface( {'a': _interface.Variable(int_type, "description1")}, { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) wf_node = _workflow.Node(id='some:node:id', metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task) obj = _workflow.WorkflowTemplate(id=_generic_id, metadata=wf_metadata, metadata_defaults=wf_metadata_defaults, interface=typed_interface, nodes=[wf_node], outputs=[]) obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj
def test_basic_unit_test(): def add_one(wf_params, value_in, value_out): value_out.set(value_in + 1) t = sdk_runnable.SdkRunnableTask( add_one, _common_constants.SdkTaskType.PYTHON_TASK, "1", 1, None, None, None, None, None, None, None, None, None, False, None, {}, None, ) t.add_inputs({'value_in': interface.Variable(primitives.Integer.to_flyte_literal_type(), "")}) t.add_outputs({'value_out': interface.Variable(primitives.Integer.to_flyte_literal_type(), "")}) out = t.unit_test(value_in=1) assert out['value_out'] == 2
def test_basic_unit_test(): def add_one(wf_params, value_in, value_out): value_out.set(value_in + 1) t = sdk_runnable.SdkRunnableTask( add_one, _common_constants.SdkTaskType.PYTHON_TASK, "1", 1, None, None, None, None, None, None, None, None, None, None, False, None, {}, None, ) t.add_inputs({'value_in': interface.Variable(primitives.Integer.to_flyte_literal_type(), "")}) t.add_outputs({'value_out': interface.Variable(primitives.Integer.to_flyte_literal_type(), "")}) out = t.unit_test(value_in=1) assert out['value_out'] == 2 with _pytest.raises(_user_exceptions.FlyteAssertion) as e: t() assert "value_in" in str(e.value) assert "INTEGER" in str(e.value)
def test_typed_interface(literal_type): typed_interface = interface.TypedInterface( {"a": interface.Variable(literal_type, "description1")}, { "b": interface.Variable(literal_type, "description2"), "c": interface.Variable(literal_type, "description3") }, ) assert typed_interface.inputs["a"].type == literal_type assert typed_interface.outputs["b"].type == literal_type assert typed_interface.outputs["c"].type == literal_type assert typed_interface.inputs["a"].description == "description1" assert typed_interface.outputs["b"].description == "description2" assert typed_interface.outputs["c"].description == "description3" assert len(typed_interface.inputs) == 1 assert len(typed_interface.outputs) == 2 pb = typed_interface.to_flyte_idl() deserialized_typed_interface = interface.TypedInterface.from_flyte_idl(pb) assert typed_interface == deserialized_typed_interface assert deserialized_typed_interface.inputs["a"].type == literal_type assert deserialized_typed_interface.outputs["b"].type == literal_type assert deserialized_typed_interface.outputs["c"].type == literal_type assert deserialized_typed_interface.inputs[ "a"].description == "description1" assert deserialized_typed_interface.outputs[ "b"].description == "description2" assert deserialized_typed_interface.outputs[ "c"].description == "description3" assert len(deserialized_typed_interface.inputs) == 1 assert len(deserialized_typed_interface.outputs) == 2
def test_workflow_template(): task = _workflow.TaskNode(reference_id=_generic_id) nm = _get_sample_node_metadata() int_type = _types.LiteralType(_types.SimpleType.INTEGER) wf_metadata = _workflow.WorkflowMetadata() wf_metadata_defaults = _workflow.WorkflowMetadataDefaults() typed_interface = _interface.TypedInterface( {"a": _interface.Variable(int_type, "description1")}, { "b": _interface.Variable(int_type, "description2"), "c": _interface.Variable(int_type, "description3") }, ) wf_node = _workflow.Node( id="some:node:id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task, ) obj = _workflow.WorkflowTemplate( id=_generic_id, metadata=wf_metadata, metadata_defaults=wf_metadata_defaults, interface=typed_interface, nodes=[wf_node], outputs=[], ) obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj
def test_task_template__k8s_pod_target(): int_type = types.LiteralType(types.SimpleType.INTEGER) obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task.TaskMetadata( False, task.RuntimeMetadata(1, "v", "f"), timedelta(days=1), literal_models.RetryStrategy(5), False, "1.0", "deprecated", False, ), interface_models.TypedInterface( # inputs {"a": interface_models.Variable(int_type, "description1")}, # outputs { "b": interface_models.Variable(int_type, "description2"), "c": interface_models.Variable(int_type, "description3"), }, ), { "a": 1, "b": { "c": 2, "d": 3 } }, config={"a": "b"}, k8s_pod=task.K8sPod( metadata=task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}), pod_spec={ "str": "val", "int": 1 }, ), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" assert obj.id.domain == "domain" assert obj.id.name == "name" assert obj.id.version == "version" assert obj.type == "python" assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}} assert obj.k8s_pod.metadata == task.K8sObjectMetadata( labels={"label": "foo"}, annotations={"anno": "bar"}) assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1} assert text_format.MessageToString( obj.to_flyte_idl()) == text_format.MessageToString( task.TaskTemplate.from_flyte_idl( obj.to_flyte_idl()).to_flyte_idl()) assert obj.config == {"a": "b"}
def test_basic_workflow_promote(mock_task_fetch): # This section defines a sample workflow from a user @_sdk_tasks.inputs(a=_Types.Integer) @_sdk_tasks.outputs(b=_Types.Integer, c=_Types.Integer) @_sdk_tasks.python_task() def demo_task_for_promote(wf_params, a, b, c): b.set(a + 1) c.set(a + 2) @_sdk_workflow.workflow_class() class TestPromoteExampleWf(object): wf_input = _sdk_workflow.Input(_Types.Integer, required=True) my_task_node = demo_task_for_promote(a=wf_input) wf_output_b = _sdk_workflow.Output(my_task_node.outputs.b, sdk_type=_Types.Integer) wf_output_c = _sdk_workflow.Output(my_task_node.outputs.c, sdk_type=_Types.Integer) # This section uses the TaskTemplate stored in Admin to promote back to an Sdk Workflow int_type = _types.LiteralType(_types.SimpleType.INTEGER) task_interface = _interface.TypedInterface( # inputs {'a': _interface.Variable(int_type, "description1")}, # outputs { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) # Since the promotion of a workflow requires retrieving the task from Admin, we mock the SdkTask to return task_template = _task_model.TaskTemplate(_identifier.Identifier( _identifier.ResourceType.TASK, "project", "domain", "tests.flytekit.unit.common_tests.test_workflow_promote.demo_task_for_promote", "version"), "python_container", get_sample_task_metadata(), task_interface, custom={}, container=get_sample_container()) sdk_promoted_task = _task.SdkTask.promote_from_model(task_template) mock_task_fetch.return_value = sdk_promoted_task workflow_template = get_workflow_template() promoted_wf = _workflow_common.SdkWorkflow.promote_from_model( workflow_template) assert promoted_wf.interface.inputs[ "wf_input"] == TestPromoteExampleWf.interface.inputs["wf_input"] assert promoted_wf.interface.outputs[ "wf_output_b"] == TestPromoteExampleWf.interface.outputs["wf_output_b"] assert promoted_wf.interface.outputs[ "wf_output_c"] == TestPromoteExampleWf.interface.outputs["wf_output_c"] assert len(promoted_wf.nodes) == 1 assert len(TestPromoteExampleWf.nodes) == 1 assert promoted_wf.nodes[0].inputs[0] == TestPromoteExampleWf.nodes[ 0].inputs[0]
def test_variable_map(): v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') obj = interface.VariableMap({'vvv': v}) obj2 = interface.VariableMap.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2
def test_launch_plan_spec(): identifier_model = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") s = schedule.Schedule("asdf", "1 3 4 5 6 7") launch_plan_metadata_model = launch_plan.LaunchPlanMetadata( schedule=s, notifications=[]) v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") p = interface.Parameter(var=v) parameter_map = interface.ParameterMap({"ppp": p}) fixed_inputs = literals.LiteralMap({ "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive( integer=1))) }) labels_model = common.Labels({}) annotations_model = common.Annotations({"my": "annotation"}) auth_role_model = common.AuthRole(assumable_iam_role="my:iam:role") raw_data_output_config = common.RawOutputDataConfig("s3://bucket") empty_raw_data_output_config = common.RawOutputDataConfig("") max_parallelism = 100 lp_spec_raw_output_prefixed = launch_plan.LaunchPlanSpec( identifier_model, launch_plan_metadata_model, parameter_map, fixed_inputs, labels_model, annotations_model, auth_role_model, raw_data_output_config, max_parallelism, ) obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl( lp_spec_raw_output_prefixed.to_flyte_idl()) assert obj2 == lp_spec_raw_output_prefixed lp_spec_no_prefix = launch_plan.LaunchPlanSpec( identifier_model, launch_plan_metadata_model, parameter_map, fixed_inputs, labels_model, annotations_model, auth_role_model, empty_raw_data_output_config, max_parallelism, ) obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl( lp_spec_no_prefix.to_flyte_idl()) assert obj2 == lp_spec_no_prefix
def named_tuple_to_variable_map( cls, t: typing.NamedTuple) -> _interface_models.VariableMap: variables = {} for idx, (var_name, var_type) in enumerate(t._field_types.items()): literal_type = cls.to_literal_type(var_type) variables[var_name] = _interface_models.Variable( type=literal_type, description=f"{idx}") return _interface_models.VariableMap(variables=variables)
def test_parameter_map(): v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') p = interface.Parameter(var=v) obj = interface.ParameterMap({'ppp': p}) obj2 = interface.ParameterMap.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2
def test_parameter(): v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') obj = interface.Parameter(var=v) assert obj.var == v obj2 = interface.Parameter.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.var == v
def test_lp_closure(): v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') p = interface.Parameter(var=v) parameter_map = interface.ParameterMap({'ppp': p}) parameter_map.to_flyte_idl() variable_map = interface.VariableMap({'vvv': v}) obj = launch_plan.LaunchPlanClosure(state=launch_plan.LaunchPlanState.ACTIVE, expected_inputs=parameter_map, expected_outputs=variable_map) assert obj.expected_inputs == parameter_map assert obj.expected_outputs == variable_map obj2 = launch_plan.LaunchPlanClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.expected_inputs == parameter_map assert obj2.expected_outputs == variable_map
def __init__(self, name, sdk_type, help=None, **kwargs): """ :param Text name: :param flytekit.common.types.base_sdk_types.FlyteSdkType sdk_type: This is the SDK type necessary to create an input to this workflow. :param Text help: An optional help string to describe the input to users. :param bool required: If set to True, default must be None :param T default: If this is not a required input, the value will default to this value. """ param_default = None if "required" not in kwargs and "default" not in kwargs: # Neither required or default is set so assume required required = True default = None elif kwargs.get("required", False) and "default" in kwargs: # Required cannot be set to True and have a default specified raise _user_exceptions.FlyteAssertion( "Default cannot be set when required is True") elif "default" in kwargs: # If default is specified, then required must be false and the value is whatever is specified required = None default = kwargs["default"] param_default = sdk_type.from_python_std(default) else: # If no default is set, but required is set, then the behavior is determined by required == True or False default = None required = kwargs["required"] if not required: # If required == False, we assume default to be None param_default = sdk_type.from_python_std(default) required = None self._sdk_required = required or False self._sdk_default = default self._help = help self._sdk_type = sdk_type self._promise = _type_models.OutputReference( _constants.GLOBAL_INPUT_NODE_ID, name) self._name = name super(Input, self).__init__( _interface_models.Variable(type=sdk_type.to_flyte_literal_type(), description=help or ""), required=required, default=param_default, )
def __init__(self, name, value, sdk_type=None, help=None): """ :param Text name: :param T value: :param U sdk_type: If specified, the value provided must cast to this type. Normally should be an instance of flytekit.common.types.base_sdk_types.FlyteSdkType. But could also be something like: list[flytekit.common.types.base_sdk_types.FlyteSdkType], dict[flytekit.common.types.base_sdk_types.FlyteSdkType,flytekit.common.types.base_sdk_types.FlyteSdkType], (flytekit.common.types.base_sdk_types.FlyteSdkType, flytekit.common.types.base_sdk_types.FlyteSdkType, ...) """ if sdk_type is None: # This syntax didn't work for some reason: sdk_type = sdk_type or Output._infer_type(value) sdk_type = Output._infer_type(value) sdk_type = _type_helpers.python_std_to_sdk_type(sdk_type) self._binding_data = _interface.BindingData.from_python_std(sdk_type.to_flyte_literal_type(), value) self._var = _interface_models.Variable(sdk_type.to_flyte_literal_type(), help or "") self._name = name
def test_old_style_role(): identifier_model = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") s = schedule.Schedule("asdf", "1 3 4 5 6 7") launch_plan_metadata_model = launch_plan.LaunchPlanMetadata( schedule=s, notifications=[]) v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") p = interface.Parameter(var=v) parameter_map = interface.ParameterMap({"ppp": p}) fixed_inputs = literals.LiteralMap({ "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive( integer=1))) }) labels_model = common.Labels({}) annotations_model = common.Annotations({"my": "annotation"}) raw_data_output_config = common.RawOutputDataConfig("s3://bucket") old_role = _launch_plan_idl.Auth( kubernetes_service_account="my:service:account") old_style_spec = _launch_plan_idl.LaunchPlanSpec( workflow_id=identifier_model.to_flyte_idl(), entity_metadata=launch_plan_metadata_model.to_flyte_idl(), default_inputs=parameter_map.to_flyte_idl(), fixed_inputs=fixed_inputs.to_flyte_idl(), labels=labels_model.to_flyte_idl(), annotations=annotations_model.to_flyte_idl(), raw_output_data_config=raw_data_output_config.to_flyte_idl(), auth=old_role, ) lp_spec = launch_plan.LaunchPlanSpec.from_flyte_idl(old_style_spec) assert lp_spec.auth_role.assumable_iam_role == "my:service:account"
def apply_outputs_wrapper(task): if not isinstance(task, _sdk_runnable_tasks.SdkRunnableTask): additional_msg = \ "Outputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( task.__module__, task.__name__ if hasattr(task, "__name__") else "<unknown>" ) raise _user_exceptions.FlyteTypeException( expected_type=_sdk_runnable_tasks.SdkRunnableTask, received_type=type(task), received_value=task, additional_msg=additional_msg) for k, v in _six.iteritems(kwargs): kwargs[k] = _interface_model.Variable( _type_helpers.python_std_to_sdk_type( v).to_flyte_literal_type(), '') # TODO: Support descriptions task.add_outputs(kwargs) return task
def __init__( self, max_number_of_training_jobs: int, max_parallel_training_jobs: int, training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask], retries: int = 0, cacheable: bool = False, cache_version: str = "", tunable_parameters: typing.List[str] = None, ): """ :param int max_number_of_training_jobs: The maximum number of training jobs that can be launched by this hyperparameter tuning job :param int max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter tuning job in parallel :param typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask] training_job: The reference to the training job definition :param int retries: Number of retries to attempt :param bool cacheable: The flag to set if the user wants the output of the task execution to be cached :param str cache_version: String describing the caching version for task discovery purposes :param typing.List[str] tunable_parameters: A list of parameters that to tune. If you are tuning a built-int algorithm, refer to the algorithm's documentation to understand the possible values for the tunable parameters. E.g. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html for the list of hyperparameters for Image Classification built-in algorithm. If you are passing a custom training job, the list of tunable parameters must be a strict subset of the list of inputs defined on that job. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html for the list of supported hyperparameter types. """ # Use the training job model as a measure of type checking hpo_job = _hpo_job_model.HyperparameterTuningJob( max_number_of_training_jobs=max_number_of_training_jobs, max_parallel_training_jobs=max_parallel_training_jobs, training_job=training_job.training_job_model, ).to_flyte_idl() # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of # the underlying training job # TODO: Discuss whether this is a viable interface or contract timeout = _datetime.timedelta(seconds=0) inputs = {} inputs.update(training_job.interface.inputs) inputs.update({ "hyperparameter_tuning_job_config": _interface_model.Variable( HyperparameterTuningJobConfig.to_flyte_literal_type(), "", ), }) if tunable_parameters: inputs.update({ param: _interface_model.Variable( ParameterRange.to_flyte_literal_type(), "") for param in tunable_parameters }) super().__init__( type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, retries=_literal_models.RetryStrategy(retries=retries), interruptible=False, discovery_version=cache_version, deprecated_error_message="", ), interface=_interface.TypedInterface( inputs=inputs, outputs={ "model": _interface_model.Variable( type=_types_models.LiteralType( blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType. BlobDimensionality.SINGLE, )), description="", ) }, ), custom=MessageToDict(hpo_job), )
def __init__( self, training_job_resource_config: _training_job_models. TrainingJobResourceConfig, algorithm_specification: _training_job_models.AlgorithmSpecification, retries: int = 0, cacheable: bool = False, cache_version: str = "", ): """ :param training_job_resource_config: The options to configure the training job :param algorithm_specification: The options to configure the target algorithm of the training :param retries: Number of retries to attempt :param cacheable: The flag to set if the user wants the output of the task execution to be cached :param cache_version: String describing the caching version for task discovery purposes """ # Use the training job model as a measure of type checking self._training_job_model = _training_job_models.TrainingJob( algorithm_specification=algorithm_specification, training_job_resource_config=training_job_resource_config, ) # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training # job gracefully timeout = _datetime.timedelta(seconds=0) super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__( type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, retries=_literal_models.RetryStrategy(retries=retries), interruptible=False, discovery_version=cache_version, deprecated_error_message="", ), interface=_interface.TypedInterface( inputs={ "static_hyperparameters": _interface_model.Variable( type=_idl_types.LiteralType( simple=_idl_types.SimpleType.STRUCT), description="", ), "train": _interface_model.Variable( type=_idl_types.LiteralType(blob=_core_types.BlobType( format=_content_type_to_blob_format( algorithm_specification.input_content_type), dimensionality=_core_types.BlobType. BlobDimensionality.MULTIPART, ), ), description="", ), "validation": _interface_model.Variable( type=_idl_types.LiteralType(blob=_core_types.BlobType( format=_content_type_to_blob_format( algorithm_specification.input_content_type), dimensionality=_core_types.BlobType. BlobDimensionality.MULTIPART, ), ), description="", ), }, outputs={ "model": _interface_model.Variable( type=_idl_types.LiteralType(blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType. BlobDimensionality.SINGLE, )), description="", ) }, ), custom=MessageToDict(self._training_job_model.to_flyte_idl()), )
def __init__( self, region, role_arn, resource_config, algorithm_specification=None, stopping_condition=None, vpc_config=None, enable_spot_training=False, interruptible=False, retries=0, cacheable=False, cache_version="", ): """ :param Text region: The region in which to run the SageMaker job. :param Text role_arn: The ARN of the role to run in the SageMaker job. :param dict[Text,T] algorithm_specification: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html :param dict[Text,T] resource_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ResourceConfig.html :param dict[Text,T] stopping_condition: https://docs.aws.amazon.com/sagemaker/latest/dg/API_StoppingCondition.html :param dict[Text,T] vpc_config: https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html :param bool enable_spot_training: https://docs.aws.amazon.com/sagemaker/latest/dg/API_HyperParameterTrainingJobDefinition.html :param int retries: Number of time to retry. :param bool cacheable: Whether or not to use Flyte's caching system. :param Text cache_version: Update this to notify a behavioral change requiring the cache to be invalidated. """ algorithm_specification = algorithm_specification or {} algorithm_specification["TrainingImage"] = ( algorithm_specification.get("TrainingImage") or "825641698319.dkr.ecr.us-east-2.amazonaws.com/xgboost:1") algorithm_specification["TrainingInputMode"] = "File" job_config = ParseDict( { "Region": region, "ResourceConfig": resource_config, "StoppingCondition": stopping_condition, "VpcConfig": vpc_config, "AlgorithmSpecification": algorithm_specification, "RoleArn": role_arn, }, sagemaker_pb2.SagemakerHPOJob(), ) print(MessageToDict(job_config)) # TODO: Optionally, pull timeout behavior from stopping condition and pass to Flyte task def. timeout = _datetime.timedelta(seconds=0) # TODO: The FlyteKit type engine is extensible so we can create a SagemakerInput type with custom # TODO: parsing/casting logic. For now, we will use the Generic type since there is a little that needs # TODO: to be done on Flyte side to unlock this cleanly. # TODO: This call to the super-constructor will be less verbose in future versions of Flytekit following a # TODO: refactor. # TODO: Add more configurations to the custom dict. These are things that are necessary to execute the task, # TODO: but might not affect the outputs (i.e. Running on a bigger machine). These are currently static for # TODO: a given definition of a task, but will be more dynamic in the future. Also, it is possible to # TODO: make it dynamic by using our @dynamic_task. # TODO: You might want to inherit the role ARN from the execution at runtime. super(SagemakerXgBoostOptimizer, self).__init__( type=_TASK_TYPE, metadata=_task_models.TaskMetadata( discoverable=cacheable, runtime=_task_models.RuntimeMetadata(0, "0.1.0b0", "sagemaker"), timeout=timeout, retries=_literal_models.RetryStrategy(retries=retries), interruptible=interruptible, discovery_version=cache_version, deprecated_error_message="", ), interface=_interface.TypedInterface({}, {}), custom=MessageToDict(job_config), ) # TODO: Add more inputs that we expect to change the outputs of the task. # TODO: We can add outputs too! # We use helper methods for adding to interface, thus overriding the one set above. This will be simplified post # refactor. self.add_inputs({ "static_hyperparameters": _interface_model.Variable( _sdk_types.Types.Generic.to_flyte_literal_type(), ""), "train": _interface_model.Variable( _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""), "validation": _interface_model.Variable( _sdk_types.Types.MultiPartCSV.to_flyte_literal_type(), ""), }) self.add_outputs({ "model": _interface_model.Variable( _sdk_types.Types.Blob.to_flyte_literal_type(), "") })
def __init__( self, max_number_of_training_jobs: int, max_parallel_training_jobs: int, training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask], retries: int = 0, cacheable: bool = False, cache_version: str = "", ): """ :param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this hyperparameter tuning job :param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter tuning job in parallel :param training_job: The reference to the training job definition :param retries: Number of retries to attempt :param cacheable: The flag to set if the user wants the output of the task execution to be cached :param cache_version: String describing the caching version for task discovery purposes """ # Use the training job model as a measure of type checking hpo_job = _hpo_job_model.HyperparameterTuningJob( max_number_of_training_jobs=max_number_of_training_jobs, max_parallel_training_jobs=max_parallel_training_jobs, training_job=training_job.training_job_model, ).to_flyte_idl() # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of # the underlying training job # TODO: Discuss whether this is a viable interface or contract timeout = _datetime.timedelta(seconds=0) inputs = { "hyperparameter_tuning_job_config": _interface_model.Variable( _sdk_types.Types.Proto( _pb2_hpo_job.HyperparameterTuningJobConfig). to_flyte_literal_type(), "", ), } inputs.update(training_job.interface.inputs) super(SdkSimpleHyperparameterTuningJobTask, self).__init__( type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, retries=_literal_models.RetryStrategy(retries=retries), interruptible=False, discovery_version=cache_version, deprecated_error_message="", ), interface=_interface.TypedInterface( inputs=inputs, outputs={ "model": _interface_model.Variable( type=_types_models.LiteralType( blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType. BlobDimensionality.SINGLE, )), description="", ) }, ), custom=MessageToDict(hpo_job), )
def transform_type(x: type, description: str = None) -> _interface_models.Variable: return _interface_models.Variable(type=TypeEngine.to_literal_type(x), description=description)
def __init__( self, notebook_path, inputs, outputs, task_type, discovery_version, retries, deprecated, storage_request, cpu_request, gpu_request, memory_request, storage_limit, cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, custom, ): if _os.path.isabs(notebook_path) is False: # Find absolute path for the notebook. task_module = _importlib.import_module(_find_instance_module()) module_path = _os.path.dirname(task_module.__file__) notebook_path = _os.path.normpath( _os.path.join(module_path, notebook_path)) self._notebook_path = notebook_path super(SdkNotebookTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "notebook", ), timeout, _literal_models.RetryStrategy(retries), False, discovery_version, deprecated, ), _interface2.TypedInterface({}, {}), custom, container=self._get_container_definition( storage_request=storage_request, cpu_request=cpu_request, gpu_request=gpu_request, memory_request=memory_request, storage_limit=storage_limit, cpu_limit=cpu_limit, gpu_limit=gpu_limit, memory_limit=memory_limit, environment=environment, ), ) # Add Inputs if inputs is not None: inputs(self) # Add outputs if outputs is not None: outputs(self) # Add a Notebook output as a Blob. self.interface.outputs.update(output_notebook=_interface.Variable( _Types.Blob.to_flyte_literal_type(), OUTPUT_NOTEBOOK))
types.LiteralType(collection_type=literal_type) for literal_type in LIST_OF_SCALAR_LITERAL_TYPES ] LIST_OF_NESTED_COLLECTION_LITERAL_TYPES = [ types.LiteralType(collection_type=literal_type) for literal_type in LIST_OF_COLLECTION_LITERAL_TYPES ] LIST_OF_ALL_LITERAL_TYPES = \ LIST_OF_SCALAR_LITERAL_TYPES + \ LIST_OF_COLLECTION_LITERAL_TYPES + \ LIST_OF_NESTED_COLLECTION_LITERAL_TYPES LIST_OF_INTERFACES = [ interface.TypedInterface({'a': interface.Variable(t, "description 1")}, {'b': interface.Variable(t, "description 2")}) for t in LIST_OF_ALL_LITERAL_TYPES ] LIST_OF_RESOURCE_ENTRIES = [ task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "1"), task.Resources.ResourceEntry(task.Resources.ResourceName.GPU, "1"), task.Resources.ResourceEntry(task.Resources.ResourceName.MEMORY, "1G"), task.Resources.ResourceEntry(task.Resources.ResourceName.STORAGE, "1G") ] LIST_OF_RESOURCE_ENTRY_LISTS = [LIST_OF_RESOURCE_ENTRIES] LIST_OF_RESOURCES = [ task.Resources(request, limit) for request, limit in product(
types.LiteralType(collection_type=literal_type) for literal_type in LIST_OF_SCALAR_LITERAL_TYPES ] LIST_OF_NESTED_COLLECTION_LITERAL_TYPES = [ types.LiteralType(collection_type=literal_type) for literal_type in LIST_OF_COLLECTION_LITERAL_TYPES ] LIST_OF_ALL_LITERAL_TYPES = (LIST_OF_SCALAR_LITERAL_TYPES + LIST_OF_COLLECTION_LITERAL_TYPES + LIST_OF_NESTED_COLLECTION_LITERAL_TYPES) LIST_OF_INTERFACES = [ interface.TypedInterface( {"a": interface.Variable(t, "description 1")}, {"b": interface.Variable(t, "description 2")}, ) for t in LIST_OF_ALL_LITERAL_TYPES ] LIST_OF_RESOURCE_ENTRIES = [ task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "1"), task.Resources.ResourceEntry(task.Resources.ResourceName.GPU, "1"), task.Resources.ResourceEntry(task.Resources.ResourceName.MEMORY, "1G"), task.Resources.ResourceEntry(task.Resources.ResourceName.STORAGE, "1G"), ] LIST_OF_RESOURCE_ENTRY_LISTS = [LIST_OF_RESOURCE_ENTRIES] LIST_OF_RESOURCES = [ task.Resources(request, limit) for request, limit in product(
def test_variable_type(literal_type): var = interface.Variable(type=literal_type, description="abc") assert var.type == literal_type assert var.description == "abc" assert var == interface.Variable.from_flyte_idl(var.to_flyte_idl())
def __init__( self, statement, output_schema, routing_group=None, catalog=None, schema=None, task_inputs=None, interruptible=False, discoverable=False, discovery_version=None, retries=1, timeout=None, deprecated=None, ): """ :param Text statement: Presto query specification :param flytekit.common.types.schema.Schema output_schema: Schema that represents that data queried from Presto :param Text routing_group: The routing group that a Presto query should be sent to for the given environment :param Text catalog: The catalog to set for the given Presto query :param Text schema: The schema to set for the given Presto query :param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] task_inputs: Optional inputs to the Presto task :param bool discoverable: :param Text discovery_version: String describing the version for task discovery purposes :param int retries: Number of retries to attempt :param datetime.timedelta timeout: :param Text deprecated: This string can be used to mark the task as deprecated. Consumers of the task will receive deprecation warnings. """ # Set as class fields which are used down below to configure implicit # parameters self._routing_group = routing_group or "" self._catalog = catalog or "" self._schema = schema or "" metadata = _task_model.TaskMetadata( discoverable, # This needs to have the proper version reflected in it _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"), timeout or _datetime.timedelta(seconds=0), _literals.RetryStrategy(retries), interruptible, discovery_version, deprecated, ) presto_query = _presto_models.PrestoQuery( routing_group=routing_group or "", catalog=catalog or "", schema=schema or "", statement=statement, ) # Here we set the routing_group, catalog, and schema as implicit # parameters for caching purposes i = _interface.TypedInterface( { "__implicit_routing_group": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), description="The routing group set as an implicit input", ), "__implicit_catalog": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), description="The catalog set as an implicit input", ), "__implicit_schema": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), description="The schema set as an implicit input", ), }, { # Set the schema for the Presto query as an output "results": _interface_model.Variable( type=_types.LiteralType(schema=output_schema.schema_type), description="The schema for the Presto query", ) }, ) super(SdkPrestoTask, self).__init__( _constants.SdkTaskType.PRESTO_TASK, metadata, i, _MessageToDict(presto_query.to_flyte_idl()), ) # Set user provided inputs task_inputs(self)
def test_workflow_closure(): int_type = _types.LiteralType(_types.SimpleType.INTEGER) typed_interface = _interface.TypedInterface( {'a': _interface.Variable(int_type, "description1")}, { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) b0 = _literals.Binding( 'a', _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5)))) b1 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'b'))) b2 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'c'))) node_metadata = _workflow.NodeMetadata(name='node1', timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) task_metadata = _task.TaskMetadata( True, _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!") cpu_resource = _task.Resources.ResourceEntry( _task.Resources.ResourceName.CPU, "1") resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource]) task = _task.TaskTemplate( _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, typed_interface, { 'a': 1, 'b': { 'c': 2, 'd': 3 } }, container=_task.Container("my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {})) task_node = _workflow.TaskNode(task.id) node = _workflow.Node(id='my_node', metadata=node_metadata, inputs=[b0], upstream_node_ids=[], output_aliases=[], task_node=task_node) template = _workflow.WorkflowTemplate( id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version"), metadata=_workflow.WorkflowMetadata(), interface=typed_interface, nodes=[node], outputs=[b1, b2], ) obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task]) assert len(obj.tasks) == 1 obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2
def test_interface(): ctx = FlyteContextManager.current_context() lt = TypeEngine.to_literal_type(pd.DataFrame) df = pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]}) annotated_sd_type = Annotated[StructuredDataset, kwtypes(name=str, age=int)] df_literal_type = TypeEngine.to_literal_type(annotated_sd_type) assert df_literal_type.structured_dataset_type is not None assert len(df_literal_type.structured_dataset_type.columns) == 2 assert df_literal_type.structured_dataset_type.columns[0].name == "name" assert df_literal_type.structured_dataset_type.columns[0].literal_type.simple is not None assert df_literal_type.structured_dataset_type.columns[1].name == "age" assert df_literal_type.structured_dataset_type.columns[1].literal_type.simple is not None sd = annotated_sd_type(df) sd_literal = TypeEngine.to_literal(ctx, sd, python_type=annotated_sd_type, expected=lt) lm = { "my_map": Literal( map=LiteralMap( literals={ "k1": Literal(scalar=Scalar(primitive=Primitive(string_value="v1"))), "k2": Literal(scalar=Scalar(primitive=Primitive(string_value="2"))), }, ) ), "my_list": Literal( collection=LiteralCollection( literals=[ Literal(scalar=Scalar(primitive=Primitive(integer=1))), Literal(scalar=Scalar(primitive=Primitive(integer=2))), Literal(scalar=Scalar(primitive=Primitive(integer=3))), ] ) ), "val_a": Literal(scalar=Scalar(primitive=Primitive(integer=21828))), "my_df": sd_literal, } variable_map = { "my_map": interface_models.Variable(type=TypeEngine.to_literal_type(typing.Dict[str, str]), description=""), "my_list": interface_models.Variable(type=TypeEngine.to_literal_type(typing.List[int]), description=""), "val_a": interface_models.Variable(type=TypeEngine.to_literal_type(int), description=""), "my_df": interface_models.Variable(type=df_literal_type, description=""), } lr = LiteralsResolver(lm, variable_map=variable_map, ctx=ctx) assert lr._ctx is ctx with pytest.raises(ValueError): lr["not"] # noqa with pytest.raises(ValueError): lr.get_literal("not") # Test that just using [] works, guessing from the Flyte type is invoked result = lr["my_list"] assert result == [1, 2, 3] # Test that using get works, guessing from the Flyte type is invoked result = lr.get("my_map") assert result == { "k1": "v1", "k2": "2", } # Getting the literal will return the Literal object itself assert lr.get_literal("my_df") is sd_literal guessed_df = lr["my_df"] # Based on guessing, so no column information assert len(guessed_df.metadata.structured_dataset_type.columns) == 0 guessed_df_2 = lr["my_df"] assert guessed_df is guessed_df_2 # Update type hints with the annotated type lr.update_type_hints({"my_df": annotated_sd_type}) del lr._native_values["my_df"] guessed_df = lr.get("my_df") # Using the user specified type, so number of columns is correct. assert len(guessed_df.metadata.structured_dataset_type.columns) == 2