예제 #1
0
def create_new_execution_in_existing_run_context(
    store,
    execution_type_name: str,
    context_id: int,
    # TODO: Remove when UX stops relying on thsese properties
    pipeline_name: str = None,
    run_id: str = None,
    instance_id: str = None,
) -> metadata_store_pb2.Execution:
    pipeline_name = pipeline_name or 'Context_' + str(context_id) + '_pipeline'
    run_id = run_id or 'Context_' + str(context_id) + '_run'
    instance_id = instance_id or execution_type_name
    return create_new_execution_in_existing_context(
        store=store,
        execution_type_name=execution_type_name,
        context_id=context_id,
        execution_type_properties={
            EXECUTION_PIPELINE_NAME_PROPERTY_NAME: metadata_store_pb2.STRING,
            EXECUTION_RUN_ID_PROPERTY_NAME: metadata_store_pb2.STRING,
            EXECUTION_COMPONENT_ID_PROPERTY_NAME: metadata_store_pb2.STRING,
        },
        # TODO: Remove when UX stops relying on thsese properties
        properties={
            EXECUTION_PIPELINE_NAME_PROPERTY_NAME: metadata_store_pb2.Value(string_value=pipeline_name), # Mistakenly used for grouping in the UX
            EXECUTION_RUN_ID_PROPERTY_NAME: metadata_store_pb2.Value(string_value=run_id),
            EXECUTION_COMPONENT_ID_PROPERTY_NAME: metadata_store_pb2.Value(string_value=instance_id), # should set to task ID, not component ID
        },
    )
예제 #2
0
    def _get_or_create_run_context(self):
        run_id = metadata_store_pb2.Value(
            string_value=kfputils.format_kfp_run_id_uri(self.run_uuid))
        workflow_name = metadata_store_pb2.Value(
            string_value=self.workflow_name)
        pipeline_name = metadata_store_pb2.Value(
            string_value=self.pipeline_name)
        context_name = self.workflow_name

        property_types = {
            "run_id": metadata_store_pb2.STRING,
            "pipeline_name": metadata_store_pb2.STRING,
            "workflow_name": metadata_store_pb2.STRING
        }
        properties = {
            "run_id": run_id,
            "pipeline_name": pipeline_name,
            "workflow_name": workflow_name
        }

        return self._get_or_create_context_with_type(
            context_name=context_name,
            type_name=RUN_CONTEXT_TYPE_NAME,
            property_types=property_types,
            properties=properties)
예제 #3
0
def create_new_execution_in_existing_run_context(
    store,
    execution_type_name: str,
    context_id: int,
    pod_name: str,
    # TODO: Remove when UX stops relying on thsese properties
    pipeline_name: str = None,
    run_id: str = None,
    instance_id: str = None,
    custom_properties = None,
) -> metadata_store_pb2.Execution:
    pipeline_name = pipeline_name or 'Context_' + str(context_id) + '_pipeline'
    run_id = run_id or 'Context_' + str(context_id) + '_run'
    instance_id = instance_id or execution_type_name
    mlmd_custom_properties = {}
    for property_name, property_value in (custom_properties or {}).items():
        mlmd_custom_properties[property_name] = value_to_mlmd_value(property_value)
    mlmd_custom_properties[KFP_POD_NAME_EXECUTION_PROPERTY_NAME] = metadata_store_pb2.Value(string_value=pod_name)
    return create_new_execution_in_existing_context(
        store=store,
        execution_type_name=execution_type_name,
        context_id=context_id,
        execution_type_properties={
            EXECUTION_PIPELINE_NAME_PROPERTY_NAME: metadata_store_pb2.STRING,
            EXECUTION_RUN_ID_PROPERTY_NAME: metadata_store_pb2.STRING,
            EXECUTION_COMPONENT_ID_PROPERTY_NAME: metadata_store_pb2.STRING,
        },
        # TODO: Remove when UX stops relying on thsese properties
        properties={
            EXECUTION_PIPELINE_NAME_PROPERTY_NAME: metadata_store_pb2.Value(string_value=pipeline_name), # Mistakenly used for grouping in the UX
            EXECUTION_RUN_ID_PROPERTY_NAME: metadata_store_pb2.Value(string_value=run_id),
            EXECUTION_COMPONENT_ID_PROPERTY_NAME: metadata_store_pb2.Value(string_value=instance_id), # should set to task ID, not component ID
        },
        custom_properties=mlmd_custom_properties,
    )
