예제 #1
0
  def resolve(
      self,
      pipeline_info: data_types.PipelineInfo,
      metadata_handler: metadata.Metadata,
      source_channels: Dict[Text, types.Channel],
  ) -> base_resolver.ResolveResult:
    pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
    if pipeline_context is None:
      raise RuntimeError('Pipeline context absent for %s' % pipeline_context)

    candidate_dict = {}
    for k, c in source_channels.items():
      cancidate_artifacts = metadata_handler.get_qualified_artifacts(
          contexts=[pipeline_context],
          type_name=c.type_name,
          producer_component_id=c.producer_component_id,
          output_key=c.output_key)
      candidate_dict[k] = [
          artifact_utils.deserialize_artifact(a.type, a.artifact)
          for a in cancidate_artifacts
      ]

    resolved_dict = self._resolve(candidate_dict)
    resolve_state_dict = {
        k: len(artifact_list) >= self._desired_num_of_artifact
        for k, artifact_list in resolved_dict.items()
    }

    return base_resolver.ResolveResult(
        per_key_resolve_result=resolved_dict,
        per_key_resolve_state=resolve_state_dict)
예제 #2
0
  def resolve(
      self,
      pipeline_info: data_types.PipelineInfo,
      metadata_handler: metadata.Metadata,
      source_channels: Dict[Text, types.Channel],
  ) -> resolver.ResolveResult:
    # First, checks whether we have exactly Model and ModelBlessing Channels.
    model_channel_key = None
    model_blessing_channel_key = None
    assert len(source_channels) == 2, 'Expecting 2 input Channels'
    for k, c in source_channels.items():
      if issubclass(c.type, standard_artifacts.Model):
        model_channel_key = k
      elif issubclass(c.type, standard_artifacts.ModelBlessing):
        model_blessing_channel_key = k
      else:
        raise RuntimeError('Only expecting Model or ModelBlessing, got %s' %
                           c.type)
    assert model_channel_key is not None, 'Expecting Model as input'
    assert model_blessing_channel_key is not None, ('Expecting ModelBlessing as'
                                                    ' input')

    model_channel = source_channels[model_channel_key]
    model_blessing_channel = source_channels[model_blessing_channel_key]
    # Gets the pipeline context as the artifact search space.
    pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
    if pipeline_context is None:
      raise RuntimeError('Pipeline context absent for %s' % pipeline_context)

    candidate_dict = {}
    # Gets all models in the search space and sort in reverse order by id.
    all_models = metadata_handler.get_qualified_artifacts(
        contexts=[pipeline_context],
        type_name=model_channel.type_name,
        producer_component_id=model_channel.producer_component_id,
        output_key=model_channel.output_key)
    candidate_dict[model_channel_key] = [
        artifact_utils.deserialize_artifact(a.type, a.artifact)
        for a in all_models
    ]
    # Gets all ModelBlessing artifacts in the search space.
    all_model_blessings = metadata_handler.get_qualified_artifacts(
        contexts=[pipeline_context],
        type_name=model_blessing_channel.type_name,
        producer_component_id=model_blessing_channel.producer_component_id,
        output_key=model_blessing_channel.output_key)
    candidate_dict[model_blessing_channel_key] = [
        artifact_utils.deserialize_artifact(a.type, a.artifact)
        for a in all_model_blessings
    ]

    resolved_dict = self._resolve(candidate_dict, model_channel_key,
                                  model_blessing_channel_key)
    resolve_state_dict = {
        k: bool(artifact_list) for k, artifact_list in resolved_dict.items()
    }

    return resolver.ResolveResult(
        per_key_resolve_result=resolved_dict,
        per_key_resolve_state=resolve_state_dict)
