示例#1
0
    def create_artifact_with_type(
            self, artifact: metadata_store_pb2.Artifact,
            artifact_type: metadata_store_pb2.ArtifactType) -> int:
        """Creates an artifact with a type.

    This first gets the type (or creates it if it does not exist), and then
    puts the artifact into the database with that type.

    The type_id should not be specified in the artifact (it is ignored).

    Note that this is not a transaction!
    1. First, the type is created as a transaction.
    2. Then the artifact is created as a transaction.

    Args:
      artifact: the artifact to create (no id or type_id)
      artifact_type: the type of the new artifact (no id)

    Returns:
      the artifact ID of the resulting type.

    Raises:
      InvalidArgument: if the type is not the same as one with the same name
        already in the database.
    """
        type_id = self.put_artifact_type(artifact_type)
        artifact_copy = metadata_store_pb2.Artifact()
        artifact_copy.CopyFrom(artifact)
        artifact_copy.type_id = type_id
        [artifact_id] = self.put_artifacts([artifact_copy])
        return artifact_id
示例#2
0
    def testArtifactAndEventPairs(self):
        example = standard_artifacts.Examples()
        example.uri = 'example'
        example.id = 1

        expected_artifact = metadata_store_pb2.Artifact()
        text_format.Parse(
            """
        id: 1
        type_id: 1
        uri: 'example'""", expected_artifact)
        expected_event = metadata_store_pb2.Event()
        text_format.Parse(
            """
        path {
          steps {
            key: 'example'
          }
          steps {
            index: 0
          }
        }
        type: INPUT""", expected_event)

        with metadata.Metadata(connection_config=self._connection_config) as m:
            result = execution_lib._create_artifact_and_event_pairs(
                m, {
                    'example': [example],
                }, metadata_store_pb2.Event.INPUT)

            self.assertCountEqual([(expected_artifact, expected_event)],
                                  result)
示例#3
0
    def test_put_artifacts_get_artifacts_by_uri(self):
        store = _get_metadata_store()
        artifact_type = _create_example_artifact_type()
        type_id = store.put_artifact_type(artifact_type)
        want_artifact = metadata_store_pb2.Artifact()
        want_artifact.type_id = type_id
        want_artifact.uri = "test_uri"
        other_artifact = metadata_store_pb2.Artifact()
        other_artifact.uri = "other_uri"
        other_artifact.type_id = type_id

        [want_artifact_id,
         _] = store.put_artifacts([want_artifact, other_artifact])
        artifact_result = store.get_artifacts_by_uri(want_artifact.uri)
        self.assertLen(artifact_result, 1)
        self.assertEqual(artifact_result[0].id, want_artifact_id)
  def test_put_artifacts_get_artifact_by_type_and_name(self):
    # Prepare test data.
    store = _get_metadata_store()
    artifact_type = _create_example_artifact_type(self._get_test_type_name())
    type_id = store.put_artifact_type(artifact_type)
    artifact = metadata_store_pb2.Artifact()
    artifact.type_id = type_id
    artifact.name = self._get_test_type_name()
    [artifact_id] = store.put_artifacts([artifact])

    # Test Artifact found case.
    got_artifact = store.get_artifact_by_type_and_name(artifact_type.name,
                                                       artifact.name)
    self.assertEqual(got_artifact.id, artifact_id)
    self.assertEqual(got_artifact.type_id, type_id)
    self.assertEqual(got_artifact.name, artifact.name)

    # Test Artifact not found cases.
    empty_artifact = store.get_artifact_by_type_and_name(
        "random_name", artifact.name)
    self.assertIsNone(empty_artifact)
    empty_artifact = store.get_artifact_by_type_and_name(
        artifact_type.name, "random_name")
    self.assertIsNone(empty_artifact)
    empty_artifact = store.get_artifact_by_type_and_name(
        "random_name", "random_name")
    self.assertIsNone(empty_artifact)