예제 #4
0
    def _get_context_id(self, reuse_workspace_if_exists):
        ctx = self._get_existing_context()
        if ctx is not None:
            if reuse_workspace_if_exists:
                return ctx.id
            else:
                raise ValueError(
                    'Workspace name {} already exists with id {}. You can initialize workspace with reuse_workspace_if_exists=True if want to reuse it'
                    .format(self.name, ctx.id))
        # Create new context type or get the existing type id.
        ctx_type = mlpb.ContextType(name=self.CONTEXT_TYPE_NAME,
                                    properties={
                                        "description": mlpb.STRING,
                                        "labels": mlpb.STRING
                                    })
        ctx_type_id = _retry(lambda: self.store.put_context_type(ctx_type))

        # Add new context for workspace.
        prop = {}
        if self.description is not None:
            prop["description"] = mlpb.Value(string_value=self.description)
        if self.labels is not None:
            prop["labels"] = mlpb.Value(string_value=json.dumps(self.labels))
        ctx = mlpb.Context(
            type_id=ctx_type_id,
            name=self.name,
            properties=prop,
        )
        ctx_id = _retry(lambda: self.store.put_contexts([ctx])[0])
        return ctx_id
예제 #5
0
def value_to_mlmd_value(value) -> metadata_store_pb2.Value:
    if value is None:
        return metadata_store_pb2.Value()
    if isinstance(value, int):
        return metadata_store_pb2.Value(int_value=value)
    if isinstance(value, float):
        return metadata_store_pb2.Value(double_value=value)
    return metadata_store_pb2.Value(string_value=str(value))
예제 #6
0
class CompilerUtilsTest(tf.test.TestCase, parameterized.TestCase):
    @parameterized.named_parameters(
        ("IntValue", 42, metadata_store_pb2.Value(int_value=42)),
        ("FloatValue", 42.0, metadata_store_pb2.Value(double_value=42.0)),
        ("StrValue", "42", metadata_store_pb2.Value(string_value="42")))
    def testSetFieldValuePb(self, value, expected_pb):
        pb = metadata_store_pb2.Value()
        compiler_utils.set_field_value_pb(pb, value)
        self.assertEqual(pb, expected_pb)

    def testSetFieldValuePbUnsupportedType(self):
        pb = metadata_store_pb2.Value()
        with self.assertRaises(ValueError):
            compiler_utils.set_field_value_pb(pb, True)

    def testSetRuntimeParameterPb(self):
        pb = pipeline_pb2.RuntimeParameter()
        compiler_utils.set_runtime_parameter_pb(pb, "test_name", str,
                                                "test_default_value")
        expected_pb = pipeline_pb2.RuntimeParameter(
            name="test_name",
            type=pipeline_pb2.RuntimeParameter.Type.STRING,
            default_value=metadata_store_pb2.Value(
                string_value="test_default_value"))
        self.assertEqual(expected_pb, pb)

    def testIsResolver(self):
        resolver = ResolverNode(instance_name="test_resolver_name",
                                resolver_class=latest_blessed_model_resolver.
                                LatestBlessedModelResolver)
        self.assertTrue(compiler_utils.is_resolver(resolver))

        example_gen = CsvExampleGen(input=external_input("data_path"))
        self.assertFalse(compiler_utils.is_resolver(example_gen))

    def testIsImporter(self):
        importer = ImporterNode(instance_name="import_schema",
                                source_uri="uri/to/schema",
                                artifact_type=standard_artifacts.Schema)
        self.assertTrue(compiler_utils.is_importer(importer))

        example_gen = CsvExampleGen(input=external_input("data_path"))
        self.assertFalse(compiler_utils.is_importer(example_gen))

    def testEnsureTopologicalOrder(self):
        a = EmptyComponent(name="a")
        b = EmptyComponent(name="b")
        c = EmptyComponent(name="c")
        a.add_downstream_node(b)
        a.add_downstream_node(c)
        valid_orders = {"abc", "acb"}
        for order in itertools.permutations([a, b, c]):
            if "".join([c._instance_name for c in order]) in valid_orders:
                self.assertTrue(compiler_utils.ensure_topological_order(order))
            else:
                self.assertFalse(
                    compiler_utils.ensure_topological_order(order))
