Ejemplo n.º 1
0
  def _build_importer_spec(self) -> ImporterSpec:
    """Builds ImporterSpec."""
    assert isinstance(self._node, importer.Importer)
    output_channel = self._node.outputs[importer.IMPORT_RESULT_KEY]
    result = ImporterSpec()

    # Importer's output channel contains one artifact instance with
    # additional properties.
    artifact_instance = list(output_channel.get())[0]
    struct_proto = compiler_utils.pack_artifact_properties(artifact_instance)
    if struct_proto:
      result.metadata.CopyFrom(struct_proto)

    result.reimport = bool(self._exec_properties[importer.REIMPORT_OPTION_KEY])
    result.artifact_uri.CopyFrom(
        compiler_utils.value_converter(
            self._exec_properties[importer.SOURCE_URI_KEY]))
    single_artifact = artifact_utils.get_single_instance(
        list(output_channel.get()))
    result.type_schema.CopyFrom(
        pipeline_pb2.ArtifactTypeSchema(
            instance_schema=compiler_utils.get_artifact_schema(
                single_artifact)))

    return result
Ejemplo n.º 2
0
  def _build_importer_spec(self) -> ImporterSpec:
    """Builds ImporterSpec."""
    assert isinstance(self._node, importer.Importer)
    output_channel = self._node.outputs[importer.IMPORT_RESULT_KEY]
    result = ImporterSpec()

    # Importer's output channel contains one artifact instance with
    # additional properties.
    if output_channel.additional_properties:
      result.metadata.update(output_channel.additional_properties)
    if output_channel.additional_custom_properties:
      result.metadata.update(output_channel.additional_custom_properties)

    result.reimport = bool(self._exec_properties[importer.REIMPORT_OPTION_KEY])

    # 'artifact_uri' property of Importer node should be string, but the type
    # is not checked (except the pytype hint) in Importer node.
    # It is possible to escape the type constraint and pass a RuntimeParameter.
    # If that happens, we need to overwrite the runtime parameter name to
    # 'artifact_uri', instead of using the name of user-provided runtime
    # parameter.
    if isinstance(self._exec_properties[importer.SOURCE_URI_KEY],
                  data_types.RuntimeParameter):
      result.artifact_uri.runtime_parameter = importer.SOURCE_URI_KEY
    else:
      result.artifact_uri.CopyFrom(
          compiler_utils.value_converter(
              self._exec_properties[importer.SOURCE_URI_KEY]))

    result.type_schema.CopyFrom(
        pipeline_pb2.ArtifactTypeSchema(
            instance_schema=compiler_utils.get_artifact_schema(
                output_channel.type)))

    return result
Ejemplo n.º 3
0
  def test_build_importer_component_spec(self):
    expected_importer_component = {
        'inputDefinitions': {
            'parameters': {
                'uri': {
                    'type': 'STRING'
                }
            }
        },
        'outputDefinitions': {
            'artifacts': {
                'artifact': {
                    'artifactType': {
                        'schemaTitle': 'system.Artifact'
                    }
                }
            }
        },
        'executorLabel': 'exec-importer-1'
    }
    expected_importer_comp_spec = pb.ComponentSpec()
    json_format.ParseDict(expected_importer_component,
                          expected_importer_comp_spec)
    importer_comp_spec = importer_node._build_importer_component_spec(
        importer_base_name='importer-1',
        artifact_type_schema=pb.ArtifactTypeSchema(
            schema_title='system.Artifact'))

    self.maxDiff = None
    self.assertEqual(expected_importer_comp_spec, importer_comp_spec)
Ejemplo n.º 4
0
    def __init__(self, instance_schema: Optional[str] = None):
        """Constructs an instance of Artifact.

    Setups up self._metadata_fields to perform type checking and initialize
    RuntimeArtifact.

    """
        if self.__class__ == Artifact:
            if not instance_schema:
                raise ValueError(
                    'The "instance_schema" argument must be set for Artifact.')
            self._instance_schema = instance_schema
        else:
            if instance_schema:
                raise ValueError(
                    'The "instance_schema" argument must not be passed for Artifact \
               subclass: {}'.format(self.__class__))

        # setup self._metadata_fields
        self.TYPE_NAME, self._metadata_fields = artifact_utils.parse_schema(
            self._instance_schema)

        # Instantiate a RuntimeArtifact pb message as the POD data structure.
        self._artifact = pipeline_spec_pb2.RuntimeArtifact()

        # Stores the metadata for the Artifact.
        self.metadata = {}

        self._artifact.type.CopyFrom(
            pipeline_spec_pb2.ArtifactTypeSchema(
                instance_schema=self._instance_schema))

        self._initialized = True