示例#5
0
def deserialize_artifact(
    artifact_type: metadata_store_pb2.ArtifactType,
    artifact: Optional[metadata_store_pb2.Artifact] = None) -> Artifact:
  """Reconstruct Artifact object from MLMD proto descriptors.

  Internal method, no backwards compatibility guarantees.

  Args:
    artifact_type: A metadata_store_pb2.ArtifactType proto object describing the
      type of the artifact.
    artifact: A metadata_store_pb2.Artifact proto object describing the contents
      of the artifact.  If not provided, an Artifact of the desired type with
      empty contents is created.

  Returns:
    Artifact subclass object for the given MLMD proto descriptors.
  """
  # Validate inputs.
  if not isinstance(artifact_type, metadata_store_pb2.ArtifactType):
    raise ValueError(
        ('Expected metadata_store_pb2.ArtifactType for artifact_type, got %s '
         'instead') % (artifact_type,))
  if artifact and not isinstance(artifact, metadata_store_pb2.Artifact):
    raise ValueError(
        ('Expected metadata_store_pb2.Artifact for artifact, got %s '
         'instead') % (artifact,))

  # Get the artifact's class and construct the Artifact object.
  artifact_cls = get_artifact_type_class(artifact_type)
  result = artifact_cls()
  result.artifact_type.CopyFrom(artifact_type)
  result.set_mlmd_artifact(artifact or metadata_store_pb2.Artifact())
  return result
示例#6
0
  def from_json_dict(cls, dict_data: Dict[Text, Any]) -> Any:
    module_name = dict_data['__artifact_class_module__']
    class_name = dict_data['__artifact_class_name__']
    artifact = metadata_store_pb2.Artifact()
    artifact_type = metadata_store_pb2.ArtifactType()
    json_format.Parse(json.dumps(dict_data['artifact']), artifact)
    json_format.Parse(json.dumps(dict_data['artifact_type']), artifact_type)

    # First, try to resolve the specific class used for the artifact; if this
    # is not possible, use a generic artifact.Artifact object.
    result = None
    try:
      artifact_cls = getattr(importlib.import_module(module_name), class_name)
      # If the artifact type is the base Artifact class, do not construct the
      # object here since that constructor requires the mlmd_artifact_type
      # argument.
      if artifact_cls != Artifact:
        result = artifact_cls()
    except (AttributeError, ImportError, ValueError):
      absl.logging.warning((
          'Could not load artifact class %s.%s; using fallback deserialization '
          'for the relevant artifact. This behavior may not be supported in '
          'the future; please make sure that any artifact classes can be '
          'imported within your container or environment.') %
                           (module_name, class_name))
    if not result:
      result = Artifact(mlmd_artifact_type=artifact_type)
    result.set_mlmd_artifact_type(artifact_type)
    result.set_mlmd_artifact(artifact)
    return result
示例#7
0
文件: types.py 项目: dizcology/tfx
    def __init__(self, type_name: Text, split: Text = ''):
        """Construct an instance of TfxType.

    Each instance of TfxTypes wraps an Artifact and its type internally. When
    first created, the artifact will have an empty URI (which will be filled by
    orchestration system before first usage).

    Args:
      type_name: Name of underlying ArtifactType.
      split: Which split this instance of articact maps to.
    """
        artifact_type = metadata_store_pb2.ArtifactType()
        artifact_type.name = type_name
        artifact_type.properties['type_name'] = metadata_store_pb2.STRING
        # This is a temporary solution due to b/123435989.
        artifact_type.properties['name'] = metadata_store_pb2.STRING
        # This indicates the state of an artifact. A state can be any of the
        # followings: PENDING, PUBLISHED, MISSING, DELETING, DELETED
        # TODO(ruoyu): Maybe switch to artifact top-level state if it's supported.
        artifact_type.properties['state'] = metadata_store_pb2.STRING
        # Span number of an artifact. For the same artifact type produced by the
        # same executor, this number should always increase.
        artifact_type.properties['span'] = metadata_store_pb2.INT
        # Comma separated splits recognized. Empty string means artifact has no
        # split.
        artifact_type.properties['split'] = metadata_store_pb2.STRING

        self.artifact_type = artifact_type

        artifact = metadata_store_pb2.Artifact()
        artifact.properties['type_name'].string_value = type_name
        artifact.properties['split'].string_value = split

        self.artifact = artifact