예제 #7
0
    def setUp(self):
        super().setUp()

        self.artifact_struct_dict = {
            'a1':
            text_format.Parse(
                """
                elements {
                  artifact {
                    artifact {
                      id: 123
                    }
                    type {
                      name: 't1'
                    }
                  }
                }
                """, metadata_store_service_pb2.ArtifactStructList()),
            'a2':
            text_format.Parse(
                """
                elements {
                  artifact {
                    artifact {
                      id: 456
                    }
                    type {
                      name: 't2'
                    }
                  }
                }
                """, metadata_store_service_pb2.ArtifactStructList())
        }

        self.artifact_dict = {
            'a1': [
                artifact_utils.deserialize_artifact(
                    metadata_store_pb2.ArtifactType(name='t1'),
                    metadata_store_pb2.Artifact(id=123))
            ],
            'a2': [
                artifact_utils.deserialize_artifact(
                    metadata_store_pb2.ArtifactType(name='t2'),
                    metadata_store_pb2.Artifact(id=456))
            ]
        }

        self.metadata_value_dict = {
            'p0': metadata_store_pb2.Value(int_value=0),
            'p1': metadata_store_pb2.Value(int_value=1),
            'p2': metadata_store_pb2.Value(string_value='hello'),
            'p3': metadata_store_pb2.Value(string_value='')
        }
        self.value_dict = {'p0': 0, 'p1': 1, 'p2': 'hello', 'p3': ''}
예제 #8
0
def mlpb_artifact(type_id, uri, workspace, name=None, version=None):
  properties = {}
  if name:
    properties["name"] = mlpb.Value(string_value=name)
  if version:
    properties["version"] = mlpb.Value(string_value=version)
  return mlpb.Artifact(uri=uri,
                       type_id=type_id,
                       properties=properties,
                       custom_properties={
                           metadata._WORKSPACE_PROPERTY_NAME:
                               mlpb.Value(string_value=workspace),
                       })
예제 #9
0
def _replace_pipeline_run_id_in_channel(channel: p_pb2.InputSpec.Channel,
                                        pipeline_run_id: str):
  """Update in place."""
  for context_query in channel.context_queries:
    if context_query.type.name == dsl_constants.PIPELINE_RUN_CONTEXT_TYPE_NAME:
      context_query.name.field_value.CopyFrom(
          mlmd_pb2.Value(string_value=pipeline_run_id))
      return

  channel.context_queries.append(
      p_pb2.InputSpec.Channel.ContextQuery(
          type=mlmd_pb2.ContextType(
              name=dsl_constants.PIPELINE_RUN_CONTEXT_TYPE_NAME),
          name=p_pb2.Value(
              field_value=mlmd_pb2.Value(string_value=pipeline_run_id))))
예제 #10
0
 def testExecutionWatcher_Local(self):
     req = execution_watcher_pb2.UpdateExecutionInfoRequest()
     value = metadata_store_pb2.Value()
     value.string_value = 'string_value'
     req.execution_id = self._execution.id
     req.updates['test_key'].CopyFrom(value)
     res = self.stub.UpdateExecutionInfo(req)
     self.assertEqual(execution_watcher_pb2.UpdateExecutionInfoResponse(),
                      res)
     with self._mlmd_connection as m:
         executions = m.store.get_executions_by_id([self._execution.id])
     self.assertEqual(len(executions), 1)
     self.assertProtoPartiallyEquals("""
   id: 1
   last_known_state: RUNNING
   custom_properties {
     key: "test_key"
     value {
       string_value: "string_value"
     }
   }
   """,
                                     executions[0],
                                     ignored_fields=[
                                         'type_id',
                                         'create_time_since_epoch',
                                         'last_update_time_since_epoch'
                                     ])
예제 #11
0
 def testJsonRoundTrip(self):
   chnl = channel.Channel(
       type=_MyType,
       additional_properties={
           'string_value': metadata_store_pb2.Value(string_value='forty-two')
       },
       additional_custom_properties={
           'int_value': metadata_store_pb2.Value(int_value=42)
       })
   serialized = chnl.to_json_dict()
   rehydrated = channel.Channel.from_json_dict(serialized)
   self.assertIs(chnl.type, rehydrated.type)
   self.assertEqual(chnl.type_name, rehydrated.type_name)
   self.assertEqual(chnl.additional_properties,
                    rehydrated.additional_properties)
   self.assertEqual(chnl.additional_custom_properties,
                    rehydrated.additional_custom_properties)