예제 #3
0
    def resolve(
        self,
        pipeline_info: data_types.PipelineInfo,
        metadata_handler: metadata.Metadata,
        source_channels: Dict[Text, types.Channel],
    ) -> base_resolver.ResolveResult:
        artifacts_dict = {}
        resolve_state_dict = {}
        pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
        if pipeline_context is None:
            raise RuntimeError('Pipeline context absent for %s' %
                               pipeline_context)
        artifacts_in_context = metadata_handler.get_published_artifacts_by_type_within_context(
            [c.type_name for c in source_channels.values()],
            pipeline_context.id)
        for k, c in source_channels.items():
            previous_artifacts = sorted(artifacts_in_context[c.type_name],
                                        key=lambda m: m.id,
                                        reverse=True)
            if len(previous_artifacts) >= self._desired_num_of_artifact:
                artifacts_dict[k] = [
                    _generate_tfx_artifact(a, c.type)
                    for a in previous_artifacts[:self._desired_num_of_artifact]
                ]
                resolve_state_dict[k] = True
            else:
                artifacts_dict[k] = [
                    _generate_tfx_artifact(a, c.type)
                    for a in previous_artifacts
                ]
                resolve_state_dict[k] = False

        return base_resolver.ResolveResult(
            per_key_resolve_result=artifacts_dict,
            per_key_resolve_state=resolve_state_dict)
예제 #4
0
  def resolve(
      self,
      pipeline_info: data_types.PipelineInfo,
      metadata_handler: metadata.Metadata,
      source_channels: Dict[Text, types.Channel],
  ) -> base_resolver.ResolveResult:
    artifacts_dict = {}
    resolve_state_dict = {}
    pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
    if pipeline_context is None:
      raise RuntimeError('Pipeline context absent for %s' % pipeline_context)
    for k, c in source_channels.items():
      candidate_artifacts = metadata_handler.get_qualified_artifacts(
          context=pipeline_context,
          type_name=c.type_name,
          producer_component_id=c.producer_component_id,
          output_key=c.output_key)
      previous_artifacts = sorted(
          candidate_artifacts, key=lambda a: a.artifact.id, reverse=True)
      if len(previous_artifacts) >= self._desired_num_of_artifact:
        artifacts_dict[k] = [
            artifact_utils.deserialize_artifact(a.type, a.artifact)
            for a in previous_artifacts[:self._desired_num_of_artifact]
        ]
        resolve_state_dict[k] = True
      else:
        artifacts_dict[k] = [
            artifact_utils.deserialize_artifact(a.type, a.artifact)
            for a in previous_artifacts
        ]
        resolve_state_dict[k] = False

    return base_resolver.ResolveResult(
        per_key_resolve_result=artifacts_dict,
        per_key_resolve_state=resolve_state_dict)
예제 #5
0
    def resolve(
        self,
        metadata_handler: metadata.Metadata,
        source_channels: Dict[Text, types.Channel],
    ) -> base_resolver.ResolveResult:
        artifacts_dict = {}
        resolve_state_dict = {}
        for k, c in source_channels.items():
            previous_artifacts = sorted(metadata_handler.get_artifacts_by_type(
                c.type_name),
                                        key=lambda m: m.id,
                                        reverse=True)
            if len(previous_artifacts) >= self._desired_num_of_artifact:
                artifacts_dict[k] = [
                    _generate_tfx_artifact(a, c.type)
                    for a in previous_artifacts[:self._desired_num_of_artifact]
                ]
                resolve_state_dict[k] = True
            else:
                artifacts_dict[k] = [
                    _generate_tfx_artifact(a, c.type)
                    for a in previous_artifacts
                ]
                resolve_state_dict[k] = False

        return base_resolver.ResolveResult(
            per_key_resolve_result=artifacts_dict,
            per_key_resolve_state=resolve_state_dict)
예제 #6
0
  def test_fetch_previous_result(self):
    with Metadata(
        connection_config=self._connection_config, logger=self._logger) as m:

      # Create an 'previous' execution.
      exec_properties = {'log_root': 'path'}
      eid = m.prepare_execution('Test', exec_properties)
      input_artifact = types.TfxType(type_name='ExamplesPath')
      m.publish_artifacts([input_artifact])
      output_artifact = types.TfxType(type_name='ExamplesPath')
      input_dict = {'input': [input_artifact]}
      output_dict = {'output': [output_artifact]}
      m.publish_execution(eid, input_dict, output_dict)

      # Test previous_run.
      self.assertEqual(None, m.previous_run('Test', input_dict, {}))
      self.assertEqual(None, m.previous_run('Test', {}, exec_properties))
      self.assertEqual(None, m.previous_run('Test2', input_dict,
                                            exec_properties))
      self.assertEqual(eid, m.previous_run('Test', input_dict, exec_properties))

      # Test fetch_previous_result_artifacts.
      new_output_artifact = types.TfxType(type_name='ExamplesPath')
      self.assertNotEqual(types.ARTIFACT_STATE_PUBLISHED,
                          new_output_artifact.state)
      new_output_dict = {'output': [new_output_artifact]}
      updated_output_dict = m.fetch_previous_result_artifacts(
          new_output_dict, eid)
      previous_artifact = output_dict['output'][-1].artifact
      current_artifact = updated_output_dict['output'][-1].artifact
      self.assertEqual(types.ARTIFACT_STATE_PUBLISHED,
                       current_artifact.properties['state'].string_value)
      self.assertEqual(previous_artifact.id, current_artifact.id)
      self.assertEqual(previous_artifact.type_id, current_artifact.type_id)
