def _create_any_protos():
    my_value_0 = test_pb2.AllSimple()
    my_value_0.optional_int32 = 0
    my_value_1 = test_pb2.UserInfo()
    my_value_2 = test_pb2.AllSimple()
    my_value_2.optional_int32 = 20
    return [_create_any(x) for x in [my_value_0, my_value_1, my_value_2]]
 def test_proto2_optional_field_with_honor_proto3_optional_semantic(self):
     proto2_message1 = test_pb2.AllSimple()
     proto2_message2 = test_pb2.AllSimple(optional_string="a")
     tensor_of_protos = tf.constant([
         proto2_message1.SerializeToString(),
         proto2_message2.SerializeToString(),
         proto2_message1.SerializeToString()
     ])
     parsed_tuples = struct2tensor_ops.parse_message_level(
         tensor_of_protos,
         test_pb2.AllSimple.DESCRIPTOR, [
             "optional_string",
         ],
         honor_proto3_optional_semantics=True)
     indices = {
         parsed_tuple.field_name: parsed_tuple.index
         for parsed_tuple in parsed_tuples
     }
     values = {
         parsed_tuple.field_name: parsed_tuple.value
         for parsed_tuple in parsed_tuples
     }
     # Only the second proto has value. No default value should be inserted.
     for idx in indices.values():
         self.assertAllEqual([1], idx)
     for value in values.values():
         self.assertAllEqual([b"a"], value)
Пример #3
0
 def test_parse_full_message_level_for_all_simple_repeated_repeated(self):
     """Test five messages with every possible repeated field repeated."""
     all_simple = test_pb2.AllSimple()
     all_simple.repeated_string.append("foo")
     all_simple.repeated_string.append("foo2")
     all_simple.repeated_int32.append(32)
     all_simple.repeated_int32.append(322)
     all_simple.repeated_uint32.append(123)
     all_simple.repeated_uint32.append(1232)
     all_simple.repeated_int64.append(123456)
     all_simple.repeated_int64.append(1234562)
     all_simple.repeated_uint64.append(123)
     all_simple.repeated_uint64.append(1232)
     all_simple.repeated_float.append(1.0)
     all_simple.repeated_float.append(2.0)
     all_simple.repeated_double.append(1.5)
     all_simple.repeated_double.append(2.5)
     result = _get_full_message_level_runnable([
         all_simple,
         test_pb2.AllSimple(),
         test_pb2.AllSimple(), all_simple,
         test_pb2.AllSimple(), all_simple,
         test_pb2.AllSimple()
     ])
     self.assertAllEqual(result["repeated_string"][INDEX],
                         [0, 0, 3, 3, 5, 5])
     self.assertAllEqual(
         result["repeated_string"][VALUE],
         [b"foo", b"foo2", b"foo", b"foo2", b"foo", b"foo2"])
     self.assertAllEqual(result["repeated_int32"][INDEX],
                         [0, 0, 3, 3, 5, 5])
     self.assertAllEqual(result["repeated_int32"][VALUE],
                         [32, 322, 32, 322, 32, 322])
     self.assertAllEqual(result["repeated_uint32"][INDEX],
                         [0, 0, 3, 3, 5, 5])
     self.assertAllEqual(result["repeated_uint32"][VALUE],
                         [123, 1232, 123, 1232, 123, 1232])
     self.assertAllEqual(result["repeated_int64"][INDEX],
                         [0, 0, 3, 3, 5, 5])
     self.assertAllEqual(
         result["repeated_int64"][VALUE],
         [123456, 1234562, 123456, 1234562, 123456, 1234562])
     self.assertAllEqual(result["repeated_uint64"][INDEX],
                         [0, 0, 3, 3, 5, 5])
     self.assertAllEqual(result["repeated_uint64"][VALUE],
                         [123, 1232, 123, 1232, 123, 1232])
     self.assertAllEqual(result["repeated_float"][INDEX],
                         [0, 0, 3, 3, 5, 5])
     self.assertAllEqual(result["repeated_float"][VALUE],
                         [1.0, 2.0, 1.0, 2.0, 1.0, 2.0])
     self.assertAllEqual(result["repeated_double"][INDEX],
                         [0, 0, 3, 3, 5, 5])
     self.assertAllEqual(result["repeated_double"][VALUE],
                         [1.5, 2.5, 1.5, 2.5, 1.5, 2.5])