예제 #12
0
def create_new_output_artifact(
    store,
    execution_id: int,
    context_id: int,
    uri: str,
    type_name: str,
    output_name: str,
    run_id: str = None,
    argo_artifact: dict = None,
) -> metadata_store_pb2.Artifact:
    properties = {
        ARTIFACT_IO_NAME_PROPERTY_NAME:
        metadata_store_pb2.Value(string_value=output_name),
    }
    custom_properties = {}
    if run_id:
        properties[
            ARTIFACT_PIPELINE_NAME_PROPERTY_NAME] = metadata_store_pb2.Value(
                string_value=str(run_id))
        properties[ARTIFACT_RUN_ID_PROPERTY_NAME] = metadata_store_pb2.Value(
            string_value=str(run_id))
    if argo_artifact:
        custom_properties[
            ARTIFACT_ARGO_ARTIFACT_PROPERTY_NAME] = metadata_store_pb2.Value(
                string_value=json.dumps(argo_artifact, sort_keys=True))
    return create_new_artifact_event_and_attribution(
        store=store,
        execution_id=execution_id,
        context_id=context_id,
        uri=uri,
        type_name=type_name,
        event_type=metadata_store_pb2.Event.OUTPUT,
        artifact_name_path=metadata_store_pb2.Event.Path(steps=[
            metadata_store_pb2.Event.Path.Step(key=output_name,
                                               #index=0,
                                               ),
        ]),
        properties=properties,
        artifact_type_properties={
            ARTIFACT_IO_NAME_PROPERTY_NAME: metadata_store_pb2.STRING,
            ARTIFACT_PIPELINE_NAME_PROPERTY_NAME: metadata_store_pb2.STRING,
            ARTIFACT_RUN_ID_PROPERTY_NAME: metadata_store_pb2.STRING,
        },
        custom_properties=custom_properties,
        #milliseconds_since_epoch=int(datetime.now(timezone.utc).timestamp() * 1000), # Happens automatically
    )
예제 #13
0
    def serialized(self):
        properties = {
            "name": mlpb.Value(string_value=self.name),
            "create_time": mlpb.Value(string_value=self.create_time),
            "description": mlpb.Value(string_value=self.description),
        }
        _del_none_properties(properties)

        custom_properties = {}
        if self.workspace is not None:
            custom_properties[_WORKSPACE_PROPERTY_NAME] = mlpb.Value(
                string_value=self.workspace.name)
        if self.run is not None:
            custom_properties[_RUN_PROPERTY_NAME] = mlpb.Value(
                string_value=self.run.name)
        return mlpb.Execution(type_id=self._type_id,
                              properties=properties,
                              custom_properties=custom_properties)
예제 #14
0
def get_or_create_run_context(
    store,
    run_id: str,
) -> metadata_store_pb2.Context:
    context = get_or_create_context_with_type(
        store=store,
        context_name=run_id,
        type_name=RUN_CONTEXT_TYPE_NAME,
        type_properties={
            CONTEXT_PIPELINE_NAME_PROPERTY_NAME: metadata_store_pb2.STRING,
            CONTEXT_RUN_ID_PROPERTY_NAME: metadata_store_pb2.STRING,
        },
        properties={
            CONTEXT_PIPELINE_NAME_PROPERTY_NAME: metadata_store_pb2.Value(string_value=run_id),
            CONTEXT_RUN_ID_PROPERTY_NAME: metadata_store_pb2.Value(string_value=run_id),
        },
    )
    return context
예제 #15
0
 def testSetRuntimeParameterPb(self):
     pb = pipeline_pb2.RuntimeParameter()
     compiler_utils.set_runtime_parameter_pb(pb, "test_name", str,
                                             "test_default_value")
     expected_pb = pipeline_pb2.RuntimeParameter(
         name="test_name",
         type=pipeline_pb2.RuntimeParameter.Type.STRING,
         default_value=metadata_store_pb2.Value(
             string_value="test_default_value"))
     self.assertEqual(expected_pb, pb)