예제 #7
0
    def test_execution(self):
        with Metadata(connection_config=self._connection_config,
                      logger=self._logger) as m:

            # Test prepare_execution.
            exec_properties = {}
            eid = m.prepare_execution('Test', exec_properties)
            [execution] = m.store.get_executions()
            self.assertProtoEquals(
                """
        id: 1
        type_id: 1
        properties {
          key: "state"
          value {
            string_value: "new"
          }
        }""", execution)

            # Test publish_execution.
            input_artifact = types.TfxArtifact(type_name='ExamplesPath')
            m.publish_artifacts([input_artifact])
            output_artifact = types.TfxArtifact(type_name='ExamplesPath')
            input_dict = {'input': [input_artifact]}
            output_dict = {'output': [output_artifact]}
            m.publish_execution(eid, input_dict, output_dict)
            # Make sure artifacts in output_dict are published.
            self.assertEqual(types.ARTIFACT_STATE_PUBLISHED,
                             output_artifact.state)
            # Make sure execution state are changed.
            [execution] = m.store.get_executions_by_id([eid])
            self.assertEqual('complete',
                             execution.properties['state'].string_value)
            # Make sure events are published.
            events = m.store.get_events_by_execution_ids([eid])
            self.assertEqual(2, len(events))
            self.assertEqual(input_artifact.id, events[0].artifact_id)
            self.assertEqual(metadata_store_pb2.Event.DECLARED_INPUT,
                             events[0].type)
            self.assertProtoEquals(
                """
          steps {
            key: "input"
          }
          steps {
            index: 0
          }""", events[0].path)
            self.assertEqual(output_artifact.id, events[1].artifact_id)
            self.assertEqual(metadata_store_pb2.Event.DECLARED_OUTPUT,
                             events[1].type)
            self.assertProtoEquals(
                """
          steps {
            key: "output"
          }
          steps {
            index: 0
          }""", events[1].path)
예제 #8
0
 def test_empty_artifact(self):
     with Metadata(self._connection_config) as m:
         m.publish_artifacts([])
         eid = m.prepare_execution('Test', {})
         m.publish_execution(eid, {}, {})
         [execution] = m.store.get_executions_by_id([eid])
         self.assertProtoEquals(
             """
     id: 1
     type_id: 1
     properties {
       key: "state"
       value {
         string_value: "complete"
       }
     }""", execution)
예제 #9
0
  def resolve(
      self,
      metadata_handler: metadata.Metadata,
      source_channels: Dict[Text, types.Channel],
  ) -> base_resolver.ResolveResult:
    artifacts_dict = {}
    for k, c in source_channels.items():
      previous_artifacts = metadata_handler.get_artifacts_by_type(c.type_name)
      if previous_artifacts:
        latest_mlmd_artifact = max(previous_artifacts, key=lambda m: m.id)
        result_artifact = types.Artifact(type_name=c.type_name)
        result_artifact.set_artifact(latest_mlmd_artifact)
        artifacts_dict[k] = ([result_artifact], True)
      else:
        artifacts_dict[k] = ([], False)

    return base_resolver.ResolveResult(per_key_resolve_result=artifacts_dict)