示例#8
0
    def test_put_duplicated_attributions_and_empty_associations(self):
        store = _get_metadata_store()
        context_type = _create_example_context_type()
        context_type_id = store.put_context_type(context_type)
        want_context = metadata_store_pb2.Context()
        want_context.type_id = context_type_id
        want_context.name = "context"
        [context_id] = store.put_contexts([want_context])
        want_context.id = context_id

        artifact_type = _create_example_artifact_type()
        artifact_type_id = store.put_artifact_type(artifact_type)
        want_artifact = metadata_store_pb2.Artifact()
        want_artifact.type_id = artifact_type_id
        want_artifact.uri = "testuri"
        [artifact_id] = store.put_artifacts([want_artifact])
        want_artifact.id = artifact_id

        attribution = metadata_store_pb2.Attribution()
        attribution.artifact_id = want_artifact.id
        attribution.context_id = want_context.id
        store.put_attributions_and_associations([attribution, attribution], [])

        got_contexts = store.get_contexts_by_artifact(want_artifact.id)
        self.assertLen(got_contexts, 1)
        self.assertEqual(got_contexts[0].id, want_context.id)
        self.assertEqual(got_contexts[0].name, want_context.name)
        got_arifacts = store.get_artifacts_by_context(want_context.id)
        self.assertLen(got_arifacts, 1)
        self.assertEqual(got_arifacts[0].uri, want_artifact.uri)
        self.assertEmpty(store.get_executions_by_context(want_context.id))
示例#9
0
    def test_put_events_get_events(self):
        store = _get_metadata_store()
        execution_type = metadata_store_pb2.ExecutionType()
        execution_type.name = "execution_type"
        execution_type_id = store.put_execution_type(execution_type)
        execution = metadata_store_pb2.Execution()
        execution.type_id = execution_type_id
        [execution_id] = store.put_executions([execution])
        artifact_type = metadata_store_pb2.ArtifactType()
        artifact_type.name = "artifact_type"
        artifact_type_id = store.put_artifact_type(artifact_type)
        artifact = metadata_store_pb2.Artifact()
        artifact.type_id = artifact_type_id
        [artifact_id] = store.put_artifacts([artifact])

        event = metadata_store_pb2.Event()
        event.type = metadata_store_pb2.Event.DECLARED_OUTPUT
        event.artifact_id = artifact_id
        event.execution_id = execution_id
        store.put_events([event])
        [event_result] = store.get_events_by_artifact_ids([artifact_id])
        self.assertEqual(event_result.artifact_id, artifact_id)
        self.assertEqual(event_result.execution_id, execution_id)
        self.assertEqual(event_result.type,
                         metadata_store_pb2.Event.DECLARED_OUTPUT)

        [event_result_2] = store.get_events_by_execution_ids([execution_id])
        self.assertEqual(event_result_2.artifact_id, artifact_id)
        self.assertEqual(event_result_2.execution_id, execution_id)
        self.assertEqual(event_result_2.type,
                         metadata_store_pb2.Event.DECLARED_OUTPUT)
    def _write_tfdv(self,
                    tfdv_path: str,
                    train_dataset_name: str,
                    train_features: List[str],
                    eval_dataset_name: str,
                    eval_features: List[str],
                    store: Optional[mlmd.MetadataStore] = None):

        a_bucket = statistics_pb2.RankHistogram.Bucket(low_rank=0,
                                                       high_rank=0,
                                                       label='a',
                                                       sample_count=4.0)
        b_bucket = statistics_pb2.RankHistogram.Bucket(low_rank=1,
                                                       high_rank=1,
                                                       label='b',
                                                       sample_count=3.0)
        c_bucket = statistics_pb2.RankHistogram.Bucket(low_rank=2,
                                                       high_rank=2,
                                                       label='c',
                                                       sample_count=2.0)

        train_stats = statistics_pb2.DatasetFeatureStatistics()
        train_stats.name = train_dataset_name
        for feature in train_features:
            train_stats.features.add()
            train_stats.features[0].name = feature
            train_stats.features[0].string_stats.rank_histogram.buckets.extend(
                [a_bucket, b_bucket, c_bucket])
        train_stats_list = statistics_pb2.DatasetFeatureStatisticsList(
            datasets=[train_stats])
        train_stats_file = os.path.join(tfdv_path, 'Split-train',
                                        'FeatureStats.pb')
        os.makedirs(os.path.dirname(train_stats_file), exist_ok=True)
        with open(train_stats_file, mode='wb') as f:
            f.write(train_stats_list.SerializeToString())

        eval_stats = statistics_pb2.DatasetFeatureStatistics()
        eval_stats.name = eval_dataset_name
        for feature in eval_features:
            eval_stats.features.add()
            eval_stats.features[0].path.step.append(feature)
            eval_stats.features[0].string_stats.rank_histogram.buckets.extend(
                [a_bucket, b_bucket, c_bucket])
        eval_stats_list = statistics_pb2.DatasetFeatureStatisticsList(
            datasets=[eval_stats])
        eval_stats_file = os.path.join(tfdv_path, 'Split-eval',
                                       'FeatureStats.pb')
        os.makedirs(os.path.dirname(eval_stats_file), exist_ok=True)
        with open(eval_stats_file, mode='wb') as f:
            f.write(eval_stats_list.SerializeToString())

        if store:
            stats_type = metadata_store_pb2.ArtifactType()
            stats_type.name = standard_artifacts.ExampleStatistics.TYPE_NAME
            stats_type_id = store.put_artifact_type(stats_type)

            artifact = metadata_store_pb2.Artifact()
            artifact.uri = tfdv_path
            artifact.type_id = stats_type_id
            store.put_artifacts([artifact])