예제 #16
0
 def testSetMetadataValueWithTfxValue(self):
     tfx_value = pipeline_pb2.Value()
     metadata_property = metadata_store_pb2.Value()
     text_format.Parse(
         """
     field_value {
         int_value: 1
     }""", tfx_value)
     data_types_utils.set_metadata_value(metadata_value=metadata_property,
                                         value=tfx_value)
     self.assertProtoEquals('int_value: 1', metadata_property)
예제 #17
0
 def test_log_invalid_artifacts_should_fail(self):
   store = metadata.Store(grpc_host=GRPC_HOST, grpc_port=GRPC_PORT)
   ws = metadata.Workspace(store=store,
                           name="ws_1",
                           description="a workspace for testing",
                           labels={"n1": "v1"})
   e = metadata.Execution(name="test execution", workspace=ws)
   artifact1 = ArtifactFixture(
       mlpb.Artifact(uri="gs://uri",
                     custom_properties={
                         metadata._WORKSPACE_PROPERTY_NAME:
                             mlpb.Value(string_value="ws1"),
                     }))
   self.assertRaises(ValueError, e.log_input, artifact1)
   artifact2 = ArtifactFixture(
       mlpb.Artifact(uri="gs://uri",
                     custom_properties={
                         metadata._RUN_PROPERTY_NAME:
                             mlpb.Value(string_value="run1"),
                     }))
   self.assertRaises(ValueError, e.log_output, artifact2)
예제 #18
0
 def testSetMetadataValueWithTfxValueFailed(self):
     tfx_value = pipeline_pb2.Value()
     metadata_property = metadata_store_pb2.Value()
     text_format.Parse(
         """
     runtime_parameter {
       name: 'rp'
     }""", tfx_value)
     with self.assertRaisesRegex(ValueError,
                                 'Expecting field_value but got'):
         data_types_utils.set_metadata_value(
             metadata_value=metadata_property, value=tfx_value)
예제 #19
0
def build_metadata_value_dict(
    value_dict: Mapping[str, types.ExecPropertyTypes]
) -> Dict[str, metadata_store_pb2.Value]:
  """Converts plain value dict into MLMD value dict."""
  result = {}
  if not value_dict:
    return result
  for k, v in value_dict.items():
    if v is None:
      continue
    value = metadata_store_pb2.Value()
    result[k] = set_metadata_value(value, v)
  return result
예제 #20
0
def parse_execution_properties(dict_data: Dict[Text, Any]) -> Dict[Text, Any]:
  """Parses a dict from key to Value proto as execution properties."""
  result = {}
  for k, v in dict_data.items():
    # Translate each field from Value pb to plain value.
    value_pb = metadata_store_pb2.Value()
    json_format.Parse(json.dumps(v), value_pb)
    result[k] = getattr(value_pb, value_pb.WhichOneof('value'))
    if result[k] is None:
      raise TypeError('Unrecognized type encountered at field %s of execution'
                      ' properties %s' % (k, dict_data))

  return result
예제 #21
0
def get_mlmd_value(
    kubeflow_value: pipeline_pb2.Value) -> metadata_store_pb2.Value:
  """Converts Kubeflow pipeline Value pb message to MLMD Value."""
  result = metadata_store_pb2.Value()
  if kubeflow_value.WhichOneof('value') == 'int_value':
    result.int_value = kubeflow_value.int_value
  elif kubeflow_value.WhichOneof('value') == 'double_value':
    result.double_value = kubeflow_value.double_value
  elif kubeflow_value.WhichOneof('value') == 'string_value':
    result.string_value = kubeflow_value.string_value
  else:
    raise TypeError('Get unknown type of value: {}'.format(kubeflow_value))

  return result
예제 #22
0
def _build_proto_exec_property_dict(exec_properties):
    """Build PythonExecutorExecutionInfo.execution_properties."""
    proto_dict = {}
    for k, v in exec_properties.items():
        value = metadata_store_pb2.Value()
        if isinstance(v, str):
            value.string_value = v
        elif isinstance(v, int):
            value.int_value = v
        elif isinstance(v, float):
            value.double_value = v
        else:
            raise RuntimeError('Unsupported type {} for key {}'.format(
                type(v), k))
        proto_dict[k] = value
    return proto_dict