Ejemplo n.º 5
0
    def __init__(self, instance_schema: Optional[str] = None):
        """Constructs an instance of Artifact"""
        if self.__class__ == Artifact:
            if not instance_schema:
                raise ValueError(
                    'The "instance_schema" argument must be passed to specify a '
                    'type for this Artifact.')
            schema_yaml = yaml.safe_load(instance_schema)
            if 'properties' not in schema_yaml:
                raise ValueError(
                    'Invalid instance_schema, properties must be present. '
                    'Got %s' % instance_schema)
            schema = schema_yaml['properties']
            self.TYPE_NAME = yaml.safe_load(instance_schema)['title']
            self.PROPERTIES = {}
            for k, v in schema.items():
                self.PROPERTIES[k] = Property.from_dict(v)
        else:
            if instance_schema:
                raise ValueError(
                    'The "mlmd_artifact_type" argument must not be passed for '
                    'Artifact subclass %s.' % self.__class__)
            instance_schema = self._get_artifact_type()

        # MLMD artifact type schema string.
        self._type_schema = instance_schema
        # Instantiate a RuntimeArtifact pb message as the POD data structure.
        self._artifact = pipeline_spec_pb2.RuntimeArtifact()
        self._artifact.type.CopyFrom(
            pipeline_spec_pb2.ArtifactTypeSchema(
                instance_schema=instance_schema))
        # Initialization flag to prevent recursive getattr / setattr errors.
        self._initialized = True
Ejemplo n.º 6
0
    def test_build_importer_spec_with_invalid_inputs_should_fail(self):
        with self.assertRaisesRegex(
                AssertionError,
                'importer spec should be built using either pipeline_param_name or '
                'constant_value'):
            importer_node.build_importer_spec(
                input_type_schema=pb.ArtifactTypeSchema(
                    schema_title='system.Artifact'),
                pipeline_param_name='param1',
                constant_value='some_uri')

        with self.assertRaisesRegex(
                AssertionError,
                'importer spec should be built using either pipeline_param_name or '
                'constant_value'):
            importer_node.build_importer_spec(
                input_type_schema=pb.ArtifactTypeSchema(
                    schema_title='system.Artifact'))
Ejemplo n.º 7
0
def build_input_artifact_spec(
    channel_spec: channel.Channel
) -> pipeline_pb2.ComponentInputsSpec.ArtifactSpec:
  """Builds artifact type spec for an input channel."""
  result = pipeline_pb2.ComponentInputsSpec.ArtifactSpec()
  result.artifact_type.CopyFrom(
      pipeline_pb2.ArtifactTypeSchema(
          instance_schema=get_artifact_schema(channel_spec.type)))
  _validate_properties_schema(
      instance_schema=result.artifact_type.instance_schema,
      properties=channel_spec.type.PROPERTIES)
  return result
Ejemplo n.º 8
0
def get_artifact_type_schema(
    artifact_class_or_type_name: Optional[Union[str,
                                                Type[artifact_types.Artifact]]]
) -> pipeline_spec_pb2.ArtifactTypeSchema:
    """Gets the IR I/O artifact type msg for the given ComponentSpec I/O
    type."""
    artifact_class = artifact_types.Artifact
    if isinstance(artifact_class_or_type_name, str):
        if re.match(_GOOGLE_TYPES_PATTERN, artifact_class_or_type_name):
            return pipeline_spec_pb2.ArtifactTypeSchema(
                schema_title=artifact_class_or_type_name,
                schema_version=_GOOGLE_TYPES_VERSION,
            )
        artifact_class = _ARTIFACT_CLASSES_MAPPING.get(
            artifact_class_or_type_name.lower(), artifact_types.Artifact)
    elif inspect.isclass(artifact_class_or_type_name) and issubclass(
            artifact_class_or_type_name, artifact_types.Artifact):
        artifact_class = artifact_class_or_type_name

    return pipeline_spec_pb2.ArtifactTypeSchema(
        schema_title=artifact_class.TYPE_NAME,
        schema_version=artifact_class.VERSION)
Ejemplo n.º 9
0
def get_artifact_type_schema_message(
        type_name: str) -> pipeline_spec_pb2.ArtifactTypeSchema:
    """Gets the IR I/O artifact type msg for the given ComponentSpec I/O type."""
    if isinstance(type_name, str):
        artifact_class = _ARTIFACT_CLASSES_MAPPING.get(type_name.lower(),
                                                       artifact.Artifact)
        # TODO: migrate all types to system. namespace.
        if artifact_class.TYPE_NAME.startswith('system.'):
            return pipeline_spec_pb2.ArtifactTypeSchema(
                schema_title=artifact_class.TYPE_NAME)
        else:
            return artifact_class.get_ir_type()
    else:
        return artifact.Artifact.get_ir_type()