Пример #4
0
def _get_expression_with_any():
  my_any_0 = test_any_pb2.MessageWithAny()
  my_value_0 = test_pb2.AllSimple()
  my_value_0.optional_int32 = 0
  my_any_0.my_any.Pack(my_value_0)
  my_any_1 = test_any_pb2.MessageWithAny()
  my_value_1 = test_pb2.UserInfo()
  my_any_1.my_any.Pack(my_value_1)
  my_any_2 = test_any_pb2.MessageWithAny()
  my_value_2 = test_pb2.AllSimple()
  my_value_2.optional_int32 = 20
  my_any_2.my_any.Pack(my_value_2)
  serialized = [x.SerializeToString() for x in [my_any_0, my_any_1, my_any_2]]
  return proto.create_expression_from_proto(
      serialized, test_any_pb2.MessageWithAny.DESCRIPTOR)
Пример #5
0
 def test_any_path(self):
   my_any_0 = test_any_pb2.MessageWithAny()
   my_value_0 = test_pb2.AllSimple()
   my_value_0.optional_int32 = 17
   my_any_0.my_any.Pack(my_value_0)
   expr = proto.create_expression_from_proto(
       [my_any_0.SerializeToString()], test_any_pb2.MessageWithAny.DESCRIPTOR)
   new_root = promote.promote(
       expr,
       path.Path(
           ["my_any", "(struct2tensor.test.AllSimple)", "optional_int32"]),
       "new_int32")
   new_field = new_root.get_descendant_or_error(
       path.Path(["my_any", "new_int32"]))
   result = calculate_with_source_paths.calculate_prensors_with_source_paths(
       [new_field])
   prensor_result, proto_summary_result = result
   self.assertLen(prensor_result, 1)
   self.assertLen(proto_summary_result, 1)
   leaf_node = prensor_result[0].node
   self.assertAllEqual(leaf_node.parent_index, [0])
   self.assertAllEqual(leaf_node.values, [17])
   list_of_paths = proto_summary_result[0].paths
   expected = [
       path.Path(
           ["my_any", "(struct2tensor.test.AllSimple)", "optional_int32"])
   ]
   self.equal_ignore_order(list_of_paths, expected)
Пример #6
0
    def test_normal_field(self, message_format):
        """Test three messages with a repeated string."""
        all_simple = test_pb2.AllSimple()
        all_simple.repeated_string.append("foo")
        all_simple.repeated_string.append("foo2")
        all_simple_empty = test_pb2.AllSimple()

        result = _run_parse_message_level_ex(
            [all_simple, all_simple_empty, all_simple, all_simple],
            {"repeated_string"}, message_format)
        self.assertNotIn("repeated_bool", result)

        self.assertAllEqual(result["repeated_string"][_INDEX],
                            [0, 0, 2, 2, 3, 3])
        self.assertAllEqual(
            result["repeated_string"][_VALUE],
            [b"foo", b"foo2", b"foo", b"foo2", b"foo", b"foo2"])
Пример #7
0
    def test_parse_full_message_level_for_simple_action_multiple(self):
        """Test multiple messages."""
        as1 = test_pb2.AllSimple()
        as1.optional_string = "a"
        as1.repeated_string.append("b")
        as1.repeated_string.append("c")
        as2 = test_pb2.AllSimple()
        as2.optional_string = "d"
        as2.optional_int32 = 123
        as3 = test_pb2.AllSimple()
        as3.repeated_string.append("d")
        as3.repeated_string.append("e")
        as3.optional_int32 = 123

        parsed_field_dict = _parse_full_message_level_as_dict([as1, as2, as3])

        doc_id = parsed_field_dict["repeated_string"]
        self.assertAllEqual(doc_id.index, [0, 0, 2, 2])
        self.assertAllEqual(doc_id.value, [b"b", b"c", b"d", b"e"])
