Ejemplo n.º 1
0
 def testGetValue(self):
     tfx_value = pipeline_pb2.Value()
     text_format.Parse(
         """
     field_value {
       int_value: 1
     }""", tfx_value)
     self.assertEqual(data_types_utils.get_value(tfx_value), 1)
Ejemplo n.º 2
0
 def testGetMetadataValueType(self):
     tfx_value = pipeline_pb2.Value()
     text_format.Parse(
         """
     field_value {
       int_value: 1
     }""", tfx_value)
     self.assertEqual(data_types_utils.get_metadata_value_type(tfx_value),
                      metadata_store_pb2.INT)
Ejemplo n.º 3
0
 def testGetValueFailed(self):
     tfx_value = pipeline_pb2.Value()
     text_format.Parse(
         """
     runtime_parameter {
       name: 'rp'
     }""", tfx_value)
     with self.assertRaisesRegex(RuntimeError,
                                 'Expecting field_value but got'):
         data_types_utils.get_value(tfx_value)
Ejemplo n.º 4
0
 def testSetMetadataValueWithTfxValue(self):
     tfx_value = pipeline_pb2.Value()
     metadata_property = metadata_store_pb2.Value()
     text_format.Parse(
         """
     field_value {
         int_value: 1
     }""", tfx_value)
     data_types_utils.set_metadata_value(metadata_value=metadata_property,
                                         value=tfx_value)
     self.assertProtoEquals('int_value: 1', metadata_property)
Ejemplo n.º 5
0
 def testSetMetadataValueWithTfxValueFailed(self):
     tfx_value = pipeline_pb2.Value()
     metadata_property = metadata_store_pb2.Value()
     text_format.Parse(
         """
     runtime_parameter {
       name: 'rp'
     }""", tfx_value)
     with self.assertRaisesRegex(ValueError,
                                 'Expecting field_value but got'):
         data_types_utils.set_metadata_value(
             metadata_value=metadata_property, value=tfx_value)
Ejemplo n.º 6
0
def build_pipeline_value_dict(
    value_dict: Dict[str, types.ExecPropertyTypes]
) -> Dict[str, pipeline_pb2.Value]:
  """Converts plain value dict into pipeline_pb2.Value dict."""
  result = {}
  if not value_dict:
    return result
  for k, v in value_dict.items():
    if v is None:
      continue
    value = pipeline_pb2.Value()
    result[k] = set_parameter_value(value, v)
  return result
Ejemplo n.º 7
0
def _replace_pipeline_run_id_in_channel(channel: p_pb2.InputSpec.Channel,
                                        pipeline_run_id: str):
  """Update in place."""
  for context_query in channel.context_queries:
    if context_query.type.name == dsl_constants.PIPELINE_RUN_CONTEXT_TYPE_NAME:
      context_query.name.field_value.CopyFrom(
          mlmd_pb2.Value(string_value=pipeline_run_id))
      return

  channel.context_queries.append(
      p_pb2.InputSpec.Channel.ContextQuery(
          type=mlmd_pb2.ContextType(
              name=dsl_constants.PIPELINE_RUN_CONTEXT_TYPE_NAME),
          name=p_pb2.Value(
              field_value=mlmd_pb2.Value(string_value=pipeline_run_id))))
Ejemplo n.º 8
0
def prepare_execution(
    metadata_handler: metadata.Metadata,
    execution_type: metadata_store_pb2.ExecutionType,
    state: metadata_store_pb2.Execution.State,
    exec_properties: Optional[Mapping[str, types.ExecPropertyTypes]] = None,
) -> metadata_store_pb2.Execution:
    """Creates an execution proto based on the information provided.

  Args:
    metadata_handler: A handler to access MLMD store.
    execution_type: A metadata_pb2.ExecutionType message describing the type of
      the execution.
    state: The state of the execution.
    exec_properties: Execution properties that need to be attached.

  Returns:
    A metadata_store_pb2.Execution message.
  """
    execution = metadata_store_pb2.Execution()
    execution.last_known_state = state
    execution.type_id = common_utils.register_type_if_not_exist(
        metadata_handler, execution_type).id

    exec_properties = exec_properties or {}
    # For every execution property, put it in execution.properties if its key is
    # in execution type schema. Otherwise, put it in execution.custom_properties.
    for k, v in exec_properties.items():
        value = pipeline_pb2.Value()
        value = data_types_utils.set_parameter_value(value, v)

        if value.HasField('schema'):
            # Stores schema in custom_properties for non-primitive types to allow
            # parsing in later stages.
            data_types_utils.set_metadata_value(
                execution.custom_properties[get_schema_key(k)],
                proto_utils.proto_to_json(value.schema))

        if (execution_type.properties.get(k) ==
                data_types_utils.get_metadata_value_type(v)):
            execution.properties[k].CopyFrom(value.field_value)
        else:
            execution.custom_properties[k].CopyFrom(value.field_value)
    logging.debug('Prepared EXECUTION:\n %s', execution)
    return execution
Ejemplo n.º 9
0
def set_metadata_value(
    metadata_value: metadata_store_pb2.Value,
    value: types.ExecPropertyTypes) -> metadata_store_pb2.Value:
  """Sets metadata property based on tfx value.

  Args:
    metadata_value: A metadata_store_pb2.Value message to be set.
    value: The value of the property in pipeline_pb2.Value form.

  Returns:
    A Value proto filled with the provided value.

  Raises:
    ValueError: If value type is not supported or is still RuntimeParameter.
  """
  parameter_value = pipeline_pb2.Value()
  set_parameter_value(parameter_value, value, set_schema=False)
  metadata_value.CopyFrom(parameter_value.field_value)
  return metadata_value