Ejemplo n.º 10
0
def get_artifact_type_schema(
    artifact_class_or_type_name: Optional[Union[str, Type[io_types.Artifact]]]
) -> pipeline_spec_pb2.ArtifactTypeSchema:
    """Gets the IR I/O artifact type msg for the given ComponentSpec I/O type."""
    artifact_class = io_types.Artifact
    if isinstance(artifact_class_or_type_name, str):
        artifact_class = _ARTIFACT_CLASSES_MAPPING.get(
            artifact_class_or_type_name.lower(), io_types.Artifact)
    elif inspect.isclass(artifact_class_or_type_name) and issubclass(
            artifact_class_or_type_name, io_types.Artifact):
        artifact_class = artifact_class_or_type_name

    return pipeline_spec_pb2.ArtifactTypeSchema(
        schema_title=artifact_class.TYPE_NAME)
Ejemplo n.º 11
0
    def test_build_importer_spec_from_pipeline_param(self):
        expected_importer = {
            'artifactUri': {
                'runtimeParameter': 'param1'
            },
            'typeSchema': {
                'schemaTitle': 'system.Artifact'
            }
        }
        expected_importer_spec = pb.PipelineDeploymentConfig.ImporterSpec()
        json_format.ParseDict(expected_importer, expected_importer_spec)
        importer_spec = importer_node.build_importer_spec(
            input_type_schema=pb.ArtifactTypeSchema(
                schema_title='system.Artifact'),
            pipeline_param_name='param1')

        self.maxDiff = None
        self.assertEqual(expected_importer_spec, importer_spec)
Ejemplo n.º 12
0
    def test_build_importer_spec_from_constant_value(self):
        expected_importer = {
            'artifactUri': {
                'constantValue': {
                    'stringValue': 'some_uri'
                }
            },
            'typeSchema': {
                'schemaTitle': 'system.Artifact'
            }
        }
        expected_importer_spec = pb.PipelineDeploymentConfig.ImporterSpec()
        json_format.ParseDict(expected_importer, expected_importer_spec)
        importer_spec = importer_node.build_importer_spec(
            input_type_schema=pb.ArtifactTypeSchema(
                schema_title='system.Artifact'),
            constant_value='some_uri')

        self.maxDiff = None
        self.assertEqual(expected_importer_spec, importer_spec)
Ejemplo n.º 13
0
def build_output_artifact_spec(
    channel_spec: channel.Channel
) -> pipeline_pb2.ComponentOutputsSpec.ArtifactSpec:
  """Builds artifact type spec for an output channel."""
  # We use the first artifact instance if available from channel, otherwise
  # create one.
  result = pipeline_pb2.ComponentOutputsSpec.ArtifactSpec()
  result.artifact_type.CopyFrom(
      pipeline_pb2.ArtifactTypeSchema(
          instance_schema=get_artifact_schema(channel_spec.type)))

  _validate_properties_schema(
      instance_schema=result.artifact_type.instance_schema,
      properties=channel_spec.type.PROPERTIES)

  if channel_spec.additional_properties:
    result.metadata.update(channel_spec.additional_properties)
  if channel_spec.additional_custom_properties:
    result.metadata.update(channel_spec.additional_custom_properties)
  return result
Ejemplo n.º 14
0
    def setUp(self):
        super().setUp()

        self._executor_invocation = pipeline_pb2.ExecutorInput()
        self._executor_invocation.outputs.output_file = _TEST_OUTPUT_METADATA_JSON
        self._executor_invocation.inputs.parameters[
            'input_base'].string_value = _TEST_INPUT_DIR
        self._executor_invocation.inputs.parameters[
            'output_config'].string_value = '{}'
        self._executor_invocation.inputs.parameters[
            'input_config'].string_value = json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='s1',
                                                pattern='span{SPAN}/split1/*'),
                    example_gen_pb2.Input.Split(name='s2',
                                                pattern='span{SPAN}/split2/*')
                ]))
        self._executor_invocation.outputs.artifacts[
            'examples'].artifacts.append(
                pipeline_pb2.RuntimeArtifact(
                    type=pipeline_pb2.ArtifactTypeSchema(
                        instance_schema=compiler_utils.get_artifact_schema(
                            standard_artifacts.Examples))))

        self._executor_invocation_from_file = fileio.open(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'executor_invocation.json'), 'r').read()

        logging.debug('Executor invocation under test: %s',
                      self._executor_invocation_from_file)
        self._expected_result_from_file = fileio.open(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'expected_output_metadata.json'), 'r').read()
        logging.debug('Expecting output metadata JSON: %s',
                      self._expected_result_from_file)

        # Change working directory after all the testdata files have been read.
        self.enter_context(test_case_utils.change_working_dir(self.tmp_dir))

        fileio.makedirs(os.path.dirname(_TEST_INPUT_DIR))