예제 #10
0
파일: metadata_test.py 프로젝트: zwcdp/tfx
  def test_artifact(self):
    with Metadata(
        connection_config=self._connection_config,
        logger=self._logger) as m:
      self.assertListEqual([], m.get_all_artifacts())

      # Test publish artifact.
      artifact = types.TfxType(type_name='ExamplesPath')
      m.publish_artifacts([artifact])
      [artifact] = m.store.get_artifacts()
      self.assertProtoEquals(
          """id: 1
        type_id: 1
        uri: ""
        properties {
          key: "split"
          value {
            string_value: ""
          }
        }
        properties {
          key: "state"
          value {
            string_value: "published"
          }
        }
        properties {
          key: "type_name"
          value {
            string_value: "ExamplesPath"
          }
        }""", artifact)

      # Test get artifact.
      self.assertListEqual([artifact], m.get_all_artifacts())

      # Test artifact state.
      m.check_artifact_state(artifact, types.ARTIFACT_STATE_PUBLISHED)
      m.update_artifact_state(artifact, types.ARTIFACT_STATE_DELETED)
      m.check_artifact_state(artifact, types.ARTIFACT_STATE_DELETED)
      self.assertRaises(RuntimeError, m.check_artifact_state, artifact,
                        types.ARTIFACT_STATE_PUBLISHED)
예제 #11
0
    def test_get_cached_execution_ids(self):
        with Metadata(connection_config=self._connection_config,
                      logger=self._logger) as m:
            mock_store = mock.Mock()
            mock_store.get_events_by_execution_ids.side_effect = [
                [
                    metadata_store_pb2.Event(
                        artifact_id=1, type=metadata_store_pb2.Event.INPUT)
                ],
                [
                    metadata_store_pb2.Event(
                        artifact_id=1, type=metadata_store_pb2.Event.INPUT),
                    metadata_store_pb2.Event(
                        artifact_id=2, type=metadata_store_pb2.Event.INPUT),
                    metadata_store_pb2.Event(
                        artifact_id=3, type=metadata_store_pb2.Event.INPUT)
                ],
                [
                    metadata_store_pb2.Event(
                        artifact_id=1, type=metadata_store_pb2.Event.INPUT),
                    metadata_store_pb2.Event(
                        artifact_id=2, type=metadata_store_pb2.Event.INPUT),
                ],
            ]
            m._store = mock_store

            input_one = types.TfxArtifact(type_name='ExamplesPath')
            input_one.id = 1
            input_two = types.TfxArtifact(type_name='ExamplesPath')
            input_two.id = 2

            input_dict = {
                'input_one': [input_one],
                'input_two': [input_two],
            }

            self.assertEqual(1,
                             m._get_cached_execution_id(input_dict, [3, 2, 1]))
예제 #12
0
def _prepare_artifact(
    metadata_handler: metadata.Metadata, uri: Text,
    properties: Dict[Text, Any], custom_properties: Dict[Text, Any],
    reimport: bool, output_artifact_class: Type[types.Artifact],
    mlmd_artifact_type: Optional[metadata_store_pb2.ArtifactType]
) -> types.Artifact:
    """Prepares the Importer's output artifact.

  If there is already an artifact in MLMD with the same URI and properties /
  custom properties, that artifact will be reused unless the `reimport`
  argument is set to True.

  Args:
    metadata_handler: The handler of MLMD.
    uri: The uri of the artifact.
    properties: The properties of the artifact, given as a dictionary from
      string keys to integer / string values. Must conform to the declared
      properties of the destination channel's output type.
    custom_properties: The custom properties of the artifact, given as a
      dictionary from string keys to integer / string values.
    reimport: If set to True, will register a new artifact even if it already
      exists in the database.
    output_artifact_class: The class of the output artifact.
    mlmd_artifact_type: The MLMD artifact type of the Artifact to be created.

  Returns:
    An Artifact object representing the imported artifact.
  """
    absl.logging.info(
        'Processing source uri: %s, properties: %s, custom_properties: %s' %
        (uri, properties, custom_properties))

    # Check types of custom properties.
    for key, value in custom_properties.items():
        if not isinstance(value, (int, Text, bytes)):
            raise ValueError((
                'Custom property value for key %r must be a string or integer '
                '(got %r instead)') % (key, value))

    unfiltered_previous_artifacts = metadata_handler.get_artifacts_by_uri(uri)
    # Only consider previous artifacts as candidates to reuse, if the properties
    # of the imported artifact match those of the existing artifact.
    previous_artifacts = []
    for candidate_mlmd_artifact in unfiltered_previous_artifacts:
        is_candidate = True
        candidate_artifact = output_artifact_class(mlmd_artifact_type)
        candidate_artifact.set_mlmd_artifact(candidate_mlmd_artifact)
        for key, value in properties.items():
            if getattr(candidate_artifact, key) != value:
                is_candidate = False
                break
        for key, value in custom_properties.items():
            if isinstance(value, int):
                if candidate_artifact.get_int_custom_property(key) != value:
                    is_candidate = False
                    break
            elif isinstance(value, (Text, bytes)):
                if candidate_artifact.get_string_custom_property(key) != value:
                    is_candidate = False
                    break
        if is_candidate:
            previous_artifacts.append(candidate_mlmd_artifact)

    result = output_artifact_class(mlmd_artifact_type)
    result.uri = uri
    for key, value in properties.items():
        setattr(result, key, value)
    for key, value in custom_properties.items():
        if isinstance(value, int):
            result.set_int_custom_property(key, value)
        elif isinstance(value, (Text, bytes)):
            result.set_string_custom_property(key, value)

    # If a registered artifact has the same uri and properties and the user does
    # not explicitly ask for reimport, reuse that artifact.
    if bool(previous_artifacts) and not reimport:
        absl.logging.info('Reusing existing artifact')
        result.set_mlmd_artifact(max(previous_artifacts, key=lambda m: m.id))

    return result
