def testGetValue(self): tfx_value = pipeline_pb2.Value() text_format.Parse( """ field_value { int_value: 1 }""", tfx_value) self.assertEqual(data_types_utils.get_value(tfx_value), 1)
def testGetMetadataValueType(self): tfx_value = pipeline_pb2.Value() text_format.Parse( """ field_value { int_value: 1 }""", tfx_value) self.assertEqual(data_types_utils.get_metadata_value_type(tfx_value), metadata_store_pb2.INT)
def testGetValueFailed(self): tfx_value = pipeline_pb2.Value() text_format.Parse( """ runtime_parameter { name: 'rp' }""", tfx_value) with self.assertRaisesRegex(RuntimeError, 'Expecting field_value but got'): data_types_utils.get_value(tfx_value)
def testSetMetadataValueWithTfxValue(self): tfx_value = pipeline_pb2.Value() metadata_property = metadata_store_pb2.Value() text_format.Parse( """ field_value { int_value: 1 }""", tfx_value) data_types_utils.set_metadata_value(metadata_value=metadata_property, value=tfx_value) self.assertProtoEquals('int_value: 1', metadata_property)
def testSetMetadataValueWithTfxValueFailed(self): tfx_value = pipeline_pb2.Value() metadata_property = metadata_store_pb2.Value() text_format.Parse( """ runtime_parameter { name: 'rp' }""", tfx_value) with self.assertRaisesRegex(ValueError, 'Expecting field_value but got'): data_types_utils.set_metadata_value( metadata_value=metadata_property, value=tfx_value)
def build_pipeline_value_dict( value_dict: Dict[str, types.ExecPropertyTypes] ) -> Dict[str, pipeline_pb2.Value]: """Converts plain value dict into pipeline_pb2.Value dict.""" result = {} if not value_dict: return result for k, v in value_dict.items(): if v is None: continue value = pipeline_pb2.Value() result[k] = set_parameter_value(value, v) return result
def _replace_pipeline_run_id_in_channel(channel: p_pb2.InputSpec.Channel, pipeline_run_id: str): """Update in place.""" for context_query in channel.context_queries: if context_query.type.name == dsl_constants.PIPELINE_RUN_CONTEXT_TYPE_NAME: context_query.name.field_value.CopyFrom( mlmd_pb2.Value(string_value=pipeline_run_id)) return channel.context_queries.append( p_pb2.InputSpec.Channel.ContextQuery( type=mlmd_pb2.ContextType( name=dsl_constants.PIPELINE_RUN_CONTEXT_TYPE_NAME), name=p_pb2.Value( field_value=mlmd_pb2.Value(string_value=pipeline_run_id))))
def prepare_execution( metadata_handler: metadata.Metadata, execution_type: metadata_store_pb2.ExecutionType, state: metadata_store_pb2.Execution.State, exec_properties: Optional[Mapping[str, types.ExecPropertyTypes]] = None, ) -> metadata_store_pb2.Execution: """Creates an execution proto based on the information provided. Args: metadata_handler: A handler to access MLMD store. execution_type: A metadata_pb2.ExecutionType message describing the type of the execution. state: The state of the execution. exec_properties: Execution properties that need to be attached. Returns: A metadata_store_pb2.Execution message. """ execution = metadata_store_pb2.Execution() execution.last_known_state = state execution.type_id = common_utils.register_type_if_not_exist( metadata_handler, execution_type).id exec_properties = exec_properties or {} # For every execution property, put it in execution.properties if its key is # in execution type schema. Otherwise, put it in execution.custom_properties. for k, v in exec_properties.items(): value = pipeline_pb2.Value() value = data_types_utils.set_parameter_value(value, v) if value.HasField('schema'): # Stores schema in custom_properties for non-primitive types to allow # parsing in later stages. data_types_utils.set_metadata_value( execution.custom_properties[get_schema_key(k)], proto_utils.proto_to_json(value.schema)) if (execution_type.properties.get(k) == data_types_utils.get_metadata_value_type(v)): execution.properties[k].CopyFrom(value.field_value) else: execution.custom_properties[k].CopyFrom(value.field_value) logging.debug('Prepared EXECUTION:\n %s', execution) return execution
def set_metadata_value( metadata_value: metadata_store_pb2.Value, value: types.ExecPropertyTypes) -> metadata_store_pb2.Value: """Sets metadata property based on tfx value. Args: metadata_value: A metadata_store_pb2.Value message to be set. value: The value of the property in pipeline_pb2.Value form. Returns: A Value proto filled with the provided value. Raises: ValueError: If value type is not supported or is still RuntimeParameter. """ parameter_value = pipeline_pb2.Value() set_parameter_value(parameter_value, value, set_schema=False) metadata_value.CopyFrom(parameter_value.field_value) return metadata_value
def register_context_if_not_exists( metadata_handler: metadata.Metadata, context_type_name: Text, context_name: Text, ) -> metadata_store_pb2.Context: """Registers a context if not exist, otherwise returns the existing one. This is a simplified wrapper around the method above which only takes context type and context name. Args: metadata_handler: A handler to access MLMD store. context_type_name: The name of the context type. context_name: The name of the context. Returns: An MLMD context. """ context_spec = pipeline_pb2.ContextSpec( name=pipeline_pb2.Value(field_value=metadata_store_pb2.Value( string_value=context_name)), type=metadata_store_pb2.ContextType(name=context_type_name)) return _register_context_if_not_exist(metadata_handler=metadata_handler, context_spec=context_spec)
def testSetParameterValueUnsupportedType(self): actual_value = pipeline_pb2.Value() with self.assertRaises(ValueError): data_types_utils.set_parameter_value(actual_value, {'a': 1})
def testSetParameterValue(self): actual_int = pipeline_pb2.Value() expected_int = text_format.Parse( """ field_value { int_value: 1 } """, pipeline_pb2.Value()) self.assertEqual(expected_int, data_types_utils.set_parameter_value(actual_int, 1)) actual_str = pipeline_pb2.Value() expected_str = text_format.Parse( """ field_value { string_value: 'hello' } """, pipeline_pb2.Value()) self.assertEqual( expected_str, data_types_utils.set_parameter_value(actual_str, 'hello')) actual_bool = pipeline_pb2.Value() expected_bool = text_format.Parse( """ field_value { string_value: 'true' } schema { value_type { boolean_type {} } } """, pipeline_pb2.Value()) self.assertEqual( expected_bool, data_types_utils.set_parameter_value(actual_bool, True)) actual_proto = pipeline_pb2.Value() expected_proto = text_format.Parse( """ field_value { string_value: '{\\n "string_value": "hello"\\n}' } schema { value_type { proto_type { message_type: 'ml_metadata.Value' } } } """, pipeline_pb2.Value()) data_types_utils.set_parameter_value( actual_proto, metadata_store_pb2.Value(string_value='hello')) actual_proto.schema.value_type.proto_type.ClearField( 'file_descriptors') self.assertProtoPartiallyEquals(expected_proto, actual_proto) actual_list = pipeline_pb2.Value() expected_list = text_format.Parse( """ field_value { string_value: '[false, true]' } schema { value_type { list_type { boolean_type {} } } } """, pipeline_pb2.Value()) self.assertEqual( expected_list, data_types_utils.set_parameter_value(actual_list, [False, True])) actual_list = pipeline_pb2.Value() expected_list = text_format.Parse( """ field_value { string_value: '["true", "false"]' } schema { value_type { list_type {} } } """, pipeline_pb2.Value()) self.assertEqual( expected_list, data_types_utils.set_parameter_value(actual_list, ['true', 'false']))
def testBuildParsedValueDict(self): int_value = text_format.Parse( """ field_value { int_value: 1 } """, pipeline_pb2.Value()) string_value = text_format.Parse( """ field_value { string_value: 'random str' } """, pipeline_pb2.Value()) bool_value = text_format.Parse( """ field_value { string_value: 'false' } schema { value_type { boolean_type {} } } """, pipeline_pb2.Value()) proto_value = text_format.Parse( """ field_value { string_value: '{"string_value":"hello"}' } schema { value_type { proto_type { message_type: 'ml_metadata.Value' } } } """, pipeline_pb2.Value()) list_boolean_value = text_format.Parse( """ field_value { string_value: '[false, true]' } schema { value_type { list_type { boolean_type {} } } } """, pipeline_pb2.Value()) list_str_value = text_format.Parse( """ field_value { string_value: '["true", "false", "random"]' } schema { value_type { list_type {} } } """, pipeline_pb2.Value()) value_dict = { 'int_val': int_value, 'string_val': string_value, 'bool_val': bool_value, 'proto_val': proto_value, 'list_boolean_value': list_boolean_value, 'list_str_value': list_str_value, } expected_parsed_dict = { 'int_val': 1, 'string_val': 'random str', 'bool_val': False, 'list_boolean_value': [False, True], 'list_str_value': ['true', 'false', 'random'], 'proto_val': metadata_store_pb2.Value(string_value='hello') } self.assertEqual(expected_parsed_dict, data_types_utils.build_parsed_value_dict(value_dict))