Ejemplo n.º 15
0
def build_output_artifact_spec(
    channel_spec: channel.Channel
) -> pipeline_pb2.ComponentOutputsSpec.ArtifactSpec:
    """Builds artifact type spec for an output channel."""
    # We use the first artifact instance if available from channel, otherwise
    # create one.
    artifacts = list(channel_spec.get())
    artifact_instance = artifacts[0] if artifacts else channel_spec.type()

    result = pipeline_pb2.ComponentOutputsSpec.ArtifactSpec()
    result.artifact_type.CopyFrom(
        pipeline_pb2.ArtifactTypeSchema(
            instance_schema=get_artifact_schema(artifact_instance)))

    _validate_properties_schema(
        instance_schema=result.artifact_type.instance_schema,
        properties=channel_spec.type.PROPERTIES)

    struct_proto = pack_artifact_properties(artifact_instance)
    if struct_proto:
        result.metadata.CopyFrom(struct_proto)
    return result
Ejemplo n.º 16
0
class ImporterNodeTest(parameterized.TestCase):

  @parameterized.parameters(
      {
          # artifact_uri is a constant value
          'input_uri':
              'gs://artifact',
          'artifact_type_schema':
              pb.ArtifactTypeSchema(schema_title='system.Dataset'),
          'expected_result': {
              'artifactUri': {
                  'constantValue': {
                      'stringValue': 'gs://artifact'
                  }
              },
              'typeSchema': {
                  'schemaTitle': 'system.Dataset'
              }
          }
      },
      {
          # artifact_uri is from PipelineParam
          'input_uri':
              _pipeline_param.PipelineParam(name='uri_to_import'),
          'artifact_type_schema':
              pb.ArtifactTypeSchema(schema_title='system.Model'),
          'expected_result': {
              'artifactUri': {
                  'runtimeParameter': 'uri'
              },
              'typeSchema': {
                  'schemaTitle': 'system.Model'
              }
          },
      })
  def test_build_importer_spec(self, input_uri, artifact_type_schema,
                               expected_result):
    expected_importer_spec = pb.PipelineDeploymentConfig.ImporterSpec()
    json_format.ParseDict(expected_result, expected_importer_spec)
    importer_spec = importer_node._build_importer_spec(
        artifact_uri=input_uri, artifact_type_schema=artifact_type_schema)

    self.maxDiff = None
    self.assertEqual(expected_importer_spec, importer_spec)

  @parameterized.parameters(
      {
          # artifact_uri is a constant value
          'importer_name': 'importer-1',
          'input_uri': 'gs://artifact',
          'expected_result': {
              'taskInfo': {
                  'name': 'importer-1'
              },
              'inputs': {
                  'parameters': {
                      'uri': {
                          'runtimeValue': {
                              'constantValue': {
                                  'stringValue': 'gs://artifact'
                              }
                          }
                      }
                  }
              },
              'componentRef': {
                  'name': 'comp-importer-1'
              },
          }
      },
      {
          # artifact_uri is from PipelineParam
          'importer_name': 'importer-2',
          'input_uri': _pipeline_param.PipelineParam(name='uri_to_import'),
          'expected_result': {
              'taskInfo': {
                  'name': 'importer-2'
              },
              'inputs': {
                  'parameters': {
                      'uri': {
                          'componentInputParameter': 'uri_to_import'
                      }
                  }
              },
              'componentRef': {
                  'name': 'comp-importer-2'
              },
          },
      })
  def test_build_importer_task_spec(self, importer_name, input_uri,
                                    expected_result):
    expected_task_spec = pb.PipelineTaskSpec()
    json_format.ParseDict(expected_result, expected_task_spec)

    task_spec = importer_node._build_importer_task_spec(
        importer_base_name=importer_name, artifact_uri=input_uri)

    self.maxDiff = None
    self.assertEqual(expected_task_spec, task_spec)

  def test_build_importer_component_spec(self):
    expected_importer_component = {
        'inputDefinitions': {
            'parameters': {
                'uri': {
                    'type': 'STRING'
                }
            }
        },
        'outputDefinitions': {
            'artifacts': {
                'artifact': {
                    'artifactType': {
                        'schemaTitle': 'system.Artifact'
                    }
                }
            }
        },
        'executorLabel': 'exec-importer-1'
    }
    expected_importer_comp_spec = pb.ComponentSpec()
    json_format.ParseDict(expected_importer_component,
                          expected_importer_comp_spec)
    importer_comp_spec = importer_node._build_importer_component_spec(
        importer_base_name='importer-1',
        artifact_type_schema=pb.ArtifactTypeSchema(
            schema_title='system.Artifact'))

    self.maxDiff = None
    self.assertEqual(expected_importer_comp_spec, importer_comp_spec)

  def test_import_with_invalid_artifact_uri_value_should_fail(self):
    from kfp.dsl.io_types import Dataset
    with self.assertRaisesRegex(
        ValueError,
        "Importer got unexpected artifact_uri: 123 of type: <class 'int'>."):
      importer_node.importer(artifact_uri=123, artifact_class=Dataset)