예제 #23
0
def _populate_exec_properties(
        executor_output: execution_result_pb2.ExecutorOutput,
        exec_properties: Dict[str, Any]):
    """Populate exec_properties to executor_output."""
    for key, value in exec_properties.items():
        v = metadata_store_pb2.Value()
        if isinstance(value, str):
            v.string_value = value
        elif isinstance(value, int):
            v.int_value = value
        elif isinstance(value, float):
            v.double_value = value
        else:
            logging.info(
                'Value type %s of key %s in exec_properties is not '
                'supported, going to drop it', type(value), key)
            continue
        executor_output.execution_properties[key].CopyFrom(v)
예제 #24
0
    def _create_execution_in_run_context(self):
        run_id = metadata_store_pb2.Value(
            string_value=kfputils.format_kfp_run_id_uri(self.run_uuid))
        pipeline_name = metadata_store_pb2.Value(
            string_value=self.pipeline_name)
        component_id = metadata_store_pb2.Value(string_value=self.component_id)
        state = metadata_store_pb2.Execution.RUNNING
        state_mlmd_value = metadata_store_pb2.Value(
            string_value=KALE_EXECUTION_STATE_RUNNING)

        property_types = {
            "run_id": metadata_store_pb2.STRING,
            "pipeline_name": metadata_store_pb2.STRING,
            "component_id": metadata_store_pb2.STRING,
            MLMD_EXECUTION_STATE_KEY: metadata_store_pb2.STRING
        }
        properties = {
            "run_id": run_id,
            "pipeline_name": pipeline_name,
            "component_id": component_id,
            MLMD_EXECUTION_STATE_KEY: state_mlmd_value
        }

        exec_hash_mlmd_value = metadata_store_pb2.Value(
            string_value=self.execution_hash)
        pod_name_mlmd_value = metadata_store_pb2.Value(
            string_value=self.pod_name)
        pod_namespace_mlmd = metadata_store_pb2.Value(
            string_value=self.pod_namespace)
        custom_props = {
            MLMD_EXECUTION_HASH_PROPERTY_KEY: exec_hash_mlmd_value,
            MLMD_EXECUTION_POD_NAME_PROPERTY_KEY: pod_name_mlmd_value,
            MLMD_EXECUTION_CACHE_POD_NAME_PROPERTY_KEY: pod_name_mlmd_value,
            MLMD_EXECUTION_POD_NAMESPACE_PROPERTY_KEY: pod_namespace_mlmd,
            KALE_EXECUTION_STATE_KEY: state_mlmd_value
        }
        execution = self._create_execution_with_type(
            type_name=self.component_id,
            property_types=property_types,
            properties=properties,
            custom_properties=custom_props,
            state=state)

        association = metadata_store_pb2.Association(
            execution_id=execution.id, context_id=self.run_context.id)
        self.store.put_attributions_and_associations([], [association])
        return execution
예제 #25
0
def _build_proto_exec_property_dict(
    exec_properties: Mapping[str, types.Property]
) -> Dict[str, metadata_store_pb2.Value]:
  """Builds PythonExecutorExecutionInfo.execution_properties."""
  result = {}
  if not exec_properties:
    return result
  for k, v in exec_properties.items():
    value = metadata_store_pb2.Value()
    if isinstance(v, str):
      value.string_value = v
    elif isinstance(v, int):
      value.int_value = v
    elif isinstance(v, float):
      value.double_value = v
    else:
      raise RuntimeError('Unsupported type {} for key {}'.format(type(v), k))
    result[k] = value
  return result
예제 #26
0
def build_metadata_value_dict(
    value_dict: Mapping[str, types.Property]
) -> Dict[str, metadata_store_pb2.Value]:
    """Converts plain value dict into MLMD value dict."""
    result = {}
    if not value_dict:
        return result
    for k, v in value_dict.items():
        value = metadata_store_pb2.Value()
        if isinstance(v, str):
            value.string_value = v
        elif isinstance(v, int):
            value.int_value = v
        elif isinstance(v, float):
            value.double_value = v
        else:
            raise RuntimeError('Unsupported type {} for key {}'.format(
                type(v), k))
        result[k] = value
    return result
예제 #27
0
 def serialization(self):
     metrics_artifact = mlpb.Artifact(
         uri=self.uri,
         properties={
             "name":
             mlpb.Value(string_value=self.name),
             "create_time":
             mlpb.Value(string_value=self.create_time),
             "description":
             mlpb.Value(string_value=self.description),
             "metrics_type":
             mlpb.Value(string_value=self.metrics_type),
             "data_set_id":
             mlpb.Value(string_value=self.data_set_id),
             "model_id":
             mlpb.Value(string_value=self.model_id),
             "owner":
             mlpb.Value(string_value=self.owner),
             _ALL_META_PROPERTY_NAME:
             mlpb.Value(string_value=json.dumps(self.__dict__)),
         })
     _del_none_properties(metrics_artifact.properties)
     return metrics_artifact
