def _build_importer_spec(self) -> ImporterSpec: """Builds ImporterSpec.""" assert isinstance(self._node, importer.Importer) output_channel = self._node.outputs[importer.IMPORT_RESULT_KEY] result = ImporterSpec() # Importer's output channel contains one artifact instance with # additional properties. artifact_instance = list(output_channel.get())[0] struct_proto = compiler_utils.pack_artifact_properties(artifact_instance) if struct_proto: result.metadata.CopyFrom(struct_proto) result.reimport = bool(self._exec_properties[importer.REIMPORT_OPTION_KEY]) result.artifact_uri.CopyFrom( compiler_utils.value_converter( self._exec_properties[importer.SOURCE_URI_KEY])) single_artifact = artifact_utils.get_single_instance( list(output_channel.get())) result.type_schema.CopyFrom( pipeline_pb2.ArtifactTypeSchema( instance_schema=compiler_utils.get_artifact_schema( single_artifact))) return result
def _build_importer_spec(self) -> ImporterSpec: """Builds ImporterSpec.""" assert isinstance(self._node, importer.Importer) output_channel = self._node.outputs[importer.IMPORT_RESULT_KEY] result = ImporterSpec() # Importer's output channel contains one artifact instance with # additional properties. if output_channel.additional_properties: result.metadata.update(output_channel.additional_properties) if output_channel.additional_custom_properties: result.metadata.update(output_channel.additional_custom_properties) result.reimport = bool(self._exec_properties[importer.REIMPORT_OPTION_KEY]) # 'artifact_uri' property of Importer node should be string, but the type # is not checked (except the pytype hint) in Importer node. # It is possible to escape the type constraint and pass a RuntimeParameter. # If that happens, we need to overwrite the runtime parameter name to # 'artifact_uri', instead of using the name of user-provided runtime # parameter. if isinstance(self._exec_properties[importer.SOURCE_URI_KEY], data_types.RuntimeParameter): result.artifact_uri.runtime_parameter = importer.SOURCE_URI_KEY else: result.artifact_uri.CopyFrom( compiler_utils.value_converter( self._exec_properties[importer.SOURCE_URI_KEY])) result.type_schema.CopyFrom( pipeline_pb2.ArtifactTypeSchema( instance_schema=compiler_utils.get_artifact_schema( output_channel.type))) return result
def test_build_importer_component_spec(self): expected_importer_component = { 'inputDefinitions': { 'parameters': { 'uri': { 'type': 'STRING' } } }, 'outputDefinitions': { 'artifacts': { 'artifact': { 'artifactType': { 'schemaTitle': 'system.Artifact' } } } }, 'executorLabel': 'exec-importer-1' } expected_importer_comp_spec = pb.ComponentSpec() json_format.ParseDict(expected_importer_component, expected_importer_comp_spec) importer_comp_spec = importer_node._build_importer_component_spec( importer_base_name='importer-1', artifact_type_schema=pb.ArtifactTypeSchema( schema_title='system.Artifact')) self.maxDiff = None self.assertEqual(expected_importer_comp_spec, importer_comp_spec)
def __init__(self, instance_schema: Optional[str] = None): """Constructs an instance of Artifact. Setups up self._metadata_fields to perform type checking and initialize RuntimeArtifact. """ if self.__class__ == Artifact: if not instance_schema: raise ValueError( 'The "instance_schema" argument must be set for Artifact.') self._instance_schema = instance_schema else: if instance_schema: raise ValueError( 'The "instance_schema" argument must not be passed for Artifact \ subclass: {}'.format(self.__class__)) # setup self._metadata_fields self.TYPE_NAME, self._metadata_fields = artifact_utils.parse_schema( self._instance_schema) # Instantiate a RuntimeArtifact pb message as the POD data structure. self._artifact = pipeline_spec_pb2.RuntimeArtifact() # Stores the metadata for the Artifact. self.metadata = {} self._artifact.type.CopyFrom( pipeline_spec_pb2.ArtifactTypeSchema( instance_schema=self._instance_schema)) self._initialized = True
def __init__(self, instance_schema: Optional[str] = None): """Constructs an instance of Artifact""" if self.__class__ == Artifact: if not instance_schema: raise ValueError( 'The "instance_schema" argument must be passed to specify a ' 'type for this Artifact.') schema_yaml = yaml.safe_load(instance_schema) if 'properties' not in schema_yaml: raise ValueError( 'Invalid instance_schema, properties must be present. ' 'Got %s' % instance_schema) schema = schema_yaml['properties'] self.TYPE_NAME = yaml.safe_load(instance_schema)['title'] self.PROPERTIES = {} for k, v in schema.items(): self.PROPERTIES[k] = Property.from_dict(v) else: if instance_schema: raise ValueError( 'The "mlmd_artifact_type" argument must not be passed for ' 'Artifact subclass %s.' % self.__class__) instance_schema = self._get_artifact_type() # MLMD artifact type schema string. self._type_schema = instance_schema # Instantiate a RuntimeArtifact pb message as the POD data structure. self._artifact = pipeline_spec_pb2.RuntimeArtifact() self._artifact.type.CopyFrom( pipeline_spec_pb2.ArtifactTypeSchema( instance_schema=instance_schema)) # Initialization flag to prevent recursive getattr / setattr errors. self._initialized = True
def test_build_importer_spec_with_invalid_inputs_should_fail(self): with self.assertRaisesRegex( AssertionError, 'importer spec should be built using either pipeline_param_name or ' 'constant_value'): importer_node.build_importer_spec( input_type_schema=pb.ArtifactTypeSchema( schema_title='system.Artifact'), pipeline_param_name='param1', constant_value='some_uri') with self.assertRaisesRegex( AssertionError, 'importer spec should be built using either pipeline_param_name or ' 'constant_value'): importer_node.build_importer_spec( input_type_schema=pb.ArtifactTypeSchema( schema_title='system.Artifact'))
def build_input_artifact_spec( channel_spec: channel.Channel ) -> pipeline_pb2.ComponentInputsSpec.ArtifactSpec: """Builds artifact type spec for an input channel.""" result = pipeline_pb2.ComponentInputsSpec.ArtifactSpec() result.artifact_type.CopyFrom( pipeline_pb2.ArtifactTypeSchema( instance_schema=get_artifact_schema(channel_spec.type))) _validate_properties_schema( instance_schema=result.artifact_type.instance_schema, properties=channel_spec.type.PROPERTIES) return result
def get_artifact_type_schema( artifact_class_or_type_name: Optional[Union[str, Type[artifact_types.Artifact]]] ) -> pipeline_spec_pb2.ArtifactTypeSchema: """Gets the IR I/O artifact type msg for the given ComponentSpec I/O type.""" artifact_class = artifact_types.Artifact if isinstance(artifact_class_or_type_name, str): if re.match(_GOOGLE_TYPES_PATTERN, artifact_class_or_type_name): return pipeline_spec_pb2.ArtifactTypeSchema( schema_title=artifact_class_or_type_name, schema_version=_GOOGLE_TYPES_VERSION, ) artifact_class = _ARTIFACT_CLASSES_MAPPING.get( artifact_class_or_type_name.lower(), artifact_types.Artifact) elif inspect.isclass(artifact_class_or_type_name) and issubclass( artifact_class_or_type_name, artifact_types.Artifact): artifact_class = artifact_class_or_type_name return pipeline_spec_pb2.ArtifactTypeSchema( schema_title=artifact_class.TYPE_NAME, schema_version=artifact_class.VERSION)
def get_artifact_type_schema_message( type_name: str) -> pipeline_spec_pb2.ArtifactTypeSchema: """Gets the IR I/O artifact type msg for the given ComponentSpec I/O type.""" if isinstance(type_name, str): artifact_class = _ARTIFACT_CLASSES_MAPPING.get(type_name.lower(), artifact.Artifact) # TODO: migrate all types to system. namespace. if artifact_class.TYPE_NAME.startswith('system.'): return pipeline_spec_pb2.ArtifactTypeSchema( schema_title=artifact_class.TYPE_NAME) else: return artifact_class.get_ir_type() else: return artifact.Artifact.get_ir_type()
def get_artifact_type_schema( artifact_class_or_type_name: Optional[Union[str, Type[io_types.Artifact]]] ) -> pipeline_spec_pb2.ArtifactTypeSchema: """Gets the IR I/O artifact type msg for the given ComponentSpec I/O type.""" artifact_class = io_types.Artifact if isinstance(artifact_class_or_type_name, str): artifact_class = _ARTIFACT_CLASSES_MAPPING.get( artifact_class_or_type_name.lower(), io_types.Artifact) elif inspect.isclass(artifact_class_or_type_name) and issubclass( artifact_class_or_type_name, io_types.Artifact): artifact_class = artifact_class_or_type_name return pipeline_spec_pb2.ArtifactTypeSchema( schema_title=artifact_class.TYPE_NAME)
def test_build_importer_spec_from_pipeline_param(self): expected_importer = { 'artifactUri': { 'runtimeParameter': 'param1' }, 'typeSchema': { 'schemaTitle': 'system.Artifact' } } expected_importer_spec = pb.PipelineDeploymentConfig.ImporterSpec() json_format.ParseDict(expected_importer, expected_importer_spec) importer_spec = importer_node.build_importer_spec( input_type_schema=pb.ArtifactTypeSchema( schema_title='system.Artifact'), pipeline_param_name='param1') self.maxDiff = None self.assertEqual(expected_importer_spec, importer_spec)
def test_build_importer_spec_from_constant_value(self): expected_importer = { 'artifactUri': { 'constantValue': { 'stringValue': 'some_uri' } }, 'typeSchema': { 'schemaTitle': 'system.Artifact' } } expected_importer_spec = pb.PipelineDeploymentConfig.ImporterSpec() json_format.ParseDict(expected_importer, expected_importer_spec) importer_spec = importer_node.build_importer_spec( input_type_schema=pb.ArtifactTypeSchema( schema_title='system.Artifact'), constant_value='some_uri') self.maxDiff = None self.assertEqual(expected_importer_spec, importer_spec)
def build_output_artifact_spec( channel_spec: channel.Channel ) -> pipeline_pb2.ComponentOutputsSpec.ArtifactSpec: """Builds artifact type spec for an output channel.""" # We use the first artifact instance if available from channel, otherwise # create one. result = pipeline_pb2.ComponentOutputsSpec.ArtifactSpec() result.artifact_type.CopyFrom( pipeline_pb2.ArtifactTypeSchema( instance_schema=get_artifact_schema(channel_spec.type))) _validate_properties_schema( instance_schema=result.artifact_type.instance_schema, properties=channel_spec.type.PROPERTIES) if channel_spec.additional_properties: result.metadata.update(channel_spec.additional_properties) if channel_spec.additional_custom_properties: result.metadata.update(channel_spec.additional_custom_properties) return result
def setUp(self): super().setUp() self._executor_invocation = pipeline_pb2.ExecutorInput() self._executor_invocation.outputs.output_file = _TEST_OUTPUT_METADATA_JSON self._executor_invocation.inputs.parameters[ 'input_base'].string_value = _TEST_INPUT_DIR self._executor_invocation.inputs.parameters[ 'output_config'].string_value = '{}' self._executor_invocation.inputs.parameters[ 'input_config'].string_value = json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN}/split1/*'), example_gen_pb2.Input.Split(name='s2', pattern='span{SPAN}/split2/*') ])) self._executor_invocation.outputs.artifacts[ 'examples'].artifacts.append( pipeline_pb2.RuntimeArtifact( type=pipeline_pb2.ArtifactTypeSchema( instance_schema=compiler_utils.get_artifact_schema( standard_artifacts.Examples)))) self._executor_invocation_from_file = fileio.open( os.path.join(os.path.dirname(__file__), 'testdata', 'executor_invocation.json'), 'r').read() logging.debug('Executor invocation under test: %s', self._executor_invocation_from_file) self._expected_result_from_file = fileio.open( os.path.join(os.path.dirname(__file__), 'testdata', 'expected_output_metadata.json'), 'r').read() logging.debug('Expecting output metadata JSON: %s', self._expected_result_from_file) # Change working directory after all the testdata files have been read. self.enter_context(test_case_utils.change_working_dir(self.tmp_dir)) fileio.makedirs(os.path.dirname(_TEST_INPUT_DIR))
def build_output_artifact_spec( channel_spec: channel.Channel ) -> pipeline_pb2.ComponentOutputsSpec.ArtifactSpec: """Builds artifact type spec for an output channel.""" # We use the first artifact instance if available from channel, otherwise # create one. artifacts = list(channel_spec.get()) artifact_instance = artifacts[0] if artifacts else channel_spec.type() result = pipeline_pb2.ComponentOutputsSpec.ArtifactSpec() result.artifact_type.CopyFrom( pipeline_pb2.ArtifactTypeSchema( instance_schema=get_artifact_schema(artifact_instance))) _validate_properties_schema( instance_schema=result.artifact_type.instance_schema, properties=channel_spec.type.PROPERTIES) struct_proto = pack_artifact_properties(artifact_instance) if struct_proto: result.metadata.CopyFrom(struct_proto) return result
class ImporterNodeTest(parameterized.TestCase): @parameterized.parameters( { # artifact_uri is a constant value 'input_uri': 'gs://artifact', 'artifact_type_schema': pb.ArtifactTypeSchema(schema_title='system.Dataset'), 'expected_result': { 'artifactUri': { 'constantValue': { 'stringValue': 'gs://artifact' } }, 'typeSchema': { 'schemaTitle': 'system.Dataset' } } }, { # artifact_uri is from PipelineParam 'input_uri': _pipeline_param.PipelineParam(name='uri_to_import'), 'artifact_type_schema': pb.ArtifactTypeSchema(schema_title='system.Model'), 'expected_result': { 'artifactUri': { 'runtimeParameter': 'uri' }, 'typeSchema': { 'schemaTitle': 'system.Model' } }, }) def test_build_importer_spec(self, input_uri, artifact_type_schema, expected_result): expected_importer_spec = pb.PipelineDeploymentConfig.ImporterSpec() json_format.ParseDict(expected_result, expected_importer_spec) importer_spec = importer_node._build_importer_spec( artifact_uri=input_uri, artifact_type_schema=artifact_type_schema) self.maxDiff = None self.assertEqual(expected_importer_spec, importer_spec) @parameterized.parameters( { # artifact_uri is a constant value 'importer_name': 'importer-1', 'input_uri': 'gs://artifact', 'expected_result': { 'taskInfo': { 'name': 'importer-1' }, 'inputs': { 'parameters': { 'uri': { 'runtimeValue': { 'constantValue': { 'stringValue': 'gs://artifact' } } } } }, 'componentRef': { 'name': 'comp-importer-1' }, } }, { # artifact_uri is from PipelineParam 'importer_name': 'importer-2', 'input_uri': _pipeline_param.PipelineParam(name='uri_to_import'), 'expected_result': { 'taskInfo': { 'name': 'importer-2' }, 'inputs': { 'parameters': { 'uri': { 'componentInputParameter': 'uri_to_import' } } }, 'componentRef': { 'name': 'comp-importer-2' }, }, }) def test_build_importer_task_spec(self, importer_name, input_uri, expected_result): expected_task_spec = pb.PipelineTaskSpec() json_format.ParseDict(expected_result, expected_task_spec) task_spec = importer_node._build_importer_task_spec( importer_base_name=importer_name, artifact_uri=input_uri) self.maxDiff = None self.assertEqual(expected_task_spec, task_spec) def test_build_importer_component_spec(self): expected_importer_component = { 'inputDefinitions': { 'parameters': { 'uri': { 'type': 'STRING' } } }, 'outputDefinitions': { 'artifacts': { 'artifact': { 'artifactType': { 'schemaTitle': 'system.Artifact' } } } }, 'executorLabel': 'exec-importer-1' } expected_importer_comp_spec = pb.ComponentSpec() json_format.ParseDict(expected_importer_component, expected_importer_comp_spec) importer_comp_spec = importer_node._build_importer_component_spec( importer_base_name='importer-1', artifact_type_schema=pb.ArtifactTypeSchema( schema_title='system.Artifact')) self.maxDiff = None self.assertEqual(expected_importer_comp_spec, importer_comp_spec) def test_import_with_invalid_artifact_uri_value_should_fail(self): from kfp.dsl.io_types import Dataset with self.assertRaisesRegex( ValueError, "Importer got unexpected artifact_uri: 123 of type: <class 'int'>."): importer_node.importer(artifact_uri=123, artifact_class=Dataset)
def get_ir_type(cls) -> pipeline_spec_pb2.ArtifactTypeSchema: return pipeline_spec_pb2.ArtifactTypeSchema( instance_schema=cls.get_artifact_type())
class TypeUtilsTest(parameterized.TestCase): @parameterized.parameters( [(item, True) for item in _PARAMETER_TYPES] + [(item, False) for item in _KNOWN_ARTIFACT_TYPES + _UNKNOWN_ARTIFACT_TYPES]) def test_is_parameter_type_true(self, type_name, expected_result): self.assertEqual(expected_result, type_utils.is_parameter_type(type_name)) @parameterized.parameters( { 'artifact_class_or_type_name': 'Model', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Model', schema_version='0.0.1') }, { 'artifact_class_or_type_name': artifact_types.Model, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Model', schema_version='0.0.1') }, { 'artifact_class_or_type_name': 'Dataset', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Dataset', schema_version='0.0.1') }, { 'artifact_class_or_type_name': artifact_types.Dataset, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Dataset', schema_version='0.0.1') }, { 'artifact_class_or_type_name': 'Metrics', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Metrics', schema_version='0.0.1') }, { 'artifact_class_or_type_name': artifact_types.Metrics, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Metrics', schema_version='0.0.1') }, { 'artifact_class_or_type_name': 'ClassificationMetrics', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.ClassificationMetrics', schema_version='0.0.1') }, { 'artifact_class_or_type_name': artifact_types.ClassificationMetrics, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.ClassificationMetrics', schema_version='0.0.1') }, { 'artifact_class_or_type_name': 'SlicedClassificationMetrics', 'expected_result': pb.ArtifactTypeSchema( schema_title='system.SlicedClassificationMetrics', schema_version='0.0.1') }, { 'artifact_class_or_type_name': artifact_types.SlicedClassificationMetrics, 'expected_result': pb.ArtifactTypeSchema( schema_title='system.SlicedClassificationMetrics', schema_version='0.0.1') }, { 'artifact_class_or_type_name': 'arbitrary name', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Artifact', schema_version='0.0.1') }, { 'artifact_class_or_type_name': _ArbitraryClass, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Artifact', schema_version='0.0.1') }, { 'artifact_class_or_type_name': artifact_types.HTML, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.HTML', schema_version='0.0.1') }, { 'artifact_class_or_type_name': artifact_types.Markdown, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Markdown', schema_version='0.0.1') }, { 'artifact_class_or_type_name': 'some-google-type', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Artifact', schema_version='0.0.1') }, { 'artifact_class_or_type_name': 'google.VertexModel', 'expected_result': pb.ArtifactTypeSchema(schema_title='google.VertexModel', schema_version='0.0.1') }, { 'artifact_class_or_type_name': _VertexDummy, 'expected_result': pb.ArtifactTypeSchema(schema_title='google.VertexDummy', schema_version='0.0.2') }, ) def test_get_artifact_type_schema(self, artifact_class_or_type_name, expected_result): self.assertEqual( expected_result, type_utils.get_artifact_type_schema(artifact_class_or_type_name)) @parameterized.parameters( { 'given_type': 'Int', 'expected_type': pb.ParameterType.NUMBER_INTEGER, }, { 'given_type': 'Integer', 'expected_type': pb.ParameterType.NUMBER_INTEGER, }, { 'given_type': int, 'expected_type': pb.ParameterType.NUMBER_INTEGER, }, { 'given_type': 'Double', 'expected_type': pb.ParameterType.NUMBER_DOUBLE, }, { 'given_type': 'Float', 'expected_type': pb.ParameterType.NUMBER_DOUBLE, }, { 'given_type': float, 'expected_type': pb.ParameterType.NUMBER_DOUBLE, }, { 'given_type': 'String', 'expected_type': pb.ParameterType.STRING, }, { 'given_type': 'Text', 'expected_type': pb.ParameterType.STRING, }, { 'given_type': str, 'expected_type': pb.ParameterType.STRING, }, { 'given_type': 'Boolean', 'expected_type': pb.ParameterType.BOOLEAN, }, { 'given_type': bool, 'expected_type': pb.ParameterType.BOOLEAN, }, { 'given_type': 'Dict', 'expected_type': pb.ParameterType.STRUCT, }, { 'given_type': dict, 'expected_type': pb.ParameterType.STRUCT, }, { 'given_type': 'List', 'expected_type': pb.ParameterType.LIST, }, { 'given_type': list, 'expected_type': pb.ParameterType.LIST, }, { 'given_type': Dict[str, int], 'expected_type': pb.ParameterType.STRUCT, }, { 'given_type': List[Any], 'expected_type': pb.ParameterType.LIST, }, { 'given_type': { 'JsonObject': { 'data_type': 'proto:tfx.components.trainer.TrainArgs' } }, 'expected_type': pb.ParameterType.STRUCT, }, ) def test_get_parameter_type(self, given_type, expected_type): self.assertEqual(expected_type, type_utils.get_parameter_type(given_type)) # Test get parameter by Python type. self.assertEqual(pb.ParameterType.NUMBER_INTEGER, type_utils.get_parameter_type(int)) def test_get_parameter_type_invalid(self): with self.assertRaises(AttributeError): type_utils.get_parameter_type_schema(None) def test_get_input_artifact_type_schema(self): input_specs = [ v1_structures.InputSpec(name='input1', type='String'), v1_structures.InputSpec(name='input2', type='Model'), v1_structures.InputSpec(name='input3', type=None), ] # input not found. with self.assertRaises(AssertionError) as cm: type_utils.get_input_artifact_type_schema('input0', input_specs) self.assertEqual('Input not found.', str(cm)) # input found, but it doesn't map to an artifact type. with self.assertRaises(AssertionError) as cm: type_utils.get_input_artifact_type_schema('input1', input_specs) self.assertEqual('Input is not an artifact type.', str(cm)) # input found, and a matching artifact type schema returned. self.assertEqual( 'system.Model', type_utils.get_input_artifact_type_schema( 'input2', input_specs).schema_title) # input found, and the default artifact type schema returned. self.assertEqual( 'system.Artifact', type_utils.get_input_artifact_type_schema( 'input3', input_specs).schema_title) @parameterized.parameters( { 'given_type': 'String', 'expected_type': 'String', 'is_compatible': True, }, { 'given_type': 'String', 'expected_type': 'Integer', 'is_compatible': False, }, { 'given_type': { 'type_a': { 'property': 'property_b', } }, 'expected_type': { 'type_a': { 'property': 'property_b', } }, 'is_compatible': True, }, { 'given_type': { 'type_a': { 'property': 'property_b', } }, 'expected_type': { 'type_a': { 'property': 'property_c', } }, 'is_compatible': False, }, { 'given_type': 'Artifact', 'expected_type': 'Model', 'is_compatible': True, }, { 'given_type': 'Metrics', 'expected_type': 'Artifact', 'is_compatible': True, }, ) def test_verify_type_compatibility( self, given_type: Union[str, dict], expected_type: Union[str, dict], is_compatible: bool, ): if is_compatible: self.assertTrue( type_utils.verify_type_compatibility( given_type=given_type, expected_type=expected_type, error_message_prefix='', )) else: with self.assertRaises(InconsistentTypeException): type_utils.verify_type_compatibility( given_type=given_type, expected_type=expected_type, error_message_prefix='', ) @parameterized.parameters( { 'given_type': str, 'expected_type_name': 'String', }, { 'given_type': int, 'expected_type_name': 'Integer', }, { 'given_type': float, 'expected_type_name': 'Float', }, { 'given_type': bool, 'expected_type_name': 'Boolean', }, { 'given_type': list, 'expected_type_name': 'List', }, { 'given_type': dict, 'expected_type_name': 'Dict', }, { 'given_type': Any, 'expected_type_name': None, }, ) def test_get_canonical_type_name_for_type( self, given_type, expected_type_name, ): self.assertEqual( expected_type_name, type_utils.get_canonical_type_name_for_type(given_type)) @parameterized.parameters( { 'given_type': 'PipelineTaskFinalStatus', 'expected_result': True, }, { 'given_type': 'pipelineTaskFinalstatus', 'expected_result': False, }, { 'given_type': int, 'expected_result': False, }, ) def test_is_task_final_statu_type(self, given_type, expected_result): self.assertEqual(expected_result, type_utils.is_task_final_status_type(given_type))
class TypeUtilsTest(parameterized.TestCase): def test_is_parameter_type(self): for type_name in _PARAMETER_TYPES: self.assertTrue(type_utils.is_parameter_type(type_name)) for type_name in _KNOWN_ARTIFACT_TYPES + _UNKNOWN_ARTIFACT_TYPES: self.assertFalse(type_utils.is_parameter_type(type_name)) @parameterized.parameters( { 'artifact_class_or_type_name': 'Model', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Model') }, { 'artifact_class_or_type_name': io_types.Model, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Model') }, { 'artifact_class_or_type_name': 'Dataset', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Dataset') }, { 'artifact_class_or_type_name': io_types.Dataset, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Dataset') }, { 'artifact_class_or_type_name': 'Metrics', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Metrics') }, { 'artifact_class_or_type_name': io_types.Metrics, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Metrics') }, { 'artifact_class_or_type_name': 'ClassificationMetrics', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.ClassificationMetrics') }, { 'artifact_class_or_type_name': io_types.ClassificationMetrics, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.ClassificationMetrics') }, { 'artifact_class_or_type_name': 'SlicedClassificationMetrics', 'expected_result': pb.ArtifactTypeSchema( schema_title='system.SlicedClassificationMetrics') }, { 'artifact_class_or_type_name': io_types.SlicedClassificationMetrics, 'expected_result': pb.ArtifactTypeSchema( schema_title='system.SlicedClassificationMetrics') }, { 'artifact_class_or_type_name': 'arbitrary name', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Artifact') }, { 'artifact_class_or_type_name': _ArbitraryClass, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Artifact') }, ) def test_get_artifact_type_schema(self, artifact_class_or_type_name, expected_result): self.assertEqual( expected_result, type_utils.get_artifact_type_schema(artifact_class_or_type_name)) @parameterized.parameters( { 'given_type': 'Int', 'expected_type': pb.PrimitiveType.INT, }, { 'given_type': 'Integer', 'expected_type': pb.PrimitiveType.INT, }, { 'given_type': int, 'expected_type': pb.PrimitiveType.INT, }, { 'given_type': 'Double', 'expected_type': pb.PrimitiveType.DOUBLE, }, { 'given_type': 'Float', 'expected_type': pb.PrimitiveType.DOUBLE, }, { 'given_type': float, 'expected_type': pb.PrimitiveType.DOUBLE, }, { 'given_type': 'String', 'expected_type': pb.PrimitiveType.STRING, }, { 'given_type': 'Text', 'expected_type': pb.PrimitiveType.STRING, }, { 'given_type': str, 'expected_type': pb.PrimitiveType.STRING, }, { 'given_type': 'Boolean', 'expected_type': pb.PrimitiveType.STRING, }, { 'given_type': bool, 'expected_type': pb.PrimitiveType.STRING, }, { 'given_type': 'Dict', 'expected_type': pb.PrimitiveType.STRING, }, { 'given_type': dict, 'expected_type': pb.PrimitiveType.STRING, }, { 'given_type': 'List', 'expected_type': pb.PrimitiveType.STRING, }, { 'given_type': list, 'expected_type': pb.PrimitiveType.STRING, }, { 'given_type': Dict[str, int], 'expected_type': pb.PrimitiveType.STRING, }, { 'given_type': List[Any], 'expected_type': pb.PrimitiveType.STRING, }, { 'given_type': { 'JsonObject': { 'data_type': 'proto:tfx.components.trainer.TrainArgs' } }, 'expected_type': pb.PrimitiveType.STRING, }, ) def test_get_parameter_type(self, given_type, expected_type): self.assertEqual(expected_type, type_utils.get_parameter_type(given_type)) # Test get parameter by Python type. self.assertEqual(pb.PrimitiveType.INT, type_utils.get_parameter_type(int)) def test_get_parameter_type_invalid(self): with self.assertRaises(AttributeError): type_utils.get_parameter_type_schema(None) def test_get_input_artifact_type_schema(self): input_specs = [ structures.InputSpec(name='input1', type='String'), structures.InputSpec(name='input2', type='Model'), structures.InputSpec(name='input3', type=None), ] # input not found. with self.assertRaises(AssertionError) as cm: type_utils.get_input_artifact_type_schema('input0', input_specs) self.assertEqual('Input not found.', str(cm)) # input found, but it doesn't map to an artifact type. with self.assertRaises(AssertionError) as cm: type_utils.get_input_artifact_type_schema('input1', input_specs) self.assertEqual('Input is not an artifact type.', str(cm)) # input found, and a matching artifact type schema returned. self.assertEqual( 'system.Model', type_utils.get_input_artifact_type_schema( 'input2', input_specs).schema_title) # input found, and the default artifact type schema returned. self.assertEqual( 'system.Artifact', type_utils.get_input_artifact_type_schema( 'input3', input_specs).schema_title) def test_get_parameter_type_field_name(self): self.assertEqual('string_value', type_utils.get_parameter_type_field_name('String')) self.assertEqual('int_value', type_utils.get_parameter_type_field_name('Integer')) self.assertEqual('double_value', type_utils.get_parameter_type_field_name('Float'))
class TypeUtilsTest(parameterized.TestCase): def test_is_parameter_type(self): for type_name in _PARAMETER_TYPES: self.assertTrue(type_utils.is_parameter_type(type_name)) for type_name in _KNOWN_ARTIFACT_TYPES + _UNKNOWN_ARTIFACT_TYPES: self.assertFalse(type_utils.is_parameter_type(type_name)) @parameterized.parameters( { 'artifact_class_or_type_name': 'Model', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Model') }, { 'artifact_class_or_type_name': io_types.Model, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Model') }, { 'artifact_class_or_type_name': 'Dataset', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Dataset') }, { 'artifact_class_or_type_name': io_types.Dataset, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Dataset') }, { 'artifact_class_or_type_name': 'Metrics', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Metrics') }, { 'artifact_class_or_type_name': io_types.Metrics, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Metrics') }, { 'artifact_class_or_type_name': 'ClassificationMetrics', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.ClassificationMetrics') }, { 'artifact_class_or_type_name': io_types.ClassificationMetrics, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.ClassificationMetrics') }, { 'artifact_class_or_type_name': 'SlicedClassificationMetrics', 'expected_result': pb.ArtifactTypeSchema( schema_title='system.SlicedClassificationMetrics') }, { 'artifact_class_or_type_name': io_types.SlicedClassificationMetrics, 'expected_result': pb.ArtifactTypeSchema( schema_title='system.SlicedClassificationMetrics') }, { 'artifact_class_or_type_name': 'arbitrary name', 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Artifact') }, { 'artifact_class_or_type_name': _ArbitraryClass, 'expected_result': pb.ArtifactTypeSchema(schema_title='system.Artifact') }, ) def test_get_artifact_type_schema(self, artifact_class_or_type_name, expected_result): self.assertEqual( expected_result, type_utils.get_artifact_type_schema(artifact_class_or_type_name)) def test_get_parameter_type(self): # Test get parameter type by name. self.assertEqual(pb.PrimitiveType.INT, type_utils.get_parameter_type('Int')) self.assertEqual(pb.PrimitiveType.INT, type_utils.get_parameter_type('Integer')) self.assertEqual(pb.PrimitiveType.DOUBLE, type_utils.get_parameter_type('Double')) self.assertEqual(pb.PrimitiveType.DOUBLE, type_utils.get_parameter_type('Float')) self.assertEqual(pb.PrimitiveType.STRING, type_utils.get_parameter_type('String')) self.assertEqual(pb.PrimitiveType.STRING, type_utils.get_parameter_type('Str')) # Test get parameter by Python type. self.assertEqual(pb.PrimitiveType.INT, type_utils.get_parameter_type(int)) self.assertEqual(pb.PrimitiveType.DOUBLE, type_utils.get_parameter_type(float)) self.assertEqual(pb.PrimitiveType.STRING, type_utils.get_parameter_type(str)) with self.assertRaises(AttributeError): type_utils.get_parameter_type_schema(None) with self.assertRaisesRegex(TypeError, 'Got illegal parameter type.'): type_utils.get_parameter_type(bool) def test_get_input_artifact_type_schema(self): input_specs = [ structures.InputSpec(name='input1', type='String'), structures.InputSpec(name='input2', type='Model'), structures.InputSpec(name='input3', type=None), ] # input not found. with self.assertRaises(AssertionError) as cm: type_utils.get_input_artifact_type_schema('input0', input_specs) self.assertEqual('Input not found.', str(cm)) # input found, but it doesn't map to an artifact type. with self.assertRaises(AssertionError) as cm: type_utils.get_input_artifact_type_schema('input1', input_specs) self.assertEqual('Input is not an artifact type.', str(cm)) # input found, and a matching artifact type schema returned. self.assertEqual( 'system.Model', type_utils.get_input_artifact_type_schema( 'input2', input_specs).schema_title) # input found, and the default artifact type schema returned. self.assertEqual( 'system.Artifact', type_utils.get_input_artifact_type_schema( 'input3', input_specs).schema_title) def test_get_parameter_type_field_name(self): self.assertEqual('string_value', type_utils.get_parameter_type_field_name('String')) self.assertEqual('int_value', type_utils.get_parameter_type_field_name('Integer')) self.assertEqual('double_value', type_utils.get_parameter_type_field_name('Float'))