Ejemplo n.º 17
0
 def get_ir_type(cls) -> pipeline_spec_pb2.ArtifactTypeSchema:
     return pipeline_spec_pb2.ArtifactTypeSchema(
         instance_schema=cls.get_artifact_type())
Ejemplo n.º 18
0
class TypeUtilsTest(parameterized.TestCase):
    @parameterized.parameters(
        [(item, True) for item in _PARAMETER_TYPES] +
        [(item, False)
         for item in _KNOWN_ARTIFACT_TYPES + _UNKNOWN_ARTIFACT_TYPES])
    def test_is_parameter_type_true(self, type_name, expected_result):
        self.assertEqual(expected_result,
                         type_utils.is_parameter_type(type_name))

    @parameterized.parameters(
        {
            'artifact_class_or_type_name':
            'Model',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Model',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            artifact_types.Model,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Model',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            'Dataset',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Dataset',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            artifact_types.Dataset,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Dataset',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            'Metrics',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Metrics',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            artifact_types.Metrics,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Metrics',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            'ClassificationMetrics',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.ClassificationMetrics',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            artifact_types.ClassificationMetrics,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.ClassificationMetrics',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            'SlicedClassificationMetrics',
            'expected_result':
            pb.ArtifactTypeSchema(
                schema_title='system.SlicedClassificationMetrics',
                schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            artifact_types.SlicedClassificationMetrics,
            'expected_result':
            pb.ArtifactTypeSchema(
                schema_title='system.SlicedClassificationMetrics',
                schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            'arbitrary name',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Artifact',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            _ArbitraryClass,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Artifact',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            artifact_types.HTML,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.HTML',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            artifact_types.Markdown,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Markdown',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            'some-google-type',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Artifact',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            'google.VertexModel',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='google.VertexModel',
                                  schema_version='0.0.1')
        },
        {
            'artifact_class_or_type_name':
            _VertexDummy,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='google.VertexDummy',
                                  schema_version='0.0.2')
        },
    )
    def test_get_artifact_type_schema(self, artifact_class_or_type_name,
                                      expected_result):
        self.assertEqual(
            expected_result,
            type_utils.get_artifact_type_schema(artifact_class_or_type_name))

    @parameterized.parameters(
        {
            'given_type': 'Int',
            'expected_type': pb.ParameterType.NUMBER_INTEGER,
        },
        {
            'given_type': 'Integer',
            'expected_type': pb.ParameterType.NUMBER_INTEGER,
        },
        {
            'given_type': int,
            'expected_type': pb.ParameterType.NUMBER_INTEGER,
        },
        {
            'given_type': 'Double',
            'expected_type': pb.ParameterType.NUMBER_DOUBLE,
        },
        {
            'given_type': 'Float',
            'expected_type': pb.ParameterType.NUMBER_DOUBLE,
        },
        {
            'given_type': float,
            'expected_type': pb.ParameterType.NUMBER_DOUBLE,
        },
        {
            'given_type': 'String',
            'expected_type': pb.ParameterType.STRING,
        },
        {
            'given_type': 'Text',
            'expected_type': pb.ParameterType.STRING,
        },
        {
            'given_type': str,
            'expected_type': pb.ParameterType.STRING,
        },
        {
            'given_type': 'Boolean',
            'expected_type': pb.ParameterType.BOOLEAN,
        },
        {
            'given_type': bool,
            'expected_type': pb.ParameterType.BOOLEAN,
        },
        {
            'given_type': 'Dict',
            'expected_type': pb.ParameterType.STRUCT,
        },
        {
            'given_type': dict,
            'expected_type': pb.ParameterType.STRUCT,
        },
        {
            'given_type': 'List',
            'expected_type': pb.ParameterType.LIST,
        },
        {
            'given_type': list,
            'expected_type': pb.ParameterType.LIST,
        },
        {
            'given_type': Dict[str, int],
            'expected_type': pb.ParameterType.STRUCT,
        },
        {
            'given_type': List[Any],
            'expected_type': pb.ParameterType.LIST,
        },
        {
            'given_type': {
                'JsonObject': {
                    'data_type': 'proto:tfx.components.trainer.TrainArgs'
                }
            },
            'expected_type': pb.ParameterType.STRUCT,
        },
    )
    def test_get_parameter_type(self, given_type, expected_type):
        self.assertEqual(expected_type,
                         type_utils.get_parameter_type(given_type))

        # Test get parameter by Python type.
        self.assertEqual(pb.ParameterType.NUMBER_INTEGER,
                         type_utils.get_parameter_type(int))

    def test_get_parameter_type_invalid(self):
        with self.assertRaises(AttributeError):
            type_utils.get_parameter_type_schema(None)

    def test_get_input_artifact_type_schema(self):
        input_specs = [
            v1_structures.InputSpec(name='input1', type='String'),
            v1_structures.InputSpec(name='input2', type='Model'),
            v1_structures.InputSpec(name='input3', type=None),
        ]
        # input not found.
        with self.assertRaises(AssertionError) as cm:
            type_utils.get_input_artifact_type_schema('input0', input_specs)
            self.assertEqual('Input not found.', str(cm))

        # input found, but it doesn't map to an artifact type.
        with self.assertRaises(AssertionError) as cm:
            type_utils.get_input_artifact_type_schema('input1', input_specs)
            self.assertEqual('Input is not an artifact type.', str(cm))

        # input found, and a matching artifact type schema returned.
        self.assertEqual(
            'system.Model',
            type_utils.get_input_artifact_type_schema(
                'input2', input_specs).schema_title)

        # input found, and the default artifact type schema returned.
        self.assertEqual(
            'system.Artifact',
            type_utils.get_input_artifact_type_schema(
                'input3', input_specs).schema_title)

    @parameterized.parameters(
        {
            'given_type': 'String',
            'expected_type': 'String',
            'is_compatible': True,
        },
        {
            'given_type': 'String',
            'expected_type': 'Integer',
            'is_compatible': False,
        },
        {
            'given_type': {
                'type_a': {
                    'property': 'property_b',
                }
            },
            'expected_type': {
                'type_a': {
                    'property': 'property_b',
                }
            },
            'is_compatible': True,
        },
        {
            'given_type': {
                'type_a': {
                    'property': 'property_b',
                }
            },
            'expected_type': {
                'type_a': {
                    'property': 'property_c',
                }
            },
            'is_compatible': False,
        },
        {
            'given_type': 'Artifact',
            'expected_type': 'Model',
            'is_compatible': True,
        },
        {
            'given_type': 'Metrics',
            'expected_type': 'Artifact',
            'is_compatible': True,
        },
    )
    def test_verify_type_compatibility(
        self,
        given_type: Union[str, dict],
        expected_type: Union[str, dict],
        is_compatible: bool,
    ):
        if is_compatible:
            self.assertTrue(
                type_utils.verify_type_compatibility(
                    given_type=given_type,
                    expected_type=expected_type,
                    error_message_prefix='',
                ))
        else:
            with self.assertRaises(InconsistentTypeException):
                type_utils.verify_type_compatibility(
                    given_type=given_type,
                    expected_type=expected_type,
                    error_message_prefix='',
                )

    @parameterized.parameters(
        {
            'given_type': str,
            'expected_type_name': 'String',
        },
        {
            'given_type': int,
            'expected_type_name': 'Integer',
        },
        {
            'given_type': float,
            'expected_type_name': 'Float',
        },
        {
            'given_type': bool,
            'expected_type_name': 'Boolean',
        },
        {
            'given_type': list,
            'expected_type_name': 'List',
        },
        {
            'given_type': dict,
            'expected_type_name': 'Dict',
        },
        {
            'given_type': Any,
            'expected_type_name': None,
        },
    )
    def test_get_canonical_type_name_for_type(
        self,
        given_type,
        expected_type_name,
    ):
        self.assertEqual(
            expected_type_name,
            type_utils.get_canonical_type_name_for_type(given_type))

    @parameterized.parameters(
        {
            'given_type': 'PipelineTaskFinalStatus',
            'expected_result': True,
        },
        {
            'given_type': 'pipelineTaskFinalstatus',
            'expected_result': False,
        },
        {
            'given_type': int,
            'expected_result': False,
        },
    )
    def test_is_task_final_statu_type(self, given_type, expected_result):
        self.assertEqual(expected_result,
                         type_utils.is_task_final_status_type(given_type))