Пример #8
0
  def test_out_of_order_repeated_fields_3(self):
    # This follows a 3-5-3 wire number pattern, where 3 and 4 parsed fields.
    proto = (
        test_pb2.AllSimple(repeated_string=["aaa"]).SerializeToString() +
        test_pb2.AllSimple(repeated_int64=[12345]).SerializeToString() +
        test_pb2.AllSimple(repeated_string=["bbb"]).SerializeToString())
    expected_field_value = {
        "repeated_string": [b"aaa", b"bbb"],
        "repeated_int64": [12345],
        "repeated_int32": [],
        "repeated_uint32": []
    }

    for fields_to_parse in [["repeated_int64"], ["repeated_string"],
                            [
                                "repeated_string",
                                "repeated_uint32", "repeated_int32"
                            ]]:
      parsed_fields = struct2tensor_ops.parse_message_level(
          [proto], test_pb2.AllSimple.DESCRIPTOR, fields_to_parse)
      for f in parsed_fields:
        self.assertAllEqual(expected_field_value[f.field_name], f.value)
Пример #9
0
    def test_parse_full_message_level_for_all_simple(self):
        """Test a single message with every possible primitive field."""
        all_simple = test_pb2.AllSimple()
        all_simple.optional_string = "foo"
        all_simple.optional_int32 = -5
        all_simple.optional_uint32 = 2**31
        all_simple.optional_int64 = 100123
        all_simple.optional_uint64 = 2**63
        all_simple.optional_float = 6.5
        all_simple.optional_double = -7.0
        all_simple.repeated_string.append("foo")
        all_simple.repeated_int32.append(32)
        all_simple.repeated_uint32.append(123)
        all_simple.repeated_int64.append(123456)
        all_simple.repeated_uint64.append(123)
        all_simple.repeated_float.append(1.0)
        all_simple.repeated_double.append(1.5)
        runnable = _get_full_message_level_runnable([all_simple])
        self.assertLen(runnable["optional_string"][INDEX].shape.dims, 1)
        self.assertLen(runnable["optional_string"][VALUE].shape.dims, 1)
        self.assertLen(runnable["repeated_string"][INDEX].shape.dims, 1)
        self.assertLen(runnable["repeated_string"][VALUE].shape.dims, 1)

        result = runnable
        self.assertAllEqual(result["optional_string"][INDEX], [0])
        self.assertAllEqual(result["optional_string"][VALUE], [b"foo"])
        self.assertAllEqual(result["optional_int32"][INDEX], [0])
        self.assertAllEqual(result["optional_int32"][VALUE], [-5])
        self.assertAllEqual(result["optional_uint32"][INDEX], [0])
        self.assertAllEqual(result["optional_uint32"][VALUE], [2**31])
        self.assertAllEqual(result["optional_int64"][INDEX], [0])
        self.assertAllEqual(result["optional_int64"][VALUE], [100123])
        self.assertAllEqual(result["optional_uint64"][INDEX], [0])
        self.assertAllEqual(result["optional_uint64"][VALUE], [2**63])
        self.assertAllEqual(result["optional_float"][INDEX], [0])
        self.assertAllEqual(result["optional_float"][VALUE], [6.5])
        self.assertAllEqual(result["optional_double"][INDEX], [0])
        self.assertAllEqual(result["optional_double"][VALUE], [-7.0])
def _get_optional_int32(serialized_all_simple):
    """Take a serialized test_pb2.AllSimple object and extract optional_int32."""
    holder = test_pb2.AllSimple()
    holder.ParseFromString(serialized_all_simple)
    return holder.optional_int32
Пример #11
0
def _get_empty_all_simple():
    """Take a serialized test_pb2.AllSimple object and extract optional_int32."""
    return test_pb2.AllSimple().SerializeToString()