예제 #28
0
 def serialization(self):
     data_set_artifact = mlpb.Artifact(
         uri=self.uri,
         properties={
             "name":
             mlpb.Value(string_value=self.name),
             "create_time":
             mlpb.Value(string_value=self.create_time),
             "description":
             mlpb.Value(string_value=self.description),
             "query":
             mlpb.Value(string_value=self.query),
             "version":
             mlpb.Value(string_value=self.version),
             "owner":
             mlpb.Value(string_value=self.owner),
             _ALL_META_PROPERTY_NAME:
             mlpb.Value(string_value=json.dumps(self.__dict__)),
         })
     _del_none_properties(data_set_artifact.properties)
     return data_set_artifact
예제 #29
0
def register_context_if_not_exists(
    metadata_handler: metadata.Metadata,
    context_type_name: Text,
    context_name: Text,
) -> metadata_store_pb2.Context:
    """Registers a context if not exist, otherwise returns the existing one.

  This is a simplified wrapper around the method above which only takes context
  type and context name.

  Args:
    metadata_handler: A handler to access MLMD store.
    context_type_name: The name of the context type.
    context_name: The name of the context.

  Returns:
    An MLMD context.
  """
    context_spec = pipeline_pb2.ContextSpec(
        name=pipeline_pb2.Value(field_value=metadata_store_pb2.Value(
            string_value=context_name)),
        type=metadata_store_pb2.ContextType(name=context_type_name))
    return _register_context_if_not_exist(metadata_handler=metadata_handler,
                                          context_spec=context_spec)
예제 #30
0
    def testSetParameterValue(self):
        actual_int = pipeline_pb2.Value()
        expected_int = text_format.Parse(
            """
          field_value {
            int_value: 1
          }
        """, pipeline_pb2.Value())
        self.assertEqual(expected_int,
                         data_types_utils.set_parameter_value(actual_int, 1))

        actual_str = pipeline_pb2.Value()
        expected_str = text_format.Parse(
            """
          field_value {
            string_value: 'hello'
          }
        """, pipeline_pb2.Value())
        self.assertEqual(
            expected_str,
            data_types_utils.set_parameter_value(actual_str, 'hello'))

        actual_bool = pipeline_pb2.Value()
        expected_bool = text_format.Parse(
            """
          field_value {
            string_value: 'true'
          }
          schema {
            value_type {
              boolean_type {}
            }
          }
        """, pipeline_pb2.Value())
        self.assertEqual(
            expected_bool,
            data_types_utils.set_parameter_value(actual_bool, True))

        actual_proto = pipeline_pb2.Value()
        expected_proto = text_format.Parse(
            """
          field_value {
            string_value: '{\\n  "string_value": "hello"\\n}'
          }
          schema {
            value_type {
              proto_type {
                message_type: 'ml_metadata.Value'
              }
            }
          }
        """, pipeline_pb2.Value())
        data_types_utils.set_parameter_value(
            actual_proto, metadata_store_pb2.Value(string_value='hello'))
        actual_proto.schema.value_type.proto_type.ClearField(
            'file_descriptors')
        self.assertProtoPartiallyEquals(expected_proto, actual_proto)

        actual_list = pipeline_pb2.Value()
        expected_list = text_format.Parse(
            """
          field_value {
            string_value: '[false, true]'
          }
          schema {
            value_type {
              list_type {
                boolean_type {}
              }
            }
          }
        """, pipeline_pb2.Value())
        self.assertEqual(
            expected_list,
            data_types_utils.set_parameter_value(actual_list, [False, True]))

        actual_list = pipeline_pb2.Value()
        expected_list = text_format.Parse(
            """
          field_value {
            string_value: '["true", "false"]'
          }
          schema {
            value_type {
              list_type {}
            }
          }
        """, pipeline_pb2.Value())
        self.assertEqual(
            expected_list,
            data_types_utils.set_parameter_value(actual_list,
                                                 ['true', 'false']))