Beispiel #1
0
    def __init__(self, instance_schema: Optional[str] = None):
        """Constructs an instance of Artifact.

    Setups up self._metadata_fields to perform type checking and initialize
    RuntimeArtifact.

    """
        if self.__class__ == Artifact:
            if not instance_schema:
                raise ValueError(
                    'The "instance_schema" argument must be set for Artifact.')
            self._instance_schema = instance_schema
        else:
            if instance_schema:
                raise ValueError(
                    'The "instance_schema" argument must not be passed for Artifact \
               subclass: {}'.format(self.__class__))

        # setup self._metadata_fields
        self.TYPE_NAME, self._metadata_fields = artifact_utils.parse_schema(
            self._instance_schema)

        # Instantiate a RuntimeArtifact pb message as the POD data structure.
        self._artifact = pipeline_spec_pb2.RuntimeArtifact()

        # Stores the metadata for the Artifact.
        self.metadata = {}

        self._artifact.type.CopyFrom(
            pipeline_spec_pb2.ArtifactTypeSchema(
                instance_schema=self._instance_schema))

        self._initialized = True
Beispiel #2
0
    def __init__(self, instance_schema: Optional[str] = None):
        """Constructs an instance of Artifact"""
        if self.__class__ == Artifact:
            if not instance_schema:
                raise ValueError(
                    'The "instance_schema" argument must be passed to specify a '
                    'type for this Artifact.')
            schema_yaml = yaml.safe_load(instance_schema)
            if 'properties' not in schema_yaml:
                raise ValueError(
                    'Invalid instance_schema, properties must be present. '
                    'Got %s' % instance_schema)
            schema = schema_yaml['properties']
            self.TYPE_NAME = yaml.safe_load(instance_schema)['title']
            self.PROPERTIES = {}
            for k, v in schema.items():
                self.PROPERTIES[k] = Property.from_dict(v)
        else:
            if instance_schema:
                raise ValueError(
                    'The "mlmd_artifact_type" argument must not be passed for '
                    'Artifact subclass %s.' % self.__class__)
            instance_schema = self._get_artifact_type()

        # MLMD artifact type schema string.
        self._type_schema = instance_schema
        # Instantiate a RuntimeArtifact pb message as the POD data structure.
        self._artifact = pipeline_spec_pb2.RuntimeArtifact()
        self._artifact.type.CopyFrom(
            pipeline_spec_pb2.ArtifactTypeSchema(
                instance_schema=instance_schema))
        # Initialization flag to prevent recursive getattr / setattr errors.
        self._initialized = True
Beispiel #3
0
    def setUp(self):
        super().setUp()

        self._executor_invocation = pipeline_pb2.ExecutorInput()
        self._executor_invocation.outputs.output_file = _TEST_OUTPUT_METADATA_JSON
        self._executor_invocation.inputs.parameters[
            'input_base'].string_value = _TEST_INPUT_DIR
        self._executor_invocation.inputs.parameters[
            'output_config'].string_value = '{}'
        self._executor_invocation.inputs.parameters[
            'input_config'].string_value = json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='s1',
                                                pattern='span{SPAN}/split1/*'),
                    example_gen_pb2.Input.Split(name='s2',
                                                pattern='span{SPAN}/split2/*')
                ]))
        self._executor_invocation.outputs.artifacts[
            'examples'].artifacts.append(
                pipeline_pb2.RuntimeArtifact(
                    type=pipeline_pb2.ArtifactTypeSchema(
                        instance_schema=compiler_utils.get_artifact_schema(
                            standard_artifacts.Examples))))

        self._executor_invocation_from_file = fileio.open(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'executor_invocation.json'), 'r').read()

        logging.debug('Executor invocation under test: %s',
                      self._executor_invocation_from_file)
        self._expected_result_from_file = fileio.open(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'expected_output_metadata.json'), 'r').read()
        logging.debug('Expecting output metadata JSON: %s',
                      self._expected_result_from_file)

        # Change working directory after all the testdata files have been read.
        self.enter_context(test_case_utils.change_working_dir(self.tmp_dir))

        fileio.makedirs(os.path.dirname(_TEST_INPUT_DIR))
def to_runtime_artifact(
        artifact_instance: artifact.Artifact,
        name_from_id: Mapping[int, str]) -> pipeline_pb2.RuntimeArtifact:
    """Converts TFX artifact instance to RuntimeArtifact proto message."""
    metadata = struct_pb2.Struct()
    json_format.ParseDict(_get_json_metadata_mapping(artifact_instance),
                          metadata)
    result = pipeline_pb2.RuntimeArtifact(uri=artifact_instance.uri,
                                          metadata=metadata)
    # TODO(b/135056715): Change to a unified getter/setter of Artifact type
    # once it's ready.
    # Try convert tfx artifact id to string-typed name. This should be the case
    # when running on an environment where metadata access layer is not running
    # in user space.
    id_or_none = getattr(artifact_instance, 'id', None)
    if (id_or_none is not None and id_or_none in name_from_id):
        result.name = name_from_id[id_or_none]
    else:
        logging.warning(
            'Cannot convert ID back to runtime name for artifact %s',
            artifact_instance)
    return result
Beispiel #5
0
 def deserialize(cls, data: str) -> Any:
   """Deserializes an Artifact object from JSON dict."""
   artifact = pipeline_spec_pb2.RuntimeArtifact()
   json_format.Parse(data, artifact, ignore_unknown_fields=True)
   instance_schema = yaml.safe_load(artifact.type.instance_schema)
   type_name = instance_schema['title'][len('kfp.'):]
   result = None
   try:
     artifact_cls = getattr(
         importlib.import_module(_KFP_ARTIFACT_ONTOLOGY_MODULE), type_name)
     # TODO(numerology): Add deserialization tests for first party classes.
     result = artifact_cls()
   except (AttributeError, ImportError, ValueError):
     logging.warning((
         'Could not load artifact class %s.%s; using fallback deserialization '
         'for the relevant artifact. Please make sure that any artifact '
         'classes can be imported within your container or environment.'),
         _KFP_ARTIFACT_ONTOLOGY_MODULE, type_name)
   if not result:
     # Otherwise generate a generic Artifact object.
     result = Artifact(instance_schema=artifact.type.instance_schema)
   result.runtime_artifact = artifact
   return result
Beispiel #6
0
 def deserialize(cls, data: str) -> Any:
     """Deserializes an Artifact object from JSON dict."""
     artifact = pipeline_spec_pb2.RuntimeArtifact()
     json_format.Parse(data, artifact, ignore_unknown_fields=True)
     return cls.get_from_runtime_artifact(artifact)