Ejemplo n.º 19
0
class TypeUtilsTest(parameterized.TestCase):
    def test_is_parameter_type(self):
        for type_name in _PARAMETER_TYPES:
            self.assertTrue(type_utils.is_parameter_type(type_name))
        for type_name in _KNOWN_ARTIFACT_TYPES + _UNKNOWN_ARTIFACT_TYPES:
            self.assertFalse(type_utils.is_parameter_type(type_name))

    @parameterized.parameters(
        {
            'artifact_class_or_type_name': 'Model',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Model')
        },
        {
            'artifact_class_or_type_name': io_types.Model,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Model')
        },
        {
            'artifact_class_or_type_name': 'Dataset',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Dataset')
        },
        {
            'artifact_class_or_type_name': io_types.Dataset,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Dataset')
        },
        {
            'artifact_class_or_type_name': 'Metrics',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Metrics')
        },
        {
            'artifact_class_or_type_name': io_types.Metrics,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Metrics')
        },
        {
            'artifact_class_or_type_name':
            'ClassificationMetrics',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.ClassificationMetrics')
        },
        {
            'artifact_class_or_type_name':
            io_types.ClassificationMetrics,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.ClassificationMetrics')
        },
        {
            'artifact_class_or_type_name':
            'SlicedClassificationMetrics',
            'expected_result':
            pb.ArtifactTypeSchema(
                schema_title='system.SlicedClassificationMetrics')
        },
        {
            'artifact_class_or_type_name':
            io_types.SlicedClassificationMetrics,
            'expected_result':
            pb.ArtifactTypeSchema(
                schema_title='system.SlicedClassificationMetrics')
        },
        {
            'artifact_class_or_type_name':
            'arbitrary name',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Artifact')
        },
        {
            'artifact_class_or_type_name':
            _ArbitraryClass,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Artifact')
        },
    )
    def test_get_artifact_type_schema(self, artifact_class_or_type_name,
                                      expected_result):
        self.assertEqual(
            expected_result,
            type_utils.get_artifact_type_schema(artifact_class_or_type_name))

    @parameterized.parameters(
        {
            'given_type': 'Int',
            'expected_type': pb.PrimitiveType.INT,
        },
        {
            'given_type': 'Integer',
            'expected_type': pb.PrimitiveType.INT,
        },
        {
            'given_type': int,
            'expected_type': pb.PrimitiveType.INT,
        },
        {
            'given_type': 'Double',
            'expected_type': pb.PrimitiveType.DOUBLE,
        },
        {
            'given_type': 'Float',
            'expected_type': pb.PrimitiveType.DOUBLE,
        },
        {
            'given_type': float,
            'expected_type': pb.PrimitiveType.DOUBLE,
        },
        {
            'given_type': 'String',
            'expected_type': pb.PrimitiveType.STRING,
        },
        {
            'given_type': 'Text',
            'expected_type': pb.PrimitiveType.STRING,
        },
        {
            'given_type': str,
            'expected_type': pb.PrimitiveType.STRING,
        },
        {
            'given_type': 'Boolean',
            'expected_type': pb.PrimitiveType.STRING,
        },
        {
            'given_type': bool,
            'expected_type': pb.PrimitiveType.STRING,
        },
        {
            'given_type': 'Dict',
            'expected_type': pb.PrimitiveType.STRING,
        },
        {
            'given_type': dict,
            'expected_type': pb.PrimitiveType.STRING,
        },
        {
            'given_type': 'List',
            'expected_type': pb.PrimitiveType.STRING,
        },
        {
            'given_type': list,
            'expected_type': pb.PrimitiveType.STRING,
        },
        {
            'given_type': Dict[str, int],
            'expected_type': pb.PrimitiveType.STRING,
        },
        {
            'given_type': List[Any],
            'expected_type': pb.PrimitiveType.STRING,
        },
        {
            'given_type': {
                'JsonObject': {
                    'data_type': 'proto:tfx.components.trainer.TrainArgs'
                }
            },
            'expected_type': pb.PrimitiveType.STRING,
        },
    )
    def test_get_parameter_type(self, given_type, expected_type):
        self.assertEqual(expected_type,
                         type_utils.get_parameter_type(given_type))

        # Test get parameter by Python type.
        self.assertEqual(pb.PrimitiveType.INT,
                         type_utils.get_parameter_type(int))

    def test_get_parameter_type_invalid(self):
        with self.assertRaises(AttributeError):
            type_utils.get_parameter_type_schema(None)

    def test_get_input_artifact_type_schema(self):
        input_specs = [
            structures.InputSpec(name='input1', type='String'),
            structures.InputSpec(name='input2', type='Model'),
            structures.InputSpec(name='input3', type=None),
        ]
        # input not found.
        with self.assertRaises(AssertionError) as cm:
            type_utils.get_input_artifact_type_schema('input0', input_specs)
            self.assertEqual('Input not found.', str(cm))

        # input found, but it doesn't map to an artifact type.
        with self.assertRaises(AssertionError) as cm:
            type_utils.get_input_artifact_type_schema('input1', input_specs)
            self.assertEqual('Input is not an artifact type.', str(cm))

        # input found, and a matching artifact type schema returned.
        self.assertEqual(
            'system.Model',
            type_utils.get_input_artifact_type_schema(
                'input2', input_specs).schema_title)

        # input found, and the default artifact type schema returned.
        self.assertEqual(
            'system.Artifact',
            type_utils.get_input_artifact_type_schema(
                'input3', input_specs).schema_title)

    def test_get_parameter_type_field_name(self):
        self.assertEqual('string_value',
                         type_utils.get_parameter_type_field_name('String'))
        self.assertEqual('int_value',
                         type_utils.get_parameter_type_field_name('Integer'))
        self.assertEqual('double_value',
                         type_utils.get_parameter_type_field_name('Float'))