示例#11
0
文件: types.py 项目: zhitaoli/tfx
def _create_tfx_artifact(type_name, split):
    """Create an Artifact for TfxType.__init__()."""
    artifact = metadata_store_pb2.Artifact()
    # TODO(martinz): consider whether type_name needs to be hard-coded into the
    # artifact.
    artifact.properties['type_name'].string_value = type_name
    artifact.properties['split'].string_value = split
    return artifact
示例#12
0
  def test_put_artifacts_get_artifacts_by_type(self):
    store = _get_metadata_store()
    artifact_type = _create_example_artifact_type()
    type_id = store.put_artifact_type(artifact_type)
    artifact_type_2 = _create_example_artifact_type_2()
    type_id_2 = store.put_artifact_type(artifact_type_2)
    artifact_0 = metadata_store_pb2.Artifact()
    artifact_0.type_id = type_id
    artifact_0.properties["foo"].int_value = 3
    artifact_0.properties["bar"].string_value = "Hello"
    artifact_1 = metadata_store_pb2.Artifact()
    artifact_1.type_id = type_id_2

    [_, artifact_id_1] = store.put_artifacts([artifact_0, artifact_1])
    artifact_result = store.get_artifacts_by_type(artifact_type_2.name)
    self.assertLen(artifact_result, 1)
    self.assertEqual(artifact_result[0].id, artifact_id_1)
 def test_put_artifacts_get_artifacts_by_id_with_set(self):
   store = _get_metadata_store()
   artifact_type = _create_example_artifact_type(self._get_test_type_name())
   type_id = store.put_artifact_type(artifact_type)
   artifact = metadata_store_pb2.Artifact()
   artifact.type_id = type_id
   [artifact_id] = store.put_artifacts([artifact])
   [artifact_result] = store.get_artifacts_by_id({artifact_id})
   self.assertEqual(artifact_result.type_id, artifact.type_id)