예제 #13
0
    def resolve(
        self,
        pipeline_info: data_types.PipelineInfo,
        metadata_handler: metadata.Metadata,
        source_channels: Dict[Text, types.Channel],
    ) -> base_resolver.ResolveResult:
        # First, checks whether we have exactly Model and ModelBlessing Channels.
        model_channel_key = None
        model_blessing_channel_key = None
        assert len(source_channels) == 2, 'Expecting 2 input Channels'
        for k, c in source_channels.items():
            if issubclass(c.type, standard_artifacts.Model):
                model_channel_key = k
            elif issubclass(c.type, standard_artifacts.ModelBlessing):
                model_blessing_channel_key = k
            else:
                raise RuntimeError(
                    'Only expecting Model or ModelBlessing, got %s' % c.type)
        assert model_channel_key is not None, 'Expecting Model as input'
        assert model_blessing_channel_key is not None, (
            'Expecting ModelBlessing as'
            ' input')

        # Gets the pipeline context as the artifact search space.
        pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
        if pipeline_context is None:
            raise RuntimeError('Pipeline context absent for %s' %
                               pipeline_context)
        # Gets all artifacts of interests within context with one call.
        artifacts_in_context = metadata_handler.get_published_artifacts_by_type_within_context(
            [
                source_channels[model_channel_key].type_name,
                source_channels[model_blessing_channel_key].type_name
            ], pipeline_context.id)
        # Gets all models in the search space and sort in reverse order by id.
        all_models = sorted(
            artifacts_in_context[source_channels[model_channel_key].type_name],
            key=lambda m: m.id,
            reverse=True)
        # Gets all ModelBlessing artifacts in the search space.
        all_model_blessings = artifacts_in_context[
            source_channels[model_blessing_channel_key].type_name]
        # Makes a dict of {model_id : ModelBlessing artifact} for blessed models.
        all_blessed_model_ids = dict((  # pylint: disable=g-complex-comprehension
            a.custom_properties[
                model_validator.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY].
            int_value, a) for a in all_model_blessings if a.custom_properties[
                model_validator.ARTIFACT_PROPERTY_BLESSED_KEY].int_value == 1)

        artifacts_dict = {
            model_channel_key: [],
            model_blessing_channel_key: []
        }
        resolve_state_dict = {
            model_channel_key: False,
            model_blessing_channel_key: False
        }
        # Iterates all models, if blessed, set as result. As the model list was
        # sorted, it is guaranteed to get the latest blessed model.
        for model in all_models:
            if model.id in all_blessed_model_ids:
                artifacts_dict[model_channel_key] = [
                    _generate_tfx_artifact(model, standard_artifacts.Model)
                ]
                artifacts_dict[model_blessing_channel_key] = [
                    _generate_tfx_artifact(all_blessed_model_ids[model.id],
                                           standard_artifacts.ModelBlessing)
                ]
                resolve_state_dict[model_channel_key] = True
                resolve_state_dict[model_blessing_channel_key] = True
                break

        return base_resolver.ResolveResult(
            per_key_resolve_result=artifacts_dict,
            per_key_resolve_state=resolve_state_dict)