Ejemplo n.º 20
0
class TypeUtilsTest(parameterized.TestCase):
    def test_is_parameter_type(self):
        for type_name in _PARAMETER_TYPES:
            self.assertTrue(type_utils.is_parameter_type(type_name))
        for type_name in _KNOWN_ARTIFACT_TYPES + _UNKNOWN_ARTIFACT_TYPES:
            self.assertFalse(type_utils.is_parameter_type(type_name))

    @parameterized.parameters(
        {
            'artifact_class_or_type_name': 'Model',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Model')
        },
        {
            'artifact_class_or_type_name': io_types.Model,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Model')
        },
        {
            'artifact_class_or_type_name': 'Dataset',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Dataset')
        },
        {
            'artifact_class_or_type_name': io_types.Dataset,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Dataset')
        },
        {
            'artifact_class_or_type_name': 'Metrics',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Metrics')
        },
        {
            'artifact_class_or_type_name': io_types.Metrics,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Metrics')
        },
        {
            'artifact_class_or_type_name':
            'ClassificationMetrics',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.ClassificationMetrics')
        },
        {
            'artifact_class_or_type_name':
            io_types.ClassificationMetrics,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.ClassificationMetrics')
        },
        {
            'artifact_class_or_type_name':
            'SlicedClassificationMetrics',
            'expected_result':
            pb.ArtifactTypeSchema(
                schema_title='system.SlicedClassificationMetrics')
        },
        {
            'artifact_class_or_type_name':
            io_types.SlicedClassificationMetrics,
            'expected_result':
            pb.ArtifactTypeSchema(
                schema_title='system.SlicedClassificationMetrics')
        },
        {
            'artifact_class_or_type_name':
            'arbitrary name',
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Artifact')
        },
        {
            'artifact_class_or_type_name':
            _ArbitraryClass,
            'expected_result':
            pb.ArtifactTypeSchema(schema_title='system.Artifact')
        },
    )
    def test_get_artifact_type_schema(self, artifact_class_or_type_name,
                                      expected_result):
        self.assertEqual(
            expected_result,
            type_utils.get_artifact_type_schema(artifact_class_or_type_name))

    def test_get_parameter_type(self):
        # Test get parameter type by name.
        self.assertEqual(pb.PrimitiveType.INT,
                         type_utils.get_parameter_type('Int'))
        self.assertEqual(pb.PrimitiveType.INT,
                         type_utils.get_parameter_type('Integer'))
        self.assertEqual(pb.PrimitiveType.DOUBLE,
                         type_utils.get_parameter_type('Double'))
        self.assertEqual(pb.PrimitiveType.DOUBLE,
                         type_utils.get_parameter_type('Float'))
        self.assertEqual(pb.PrimitiveType.STRING,
                         type_utils.get_parameter_type('String'))
        self.assertEqual(pb.PrimitiveType.STRING,
                         type_utils.get_parameter_type('Str'))

        # Test get parameter by Python type.
        self.assertEqual(pb.PrimitiveType.INT,
                         type_utils.get_parameter_type(int))
        self.assertEqual(pb.PrimitiveType.DOUBLE,
                         type_utils.get_parameter_type(float))
        self.assertEqual(pb.PrimitiveType.STRING,
                         type_utils.get_parameter_type(str))

        with self.assertRaises(AttributeError):
            type_utils.get_parameter_type_schema(None)

        with self.assertRaisesRegex(TypeError, 'Got illegal parameter type.'):
            type_utils.get_parameter_type(bool)

    def test_get_input_artifact_type_schema(self):
        input_specs = [
            structures.InputSpec(name='input1', type='String'),
            structures.InputSpec(name='input2', type='Model'),
            structures.InputSpec(name='input3', type=None),
        ]
        # input not found.
        with self.assertRaises(AssertionError) as cm:
            type_utils.get_input_artifact_type_schema('input0', input_specs)
            self.assertEqual('Input not found.', str(cm))

        # input found, but it doesn't map to an artifact type.
        with self.assertRaises(AssertionError) as cm:
            type_utils.get_input_artifact_type_schema('input1', input_specs)
            self.assertEqual('Input is not an artifact type.', str(cm))

        # input found, and a matching artifact type schema returned.
        self.assertEqual(
            'system.Model',
            type_utils.get_input_artifact_type_schema(
                'input2', input_specs).schema_title)

        # input found, and the default artifact type schema returned.
        self.assertEqual(
            'system.Artifact',
            type_utils.get_input_artifact_type_schema(
                'input3', input_specs).schema_title)

    def test_get_parameter_type_field_name(self):
        self.assertEqual('string_value',
                         type_utils.get_parameter_type_field_name('String'))
        self.assertEqual('int_value',
                         type_utils.get_parameter_type_field_name('Integer'))
        self.assertEqual('double_value',
                         type_utils.get_parameter_type_field_name('Float'))