示例#14
0
  def testResolveInputArtifacts(self):
    # Create input splits.
    split1 = os.path.join(self._input_base_path, 'split1', 'data')
    io_utils.write_string_file(split1, 'testing')
    os.utime(split1, (0, 1))
    split2 = os.path.join(self._input_base_path, 'split2', 'data')
    io_utils.write_string_file(split2, 'testing2')
    os.utime(split2, (0, 3))

    # Mock artifact.
    artifacts = []
    for i in [4, 3, 2, 1]:
      artifact = metadata_store_pb2.Artifact()
      artifact.id = i
      artifact.uri = self._input_base_path
      artifact.custom_properties['span'].string_value = '0'
      # Only odd ids will be matched
      if i % 2 == 1:
        artifact.custom_properties[
            'input_fingerprint'].string_value = 'split:s1,num_files:1,total_bytes:7,xor_checksum:1,sum_checksum:1\nsplit:s2,num_files:1,total_bytes:8,xor_checksum:3,sum_checksum:3'
      else:
        artifact.custom_properties[
            'input_fingerprint'].string_value = 'not_match'
      artifacts.append(artifact)

    # Create exec proterties.
    exec_properties = {
        'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='s1', pattern='split1/*'),
                    example_gen_pb2.Input.Split(name='s2', pattern='split2/*')
                ]),
                preserving_proto_field_name=True),
    }

    # Cache not hit.
    self._mock_metadata.get_artifacts_by_uri.return_value = [artifacts[0]]
    self._mock_metadata.publish_artifacts.return_value = [artifacts[3]]
    updated_input_dict = self._example_gen_driver.resolve_input_artifacts(
        self._input_channels, exec_properties, None, None)
    self.assertEqual(1, len(updated_input_dict))
    self.assertEqual(1, len(updated_input_dict['input_base']))
    updated_input_base = updated_input_dict['input_base'][0]
    self.assertEqual(1, updated_input_base.id)
    self.assertEqual(self._input_base_path, updated_input_base.uri)

    # Cache hit.
    self._mock_metadata.get_artifacts_by_uri.return_value = artifacts
    self._mock_metadata.publish_artifacts.return_value = []
    updated_input_dict = self._example_gen_driver.resolve_input_artifacts(
        self._input_channels, exec_properties, None, None)
    self.assertEqual(1, len(updated_input_dict))
    self.assertEqual(1, len(updated_input_dict['input_base']))
    updated_input_base = updated_input_dict['input_base'][0]
    self.assertEqual(3, updated_input_base.id)
    self.assertEqual(self._input_base_path, updated_input_base.uri)
示例#15
0
 def from_json_dict(cls, dict_data: Dict[Text, Any]) -> Any:
     artifact = metadata_store_pb2.Artifact()
     json_format.Parse(json.dumps(dict_data['artifact']), artifact)
     artifact_type = metadata_store_pb2.ArtifactType()
     json_format.Parse(json.dumps(dict_data['artifact_type']),
                       artifact_type)
     result = Artifact(artifact_type.name)
     result.set_artifact_type(artifact_type)
     result.set_artifact(artifact)
     return result
示例#16
0
文件: types.py 项目: rohithreddy/tfx
 def parse_from_json_dict(cls, d):
     """Creates a instance of TfxType from a json deserialized dict."""
     artifact = metadata_store_pb2.Artifact()
     json_format.Parse(json.dumps(d['artifact']), artifact)
     artifact_type = metadata_store_pb2.ArtifactType()
     json_format.Parse(json.dumps(d['artifact_type']), artifact_type)
     result = TfxType(artifact_type.name, artifact.uri)
     result.set_artifact_type(artifact_type)
     result.set_artifact(artifact)
     return result
示例#17
0
 def test_create_artifact_with_type_get_artifacts_by_id(self):
   store = _get_metadata_store()
   artifact_type = _create_example_artifact_type()
   artifact = metadata_store_pb2.Artifact()
   artifact.properties["foo"].int_value = 3
   artifact.properties["bar"].string_value = "Hello"
   artifact_id = store.create_artifact_with_type(artifact, artifact_type)
   [artifact_result] = store.get_artifacts_by_id([artifact_id])
   self.assertEqual(artifact_result.properties["bar"].string_value, "Hello")
   self.assertEqual(artifact_result.properties["foo"].int_value, 3)
示例#18
0
 def _create_artifact_with_type(self, uri: str, type_name: str,
                                property_types: dict = None,
                                properties: dict = None,
                                custom_properties: dict = None):
     artifact_type = self._get_or_create_artifact_type(
         type_name=type_name, properties=property_types)
     artifact = metadata_store_pb2.Artifact(
         uri=uri, type_id=artifact_type.id, properties=properties,
         custom_properties=custom_properties)
     artifact.id = self.store.put_artifacts([artifact])[0]
     return artifact
示例#19
0
  def test_update_artifact_get_artifact(self):
    store = _get_metadata_store()
    artifact_type = _create_example_artifact_type()
    type_id = store.put_artifact_type(artifact_type)
    artifact = metadata_store_pb2.Artifact()
    artifact.type_id = type_id
    artifact.properties["bar"].string_value = "Hello"

    [artifact_id] = store.put_artifacts([artifact])
    artifact_2 = metadata_store_pb2.Artifact()
    artifact_2.CopyFrom(artifact)
    artifact_2.id = artifact_id
    artifact_2.properties["foo"].int_value = artifact_id
    artifact_2.properties["bar"].string_value = "Goodbye"
    [artifact_id_2] = store.put_artifacts([artifact_2])
    self.assertEqual(artifact_id, artifact_id_2)

    [artifact_result] = store.get_artifacts_by_id([artifact_id])
    self.assertEqual(artifact_result.properties["bar"].string_value, "Goodbye")
    self.assertEqual(artifact_result.properties["foo"].int_value, artifact_id)
示例#20
0
    def test_put_and_use_attributions_and_associations(self):
        store = _get_metadata_store()
        context_type = _create_example_context_type(self._get_test_type_name())
        context_type_id = store.put_context_type(context_type)
        want_context = metadata_store_pb2.Context()
        want_context.type_id = context_type_id
        want_context.name = self._get_test_type_name()
        [context_id] = store.put_contexts([want_context])
        want_context.id = context_id

        execution_type = _create_example_execution_type(
            self._get_test_type_name())
        execution_type_id = store.put_execution_type(execution_type)
        want_execution = metadata_store_pb2.Execution()
        want_execution.type_id = execution_type_id
        want_execution.properties["foo"].int_value = 3
        [execution_id] = store.put_executions([want_execution])
        want_execution.id = execution_id

        artifact_type = _create_example_artifact_type(
            self._get_test_type_name())
        artifact_type_id = store.put_artifact_type(artifact_type)
        want_artifact = metadata_store_pb2.Artifact()
        want_artifact.type_id = artifact_type_id
        want_artifact.uri = "testuri"
        [artifact_id] = store.put_artifacts([want_artifact])
        want_artifact.id = artifact_id

        # insert attribution and association and test querying the relationship
        attribution = metadata_store_pb2.Attribution()
        attribution.artifact_id = want_artifact.id
        attribution.context_id = want_context.id
        association = metadata_store_pb2.Association()
        association.execution_id = want_execution.id
        association.context_id = want_context.id
        store.put_attributions_and_associations([attribution], [association])

        # test querying the relationship
        got_contexts = store.get_contexts_by_artifact(want_artifact.id)
        self.assertLen(got_contexts, 1)
        self.assertEqual(got_contexts[0].id, want_context.id)
        self.assertEqual(got_contexts[0].name, want_context.name)
        got_arifacts = store.get_artifacts_by_context(want_context.id)
        self.assertLen(got_arifacts, 1)
        self.assertEqual(got_arifacts[0].uri, want_artifact.uri)
        got_executions = store.get_executions_by_context(want_context.id)
        self.assertLen(got_executions, 1)
        self.assertEqual(got_executions[0].properties["foo"],
                         want_execution.properties["foo"])
        got_contexts = store.get_contexts_by_execution(want_execution.id)
        self.assertLen(got_contexts, 1)
        self.assertEqual(got_contexts[0].id, want_context.id)
        self.assertEqual(got_contexts[0].name, want_context.name)
示例#21
0
 def test_log_invalid_artifacts_should_fail(self):
   store = metadata.Store(grpc_host=GRPC_HOST, grpc_port=GRPC_PORT)
   ws = metadata.Workspace(store=store,
                           name="ws_1",
                           description="a workspace for testing",
                           labels={"n1": "v1"})
   e = metadata.Execution(name="test execution", workspace=ws)
   artifact1 = ArtifactFixture(
       mlpb.Artifact(uri="gs://uri",
                     custom_properties={
                         metadata._WORKSPACE_PROPERTY_NAME:
                             mlpb.Value(string_value="ws1"),
                     }))
   self.assertRaises(ValueError, e.log_input, artifact1)
   artifact2 = ArtifactFixture(
       mlpb.Artifact(uri="gs://uri",
                     custom_properties={
                         metadata._RUN_PROPERTY_NAME:
                             mlpb.Value(string_value="run1"),
                     }))
   self.assertRaises(ValueError, e.log_output, artifact2)
 def test_put_invalid_artifact(self):
   store = _get_metadata_store()
   artifact_type = _create_example_artifact_type(self._get_test_type_name())
   artifact_type_id = store.put_artifact_type(artifact_type)
   artifact = metadata_store_pb2.Artifact()
   artifact.type_id = artifact_type_id
   artifact.uri = "testuri"
   # Create the Value message for "foo" but don't populate its value.
   artifact.properties["foo"]  # pylint: disable=pointless-statement
   with self.assertRaisesRegex(errors.InvalidArgumentError,
                               "Found unmatched property type: foo"):
     store.put_artifacts([artifact])
 def test_put_artifacts_get_artifacts_by_id(self):
   store = _get_metadata_store()
   artifact_type = _create_example_artifact_type(self._get_test_type_name())
   type_id = store.put_artifact_type(artifact_type)
   artifact = metadata_store_pb2.Artifact()
   artifact.type_id = type_id
   artifact.properties["foo"].int_value = 3
   artifact.properties["bar"].string_value = "Hello"
   [artifact_id] = store.put_artifacts([artifact])
   [artifact_result] = store.get_artifacts_by_id([artifact_id])
   self.assertEqual(artifact_result.properties["bar"].string_value, "Hello")
   self.assertEqual(artifact_result.properties["foo"].int_value, 3)
示例#24
0
  def test_put_artifacts_get_artifacts(self):
    store = _get_metadata_store()
    artifact_type = _create_example_artifact_type()
    type_id = store.put_artifact_type(artifact_type)
    artifact_0 = metadata_store_pb2.Artifact()
    artifact_0.type_id = type_id
    artifact_0.properties["foo"].int_value = 3
    artifact_0.properties["bar"].string_value = "Hello"
    artifact_1 = metadata_store_pb2.Artifact()
    artifact_1.type_id = type_id

    [artifact_id_0,
     artifact_id_1] = store.put_artifacts([artifact_0, artifact_1])
    artifact_result = store.get_artifacts()
    if artifact_result[0].id == artifact_id_0:
      [artifact_result_0, artifact_result_1] = artifact_result
    else:
      [artifact_result_1, artifact_result_0] = artifact_result
    self.assertEqual(artifact_result_0.id, artifact_id_0)
    self.assertEqual(artifact_result_0.properties["bar"].string_value, "Hello")
    self.assertEqual(artifact_result_0.properties["foo"].int_value, 3)
    self.assertEqual(artifact_result_1.id, artifact_id_1)
示例#25
0
def mlpb_artifact(type_id, uri, workspace, name=None, version=None):
  properties = {}
  if name:
    properties["name"] = mlpb.Value(string_value=name)
  if version:
    properties["version"] = mlpb.Value(string_value=version)
  return mlpb.Artifact(uri=uri,
                       type_id=type_id,
                       properties=properties,
                       custom_properties={
                           metadata._WORKSPACE_PROPERTY_NAME:
                               mlpb.Value(string_value=workspace),
                       })
示例#26
0
  def test_publish_execution(self):
    store = _get_metadata_store()
    execution_type = metadata_store_pb2.ExecutionType()
    execution_type.name = "execution_type"
    execution_type_id = store.put_execution_type(execution_type)
    execution = metadata_store_pb2.Execution()
    execution.type_id = execution_type_id

    artifact_type = metadata_store_pb2.ArtifactType()
    artifact_type.name = "artifact_type"
    artifact_type_id = store.put_artifact_type(artifact_type)
    input_artifact = metadata_store_pb2.Artifact()
    input_artifact.type_id = artifact_type_id
    output_artifact = metadata_store_pb2.Artifact()
    output_artifact.type_id = artifact_type_id
    output_event = metadata_store_pb2.Event()
    output_event.type = metadata_store_pb2.Event.DECLARED_INPUT

    execution_id, artifact_ids = store.put_execution(
        execution, [[input_artifact], [output_artifact, output_event]])
    self.assertLen(artifact_ids, 2)
    events = store.get_events_by_execution_ids([execution_id])
    self.assertLen(events, 1)
示例#27
0
 def from_json_dict(cls, dict_data: Dict[Text, Any]) -> Any:
     module_name = dict_data['__artifact_class_module__']
     class_name = dict_data['__artifact_class_name__']
     artifact_cls = getattr(importlib.import_module(module_name),
                            class_name)
     artifact = metadata_store_pb2.Artifact()
     json_format.Parse(json.dumps(dict_data['artifact']), artifact)
     artifact_type = metadata_store_pb2.ArtifactType()
     json_format.Parse(json.dumps(dict_data['artifact_type']),
                       artifact_type)
     result = artifact_cls()
     result.set_mlmd_artifact_type(artifact_type)
     result.set_mlmd_artifact(artifact)
     return result
示例#28
0
  def test_put_events_with_paths(self):
    store = _get_metadata_store()
    execution_type = metadata_store_pb2.ExecutionType()
    execution_type.name = "execution_type"
    execution_type_id = store.put_execution_type(execution_type)
    execution = metadata_store_pb2.Execution()
    execution.type_id = execution_type_id
    [execution_id] = store.put_executions([execution])
    artifact_type = metadata_store_pb2.ArtifactType()
    artifact_type.name = "artifact_type"
    artifact_type_id = store.put_artifact_type(artifact_type)
    artifact_0 = metadata_store_pb2.Artifact()
    artifact_0.type_id = artifact_type_id
    artifact_1 = metadata_store_pb2.Artifact()
    artifact_1.type_id = artifact_type_id
    [artifact_id_0,
     artifact_id_1] = store.put_artifacts([artifact_0, artifact_1])

    event_0 = metadata_store_pb2.Event()
    event_0.type = metadata_store_pb2.Event.DECLARED_INPUT
    event_0.artifact_id = artifact_id_0
    event_0.execution_id = execution_id
    event_0.path.steps.add().key = "ggg"

    event_1 = metadata_store_pb2.Event()
    event_1.type = metadata_store_pb2.Event.DECLARED_INPUT
    event_1.artifact_id = artifact_id_1
    event_1.execution_id = execution_id
    event_1.path.steps.add().key = "fff"

    store.put_events([event_0, event_1])
    [event_result_0,
     event_result_1] = store.get_events_by_execution_ids([execution_id])
    self.assertLen(event_result_0.path.steps, 1)
    self.assertEqual(event_result_0.path.steps[0].key, "ggg")
    self.assertLen(event_result_1.path.steps, 1)
    self.assertEqual(event_result_1.path.steps[0].key, "fff")
示例#29
0
    def test_put_artifacts_get_artifacts(self):
        store = _get_metadata_store()
        artifact_type = _create_example_artifact_type(
            self._get_test_type_name())
        type_id = store.put_artifact_type(artifact_type)
        artifact_0 = metadata_store_pb2.Artifact()
        artifact_0.type_id = type_id
        artifact_0.properties["foo"].int_value = 3
        artifact_0.properties["bar"].string_value = "Hello"
        artifact_1 = metadata_store_pb2.Artifact()
        artifact_1.type_id = type_id

        existing_artifacts_count = 0
        try:
            existing_artifacts_count = len(store.get_artifacts())
        except errors.NotFoundError:
            existing_artifacts_count = 0

        [artifact_id_0,
         artifact_id_1] = store.put_artifacts([artifact_0, artifact_1])
        artifact_result = store.get_artifacts()
        new_artifacts_count = len(artifact_result)
        artifact_result = [
            a for a in artifact_result
            if a.id == artifact_id_0 or a.id == artifact_id_1
        ]

        if artifact_result[0].id == artifact_id_0:
            [artifact_result_0, artifact_result_1] = artifact_result
        else:
            [artifact_result_1, artifact_result_0] = artifact_result
        self.assertEqual(existing_artifacts_count + 2, new_artifacts_count)
        self.assertEqual(artifact_result_0.id, artifact_id_0)
        self.assertEqual(artifact_result_0.properties["bar"].string_value,
                         "Hello")
        self.assertEqual(artifact_result_0.properties["foo"].int_value, 3)
        self.assertEqual(artifact_result_1.id, artifact_id_1)
示例#30
0
    def put_artifact(self, properties: Dict[str, str]) -> int:
        """Inserts or updates an artifact in the fake database.

    Args:
      properties: A dictionary of custom properties and values to be added to
        the artifact.

    Returns:
      An artifact id corresponding with the input.
    """
        artifact = metadata_store_pb2.Artifact()
        artifact.uri = 'test/path/'
        artifact.type_id = self.artifact_type_id
        for name, val in properties.items():
            artifact.custom_properties[name].string_value = val
        return self.store.put_artifacts([artifact])[0]