Ejemplo n.º 10
0
def register_context_if_not_exists(
    metadata_handler: metadata.Metadata,
    context_type_name: Text,
    context_name: Text,
) -> metadata_store_pb2.Context:
    """Registers a context if not exist, otherwise returns the existing one.

  This is a simplified wrapper around the method above which only takes context
  type and context name.

  Args:
    metadata_handler: A handler to access MLMD store.
    context_type_name: The name of the context type.
    context_name: The name of the context.

  Returns:
    An MLMD context.
  """
    context_spec = pipeline_pb2.ContextSpec(
        name=pipeline_pb2.Value(field_value=metadata_store_pb2.Value(
            string_value=context_name)),
        type=metadata_store_pb2.ContextType(name=context_type_name))
    return _register_context_if_not_exist(metadata_handler=metadata_handler,
                                          context_spec=context_spec)
Ejemplo n.º 11
0
 def testSetParameterValueUnsupportedType(self):
     actual_value = pipeline_pb2.Value()
     with self.assertRaises(ValueError):
         data_types_utils.set_parameter_value(actual_value, {'a': 1})
Ejemplo n.º 12
0
    def testSetParameterValue(self):
        actual_int = pipeline_pb2.Value()
        expected_int = text_format.Parse(
            """
          field_value {
            int_value: 1
          }
        """, pipeline_pb2.Value())
        self.assertEqual(expected_int,
                         data_types_utils.set_parameter_value(actual_int, 1))

        actual_str = pipeline_pb2.Value()
        expected_str = text_format.Parse(
            """
          field_value {
            string_value: 'hello'
          }
        """, pipeline_pb2.Value())
        self.assertEqual(
            expected_str,
            data_types_utils.set_parameter_value(actual_str, 'hello'))

        actual_bool = pipeline_pb2.Value()
        expected_bool = text_format.Parse(
            """
          field_value {
            string_value: 'true'
          }
          schema {
            value_type {
              boolean_type {}
            }
          }
        """, pipeline_pb2.Value())
        self.assertEqual(
            expected_bool,
            data_types_utils.set_parameter_value(actual_bool, True))

        actual_proto = pipeline_pb2.Value()
        expected_proto = text_format.Parse(
            """
          field_value {
            string_value: '{\\n  "string_value": "hello"\\n}'
          }
          schema {
            value_type {
              proto_type {
                message_type: 'ml_metadata.Value'
              }
            }
          }
        """, pipeline_pb2.Value())
        data_types_utils.set_parameter_value(
            actual_proto, metadata_store_pb2.Value(string_value='hello'))
        actual_proto.schema.value_type.proto_type.ClearField(
            'file_descriptors')
        self.assertProtoPartiallyEquals(expected_proto, actual_proto)

        actual_list = pipeline_pb2.Value()
        expected_list = text_format.Parse(
            """
          field_value {
            string_value: '[false, true]'
          }
          schema {
            value_type {
              list_type {
                boolean_type {}
              }
            }
          }
        """, pipeline_pb2.Value())
        self.assertEqual(
            expected_list,
            data_types_utils.set_parameter_value(actual_list, [False, True]))

        actual_list = pipeline_pb2.Value()
        expected_list = text_format.Parse(
            """
          field_value {
            string_value: '["true", "false"]'
          }
          schema {
            value_type {
              list_type {}
            }
          }
        """, pipeline_pb2.Value())
        self.assertEqual(
            expected_list,
            data_types_utils.set_parameter_value(actual_list,
                                                 ['true', 'false']))
Ejemplo n.º 13
0
 def testBuildParsedValueDict(self):
     int_value = text_format.Parse(
         """
       field_value {
         int_value: 1
       }
     """, pipeline_pb2.Value())
     string_value = text_format.Parse(
         """
       field_value {
         string_value: 'random str'
       }
     """, pipeline_pb2.Value())
     bool_value = text_format.Parse(
         """
       field_value {
         string_value: 'false'
       }
       schema {
         value_type {
           boolean_type {}
         }
       }
     """, pipeline_pb2.Value())
     proto_value = text_format.Parse(
         """
       field_value {
         string_value: '{"string_value":"hello"}'
       }
       schema {
         value_type {
           proto_type {
             message_type: 'ml_metadata.Value'
           }
         }
       }
     """, pipeline_pb2.Value())
     list_boolean_value = text_format.Parse(
         """
       field_value {
         string_value: '[false, true]'
       }
       schema {
         value_type {
           list_type {
             boolean_type {}
           }
         }
       }
     """, pipeline_pb2.Value())
     list_str_value = text_format.Parse(
         """
       field_value {
         string_value: '["true", "false", "random"]'
       }
       schema {
         value_type {
           list_type {}
         }
       }
     """, pipeline_pb2.Value())
     value_dict = {
         'int_val': int_value,
         'string_val': string_value,
         'bool_val': bool_value,
         'proto_val': proto_value,
         'list_boolean_value': list_boolean_value,
         'list_str_value': list_str_value,
     }
     expected_parsed_dict = {
         'int_val': 1,
         'string_val': 'random str',
         'bool_val': False,
         'list_boolean_value': [False, True],
         'list_str_value': ['true', 'false', 'random'],
         'proto_val': metadata_store_pb2.Value(string_value='hello')
     }
     self.assertEqual(expected_parsed_dict,
                      data_types_utils.build_parsed_value_dict(value_dict))