def testGetValueFailed(self): tfx_value = pipeline_pb2.Value() text_format.Parse( """ runtime_parameter { name: 'rp' }""", tfx_value) with self.assertRaisesRegex(RuntimeError, 'Expecting field_value but got'): data_types_utils.get_value(tfx_value)
def get_executions( metadata_handler: metadata.Metadata, node: pipeline_pb2.PipelineNode) -> List[metadata_store_pb2.Execution]: """Returns all executions for the given pipeline node. This finds all executions having the same set of contexts as the pipeline node. Args: metadata_handler: A handler to access MLMD db. node: The pipeline node for which to obtain executions. Returns: List of executions for the given node in MLMD db. """ if not node.contexts.contexts: return [] # Get all the contexts associated with the node. contexts = [] for i, context_spec in enumerate(node.contexts.contexts): context_type = context_spec.type.name context_name = data_types_utils.get_value(context_spec.name) contexts.append( f"(contexts_{i}.type = '{context_type}' AND contexts_{i}.name = '{context_name}')" ) filter_query = ' AND '.join(contexts) return metadata_handler.store.get_executions(list_options=mlmd.ListOptions( filter_query=filter_query))
def get_executions( metadata_handler: metadata.Metadata, node: pipeline_pb2.PipelineNode) -> List[metadata_store_pb2.Execution]: """Returns all executions for the given pipeline node. This finds all executions having the same set of contexts as the pipeline node. Args: metadata_handler: A handler to access MLMD db. node: The pipeline node for which to obtain executions. Returns: List of executions for the given node in MLMD db. """ # Get all the contexts associated with the node. contexts = [] for context_spec in node.contexts.contexts: context = metadata_handler.store.get_context_by_type_and_name( context_spec.type.name, data_types_utils.get_value(context_spec.name)) if context is None: # If no context is registered, it's certain that there is no # associated execution for the node. return [] contexts.append(context) return execution_lib.get_executions_associated_with_all_contexts( metadata_handler, contexts)
def testGetValue(self): tfx_value = pipeline_pb2.Value() text_format.Parse( """ field_value { int_value: 1 }""", tfx_value) self.assertEqual(data_types_utils.get_value(tfx_value), 1)
def _resolve_single_channel( metadata_handler: metadata.Metadata, channel: pipeline_pb2.InputSpec.Channel) -> List[types.Artifact]: """Resolves input artifacts from a single channel.""" artifact_type = channel.artifact_query.type output_key = channel.output_key or None contexts = [] for context_query in channel.context_queries: context = metadata_handler.store.get_context_by_type_and_name( context_query.type.name, data_types_utils.get_value(context_query.name)) if context: contexts.append(context) return get_qualified_artifacts( metadata_handler=metadata_handler, contexts=contexts, artifact_type=artifact_type, output_key=output_key)
def _register_context_if_not_exist( metadata_handler: metadata.Metadata, context_spec: pipeline_pb2.ContextSpec, ) -> metadata_store_pb2.Context: """Registers a context if not exist, otherwise returns the existing one. Args: metadata_handler: A handler to access MLMD store. context_spec: A pipeline_pb2.ContextSpec message that instructs registering of a context. Returns: An MLMD context. """ context_type_name = context_spec.type.name context_name = data_types_utils.get_value(context_spec.name) context = metadata_handler.store.get_context_by_type_and_name( type_name=context_type_name, context_name=context_name) if context is not None: return context logging.debug('Failed to get context of type %s and name %s', context_type_name, context_name) # If Context is not found, try to register it. context = _generate_context_proto(metadata_handler=metadata_handler, context_spec=context_spec) try: [context_id] = metadata_handler.store.put_contexts([context]) context.id = context_id # This might happen in cases we have parallel executions of nodes. except mlmd.errors.AlreadyExistsError: logging.debug('Context %s already exists.', context_name) context = metadata_handler.store.get_context_by_type_and_name( type_name=context_type_name, context_name=context_name) assert context is not None, ( 'Context is missing for %s while put_contexts ' 'reports that it existed.') % (context_name) logging.debug('ID of context %s is %s.', context_spec, context.id) return context
def _generate_context_proto( metadata_handler: metadata.Metadata, context_spec: pipeline_pb2.ContextSpec) -> metadata_store_pb2.Context: """Generates metadata_pb2.Context based on the ContextSpec message. Args: metadata_handler: A handler to access MLMD store. context_spec: A pipeline_pb2.ContextSpec message that instructs registering of a context. Returns: A metadata_store_pb2.Context message. Raises: RuntimeError: When actual property type does not match provided metadata type schema. """ context_type = common_utils.register_type_if_not_exist( metadata_handler, context_spec.type) context_name = data_types_utils.get_value(context_spec.name) assert isinstance(context_name, Text), 'context name should be string.' result = metadata_store_pb2.Context(type_id=context_type.id, name=context_name) for k, v in context_spec.properties.items(): if k in context_type.properties: actual_property_type = data_types_utils.get_metadata_value_type(v) if context_type.properties.get(k) == actual_property_type: data_types_utils.set_metadata_value(result.properties[k], v) else: raise RuntimeError( 'Property type %s different from provided metadata type property type %s for key %s' % (actual_property_type, context_type.properties.get(k), k)) else: data_types_utils.set_metadata_value(result.custom_properties[k], v) return result
def _as_dict(proto_map) -> Dict[str, types.Property]: return { k: data_types_utils.get_value(v) for k, v in proto_map.items() }