def test_literal_types(): obj = _types.LiteralType(simple=_types.SimpleType.INTEGER) assert obj.simple == _types.SimpleType.INTEGER assert obj.schema is None assert obj.collection_type is None assert obj.map_value_type is None assert obj == _types.LiteralType.from_flyte_idl(obj.to_flyte_idl()) schema_type = _types.SchemaType([ _types.SchemaType.SchemaColumn( "a", _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), _types.SchemaType.SchemaColumn( "b", _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), _types.SchemaType.SchemaColumn( "c", _types.SchemaType.SchemaColumn.SchemaColumnType.STRING), _types.SchemaType.SchemaColumn( "d", _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), _types.SchemaType.SchemaColumn( "e", _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), _types.SchemaType.SchemaColumn( "f", _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN) ]) obj = _types.LiteralType(schema=schema_type) assert obj.simple is None assert obj.schema == schema_type assert obj.collection_type is None assert obj.map_value_type is None assert obj == _types.LiteralType.from_flyte_idl(obj.to_flyte_idl())
def test_default_python_task(): assert isinstance(default_task, _spark_task.SdkSparkTask) assert isinstance(default_task, _sdk_runnable.SdkRunnableTask) assert default_task.interface.inputs["in1"].description == "" assert default_task.interface.inputs[ "in1"].type == _type_models.LiteralType( simple=_type_models.SimpleType.INTEGER) assert default_task.interface.outputs["out1"].description == "" assert default_task.interface.outputs[ "out1"].type == _type_models.LiteralType( simple=_type_models.SimpleType.STRING) assert default_task.type == _common_constants.SdkTaskType.SPARK_TASK assert default_task.task_function_name == "default_task" assert default_task.task_module == __name__ assert default_task.metadata.timeout == _datetime.timedelta(seconds=0) assert default_task.metadata.deprecated_error_message == "" assert default_task.metadata.discoverable is False assert default_task.metadata.discovery_version == "" assert default_task.metadata.retries.retries == 0 assert len(default_task.container.resources.limits) == 0 assert len(default_task.container.resources.requests) == 0 assert default_task.custom["sparkConf"]["A"] == "B" assert default_task.custom["hadoopConf"]["C"] == "D" assert default_task.hadoop_conf["C"] == "D" assert default_task.spark_conf["A"] == "B" assert _os.path.abspath( _entrypoint.__file__ )[:-1] in default_task.custom["mainApplicationFile"] assert default_task.custom["executorPath"] == _sys.executable pb2 = default_task.to_flyte_idl() assert pb2.custom["sparkConf"]["A"] == "B" assert pb2.custom["hadoopConf"]["C"] == "D"
def test_simple_pytorch_task(): assert isinstance(simple_pytorch_task, _pytorch_task.SdkPyTorchTask) assert isinstance(simple_pytorch_task, _sdk_runnable.SdkRunnableTask) assert simple_pytorch_task.interface.inputs["in1"].description == "" assert simple_pytorch_task.interface.inputs[ "in1"].type == _type_models.LiteralType( simple=_type_models.SimpleType.INTEGER) assert simple_pytorch_task.interface.outputs["out1"].description == "" assert simple_pytorch_task.interface.outputs[ "out1"].type == _type_models.LiteralType( simple=_type_models.SimpleType.STRING) assert simple_pytorch_task.type == _common_constants.SdkTaskType.PYTORCH_TASK assert simple_pytorch_task.task_function_name == "simple_pytorch_task" assert simple_pytorch_task.task_module == __name__ assert simple_pytorch_task.metadata.timeout == _datetime.timedelta( seconds=0) assert simple_pytorch_task.metadata.deprecated_error_message == "" assert simple_pytorch_task.metadata.discoverable is False assert simple_pytorch_task.metadata.discovery_version == "" assert simple_pytorch_task.metadata.retries.retries == 0 assert len(simple_pytorch_task.container.resources.limits) == 0 assert len(simple_pytorch_task.container.resources.requests) == 0 assert simple_pytorch_task.custom["workers"] == 1 # Should strip out the venv component of the args. assert simple_pytorch_task._get_container_definition( ).args[0] == "pyflyte-execute" pb2 = simple_pytorch_task.to_flyte_idl() assert pb2.custom["workers"] == 1
def test_simple_hpo_job_task(): assert isinstance(simple_xgboost_hpo_job_task, SdkSimpleHyperparameterTuningJobTask) assert isinstance(simple_xgboost_hpo_job_task, _sdk_task.SdkTask) # Checking if the input of the underlying SdkTrainingJobTask has been embedded assert simple_xgboost_hpo_job_task.interface.inputs[ "train"].description == "" assert (simple_xgboost_hpo_job_task.interface.inputs["train"].type == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type()) assert simple_xgboost_hpo_job_task.interface.inputs[ "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType( format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, )) assert simple_xgboost_hpo_job_task.interface.inputs[ "validation"].description == "" assert (simple_xgboost_hpo_job_task.interface.inputs["validation"].type == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type()) assert simple_xgboost_hpo_job_task.interface.inputs[ "validation"].type == _idl_types.LiteralType(blob=_core_types.BlobType( format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, )) assert simple_xgboost_hpo_job_task.interface.inputs[ "static_hyperparameters"].description == "" assert ( simple_xgboost_hpo_job_task.interface.inputs["static_hyperparameters"]. type == _sdk_types.Types.Generic.to_flyte_literal_type()) # Checking if the hpo-specific input is defined assert simple_xgboost_hpo_job_task.interface.inputs[ "hyperparameter_tuning_job_config"].description == "" assert (simple_xgboost_hpo_job_task.interface. inputs["hyperparameter_tuning_job_config"].type == _sdk_types.Types.Proto(_pb2_HPOJobConfig).to_flyte_literal_type()) assert simple_xgboost_hpo_job_task.interface.outputs[ "model"].description == "" assert simple_xgboost_hpo_job_task.interface.outputs[ "model"].type == _sdk_types.Types.Blob.to_flyte_literal_type() assert simple_xgboost_hpo_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK # Checking if the spec of the TrainingJob is embedded into the custom field # of this SdkSimpleHyperparameterTuningJobTask assert simple_xgboost_hpo_job_task.to_flyte_idl( ).custom["trainingJob"] == ( builtin_algorithm_training_job_task2.to_flyte_idl().custom) assert simple_xgboost_hpo_job_task.metadata.timeout == _datetime.timedelta( seconds=0) assert simple_xgboost_hpo_job_task.metadata.discoverable is True assert simple_xgboost_hpo_job_task.metadata.discovery_version == "1" assert simple_xgboost_hpo_job_task.metadata.retries.retries == 2 assert simple_xgboost_hpo_job_task.metadata.deprecated_error_message == "" assert "metricDefinitions" in simple_xgboost_hpo_job_task.custom[ "trainingJob"]["algorithmSpecification"].keys() assert len(simple_xgboost_hpo_job_task.custom["trainingJob"] ["algorithmSpecification"]["metricDefinitions"]) == 1 """
def test_guessing_containers(): b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN) lt = model_types.LiteralType(collection_type=b) pt = TypeEngine.guess_python_type(lt) assert pt == typing.List[bool] dur = model_types.LiteralType(simple=model_types.SimpleType.DURATION) lt = model_types.LiteralType(map_value_type=dur) pt = TypeEngine.guess_python_type(lt) assert pt == typing.Dict[str, timedelta]
def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType: if t in self._SUPPORTED_TYPES: return self._SUPPORTED_TYPES[t] if hasattr(t, "__origin__") and t.__origin__ == list: return type_models.LiteralType( collection_type=self._get_dataset_column_literal_type( t.__args__[0])) if hasattr(t, "__origin__") and t.__origin__ == dict: return type_models.LiteralType( map_value_type=self._get_dataset_column_literal_type( t.__args__[1])) raise AssertionError( f"type {t} is currently not supported by StructuredDataset")
def test_default_python_task(): assert isinstance(default_task, _sdk_runnable.SdkRunnableTask) assert default_task.interface.inputs['in1'].description == '' assert default_task.interface.inputs['in1'].type == \ _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER) assert default_task.interface.outputs['out1'].description == '' assert default_task.interface.outputs['out1'].type == \ _type_models.LiteralType(simple=_type_models.SimpleType.STRING) assert default_task.type == _common_constants.SdkTaskType.PYTHON_TASK assert default_task.task_function_name == 'default_task' assert default_task.task_module == __name__ assert default_task.metadata.timeout == _datetime.timedelta(seconds=0) assert default_task.metadata.deprecated_error_message == '' assert default_task.metadata.discoverable is False assert default_task.metadata.discovery_version == '' assert default_task.metadata.retries.retries == 0
def test_workflow_template(): task = _workflow.TaskNode(reference_id=_generic_id) nm = _get_sample_node_metadata() int_type = _types.LiteralType(_types.SimpleType.INTEGER) wf_metadata = _workflow.WorkflowMetadata() wf_metadata_defaults = _workflow.WorkflowMetadataDefaults() typed_interface = _interface.TypedInterface( {"a": _interface.Variable(int_type, "description1")}, { "b": _interface.Variable(int_type, "description2"), "c": _interface.Variable(int_type, "description3") }, ) wf_node = _workflow.Node( id="some:node:id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task, ) obj = _workflow.WorkflowTemplate( id=_generic_id, metadata=wf_metadata, metadata_defaults=wf_metadata_defaults, interface=typed_interface, nodes=[wf_node], outputs=[], ) obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj
def test_workflow_template_with_queuing_budget(): task = _workflow.TaskNode(reference_id=_generic_id) nm = _get_sample_node_metadata() int_type = _types.LiteralType(_types.SimpleType.INTEGER) wf_metadata = _workflow.WorkflowMetadata(queuing_budget=timedelta( seconds=10)) wf_metadata_defaults = _workflow.WorkflowMetadataDefaults() typed_interface = _interface.TypedInterface( {'a': _interface.Variable(int_type, "description1")}, { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) wf_node = _workflow.Node(id='some:node:id', metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task) obj = _workflow.WorkflowTemplate(id=_generic_id, metadata=wf_metadata, metadata_defaults=wf_metadata_defaults, interface=typed_interface, nodes=[wf_node], outputs=[]) obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj
def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType(blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE))
def test_literal_collections(literal_type): obj = _types.LiteralType(collection_type=literal_type) assert obj.collection_type == literal_type assert obj.simple is None assert obj.schema is None assert obj.map_value_type is None assert obj == _types.LiteralType.from_flyte_idl(obj.to_flyte_idl())
def test_variable_map(): v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') obj = interface.VariableMap({'vvv': v}) obj2 = interface.VariableMap.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2
def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType( simple=_idl_types.SimpleType.BINARY, metadata={cls.PB_FIELD_KEY: cls.descriptor})
def test_launch_plan_spec(): identifier_model = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") s = schedule.Schedule("asdf", "1 3 4 5 6 7") launch_plan_metadata_model = launch_plan.LaunchPlanMetadata( schedule=s, notifications=[]) v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") p = interface.Parameter(var=v) parameter_map = interface.ParameterMap({"ppp": p}) fixed_inputs = literals.LiteralMap({ "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive( integer=1))) }) labels_model = common.Labels({}) annotations_model = common.Annotations({"my": "annotation"}) auth_role_model = common.AuthRole(assumable_iam_role="my:iam:role") raw_data_output_config = common.RawOutputDataConfig("s3://bucket") empty_raw_data_output_config = common.RawOutputDataConfig("") max_parallelism = 100 lp_spec_raw_output_prefixed = launch_plan.LaunchPlanSpec( identifier_model, launch_plan_metadata_model, parameter_map, fixed_inputs, labels_model, annotations_model, auth_role_model, raw_data_output_config, max_parallelism, ) obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl( lp_spec_raw_output_prefixed.to_flyte_idl()) assert obj2 == lp_spec_raw_output_prefixed lp_spec_no_prefix = launch_plan.LaunchPlanSpec( identifier_model, launch_plan_metadata_model, parameter_map, fixed_inputs, labels_model, annotations_model, auth_role_model, empty_raw_data_output_config, max_parallelism, ) obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl( lp_spec_no_prefix.to_flyte_idl()) assert obj2 == lp_spec_no_prefix
def test_annotated_literal_types(): obj = _types.LiteralType(simple=_types.SimpleType.INTEGER, annotation=TypeAnnotation(annotations={"foo": "bar"})) assert obj.simple == _types.SimpleType.INTEGER assert obj.schema is None assert obj.collection_type is None assert obj.map_value_type is None assert obj.annotation.annotations == {"foo": "bar"} assert obj == _types.LiteralType.from_flyte_idl(obj.to_flyte_idl())
def test_unloadable_proto_from_literal_type(): with pytest.raises(_user_exceptions.FlyteAssertion): _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( _type_models.LiteralType( simple=_type_models.SimpleType.BINARY, metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerErrorNoExist"}, ) )
def test_parameter_map(): v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') p = interface.Parameter(var=v) obj = interface.ParameterMap({'ppp': p}) obj2 = interface.ParameterMap.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2
def test_guessing_basic(): b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN) pt = TypeEngine.guess_python_type(b) assert pt is bool lt = model_types.LiteralType(simple=model_types.SimpleType.INTEGER) pt = TypeEngine.guess_python_type(lt) assert pt is int lt = model_types.LiteralType(simple=model_types.SimpleType.STRING) pt = TypeEngine.guess_python_type(lt) assert pt is str lt = model_types.LiteralType(simple=model_types.SimpleType.DURATION) pt = TypeEngine.guess_python_type(lt) assert pt is timedelta lt = model_types.LiteralType(simple=model_types.SimpleType.DATETIME) pt = TypeEngine.guess_python_type(lt) assert pt is datetime.datetime lt = model_types.LiteralType(simple=model_types.SimpleType.FLOAT) pt = TypeEngine.guess_python_type(lt) assert pt is float lt = model_types.LiteralType(simple=model_types.SimpleType.NONE) pt = TypeEngine.guess_python_type(lt) assert pt is None
def get_literal_type(self, t: Type[T]) -> LiteralType: """ Only univariate Lists are supported in Flyte """ try: sub_type = TypeEngine.to_literal_type(self.get_sub_type(t)) return _type_models.LiteralType(collection_type=sub_type) except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}")
def test_parameter(): v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') obj = interface.Parameter(var=v) assert obj.var == v obj2 = interface.Parameter.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.var == v
def test_task_template__k8s_pod_target(): int_type = types.LiteralType(types.SimpleType.INTEGER) obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task.TaskMetadata( False, task.RuntimeMetadata(1, "v", "f"), timedelta(days=1), literal_models.RetryStrategy(5), False, "1.0", "deprecated", False, ), interface_models.TypedInterface( # inputs {"a": interface_models.Variable(int_type, "description1")}, # outputs { "b": interface_models.Variable(int_type, "description2"), "c": interface_models.Variable(int_type, "description3"), }, ), { "a": 1, "b": { "c": 2, "d": 3 } }, config={"a": "b"}, k8s_pod=task.K8sPod( metadata=task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}), pod_spec={ "str": "val", "int": 1 }, ), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" assert obj.id.domain == "domain" assert obj.id.name == "name" assert obj.id.version == "version" assert obj.type == "python" assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}} assert obj.k8s_pod.metadata == task.K8sObjectMetadata( labels={"label": "foo"}, annotations={"anno": "bar"}) assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1} assert text_format.MessageToString( obj.to_flyte_idl()) == text_format.MessageToString( task.TaskTemplate.from_flyte_idl( obj.to_flyte_idl()).to_flyte_idl()) assert obj.config == {"a": "b"}
def test_generic_proto_from_literal_type(): sdk_type = _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( _type_models.LiteralType( simple=_type_models.SimpleType.STRUCT, metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerError"}, ) ) assert sdk_type.pb_type == _errors_pb2.ContainerError
def get_literal_type(self, t: Type[dict]) -> LiteralType: tp = self.get_dict_types(t) if tp: if tp[0] == str: try: sub_type = TypeEngine.to_literal_type(tp[1]) return _type_models.LiteralType(map_value_type=sub_type) except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") return _primitives.Generic.to_flyte_literal_type()
def test_basic_workflow_promote(mock_task_fetch): # This section defines a sample workflow from a user @_sdk_tasks.inputs(a=_Types.Integer) @_sdk_tasks.outputs(b=_Types.Integer, c=_Types.Integer) @_sdk_tasks.python_task() def demo_task_for_promote(wf_params, a, b, c): b.set(a + 1) c.set(a + 2) @_sdk_workflow.workflow_class() class TestPromoteExampleWf(object): wf_input = _sdk_workflow.Input(_Types.Integer, required=True) my_task_node = demo_task_for_promote(a=wf_input) wf_output_b = _sdk_workflow.Output(my_task_node.outputs.b, sdk_type=_Types.Integer) wf_output_c = _sdk_workflow.Output(my_task_node.outputs.c, sdk_type=_Types.Integer) # This section uses the TaskTemplate stored in Admin to promote back to an Sdk Workflow int_type = _types.LiteralType(_types.SimpleType.INTEGER) task_interface = _interface.TypedInterface( # inputs {'a': _interface.Variable(int_type, "description1")}, # outputs { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) # Since the promotion of a workflow requires retrieving the task from Admin, we mock the SdkTask to return task_template = _task_model.TaskTemplate(_identifier.Identifier( _identifier.ResourceType.TASK, "project", "domain", "tests.flytekit.unit.common_tests.test_workflow_promote.demo_task_for_promote", "version"), "python_container", get_sample_task_metadata(), task_interface, custom={}, container=get_sample_container()) sdk_promoted_task = _task.SdkTask.promote_from_model(task_template) mock_task_fetch.return_value = sdk_promoted_task workflow_template = get_workflow_template() promoted_wf = _workflow_common.SdkWorkflow.promote_from_model( workflow_template) assert promoted_wf.interface.inputs[ "wf_input"] == TestPromoteExampleWf.interface.inputs["wf_input"] assert promoted_wf.interface.outputs[ "wf_output_b"] == TestPromoteExampleWf.interface.outputs["wf_output_b"] assert promoted_wf.interface.outputs[ "wf_output_c"] == TestPromoteExampleWf.interface.outputs["wf_output_c"] assert len(promoted_wf.nodes) == 1 assert len(TestPromoteExampleWf.nodes) == 1 assert promoted_wf.nodes[0].inputs[0] == TestPromoteExampleWf.nodes[ 0].inputs[0]
def test_construct_literal_map_from_variable_map(): v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description") variable_map = { 'inputa': v, } input_txt_dictionary = {'inputa': '15'} literal_map = helpers.construct_literal_map_from_variable_map(variable_map, input_txt_dictionary) parsed_literal = literal_map.literals['inputa'].value ll = literals.Scalar(primitive=literals.Primitive(integer=15)) assert parsed_literal == ll
def test_construct_literal_map_from_parameter_map(): v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description") p = Parameter(var=v, required=True) pm = ParameterMap(parameters={"inputa": p}) input_txt_dictionary = {"inputa": "15"} literal_map = helpers.construct_literal_map_from_parameter_map(pm, input_txt_dictionary) parsed_literal = literal_map.literals["inputa"].value ll = literals.Scalar(primitive=literals.Primitive(integer=15)) assert parsed_literal == ll with pytest.raises(Exception): helpers.construct_literal_map_from_parameter_map(pm, {})
def test_lp_closure(): v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') p = interface.Parameter(var=v) parameter_map = interface.ParameterMap({'ppp': p}) parameter_map.to_flyte_idl() variable_map = interface.VariableMap({'vvv': v}) obj = launch_plan.LaunchPlanClosure(state=launch_plan.LaunchPlanState.ACTIVE, expected_inputs=parameter_map, expected_outputs=variable_map) assert obj.expected_inputs == parameter_map assert obj.expected_outputs == variable_map obj2 = launch_plan.LaunchPlanClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.expected_inputs == parameter_map assert obj2.expected_outputs == variable_map
def test_old_style_role(): identifier_model = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") s = schedule.Schedule("asdf", "1 3 4 5 6 7") launch_plan_metadata_model = launch_plan.LaunchPlanMetadata( schedule=s, notifications=[]) v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") p = interface.Parameter(var=v) parameter_map = interface.ParameterMap({"ppp": p}) fixed_inputs = literals.LiteralMap({ "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive( integer=1))) }) labels_model = common.Labels({}) annotations_model = common.Annotations({"my": "annotation"}) raw_data_output_config = common.RawOutputDataConfig("s3://bucket") old_role = _launch_plan_idl.Auth( kubernetes_service_account="my:service:account") old_style_spec = _launch_plan_idl.LaunchPlanSpec( workflow_id=identifier_model.to_flyte_idl(), entity_metadata=launch_plan_metadata_model.to_flyte_idl(), default_inputs=parameter_map.to_flyte_idl(), fixed_inputs=fixed_inputs.to_flyte_idl(), labels=labels_model.to_flyte_idl(), annotations=annotations_model.to_flyte_idl(), raw_output_data_config=raw_data_output_config.to_flyte_idl(), auth=old_role, ) lp_spec = launch_plan.LaunchPlanSpec.from_flyte_idl(old_style_spec) assert lp_spec.auth_role.assumable_iam_role == "my:service:account"
def test_structured_dataset(): my_cols = [ _types.StructuredDatasetType.DatasetColumn( "a", _types.LiteralType(simple=_types.SimpleType.INTEGER)), _types.StructuredDatasetType.DatasetColumn( "b", _types.LiteralType(simple=_types.SimpleType.STRING)), _types.StructuredDatasetType.DatasetColumn( "c", _types.LiteralType(collection_type=_types.LiteralType( simple=_types.SimpleType.INTEGER))), _types.StructuredDatasetType.DatasetColumn( "d", _types.LiteralType(map_value_type=_types.LiteralType( simple=_types.SimpleType.INTEGER))), ] ds = literals.StructuredDataset( uri="s3://bucket", metadata=literals.StructuredDatasetMetadata( structured_dataset_type=_types.StructuredDatasetType( columns=my_cols, format="parquet")), ) obj = literals.Scalar(structured_dataset=ds) assert obj.error is None assert obj.blob is None assert obj.binary is None assert obj.schema is None assert obj.none_type is None assert obj.structured_dataset is not None assert obj.value.uri == "s3://bucket" assert len(obj.value.metadata.structured_dataset_type.columns) == 4 obj2 = literals.Scalar.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.blob is None assert obj2.binary is None assert obj2.schema is None assert obj2.none_type is None assert obj2.structured_dataset is not None assert obj2.value.uri == "s3://bucket" assert len(obj2.value.metadata.structured_dataset_type.columns) == 4
def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType(schema=cls.schema_type)