def test_bool_key_type(self):
     map_field = "bool_string_map[1]"
     message_with_map_0 = test_map_pb2.MessageWithMap()
     message_with_map_0.bool_string_map[False] = "hello"
     message_with_map_1 = test_map_pb2.MessageWithMap()
     message_with_map_1.bool_string_map[True] = "goodbye"
     with self.session() as sess:
         result = _run_parse_message_level_ex(
             [message_with_map_0, message_with_map_1], {map_field}, sess)
         self.assertIn(map_field, result)
         self.assertAllEqual(result[map_field][_VALUE], [b"goodbye"])
         self.assertAllEqual(result[map_field][_INDEX], [1])
示例#2
0
 def test_invalid_bool_key(self):
     message_with_map = test_map_pb2.MessageWithMap()
     with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                  "Failed to parse .*string"):
         self.evaluate(
             self._parse_map_entry([message_with_map], "bool_string_map",
                                   ["2"]))
示例#3
0
 def test_bool_key_type(self):
   message_with_map = test_map_pb2.MessageWithMap()
   message_with_map.bool_string_map[False] = "hello"
   [(values_false, indices_false), (values_true, indices_true)
   ] = self._parse_map_entry([message_with_map], "bool_string_map", ["0", "1"])
   self.assertAllEqual(values_true, [])
   self.assertAllEqual(values_false, [b"hello"])
   self.assertAllEqual(indices_true, [])
   self.assertAllEqual(indices_false, [0])
示例#4
0
 def test_enum_value_type(self):
     message_with_map = test_map_pb2.MessageWithMap()
     message_with_map.string_enum_map["foo"] = test_map_pb2.BAZ
     [(values_foo, indices_foo), (values_null, indices_null)
      ] = self._parse_map_entry([message_with_map], "string_enum_map",
                                ["foo", ""])
     self.assertAllEqual(values_foo, [int(test_map_pb2.BAZ)])
     self.assertAllEqual(values_null, [])
     self.assertAllEqual(indices_foo, [0])
     self.assertAllEqual(indices_null, [])
示例#5
0
 def test_multiple_messages(self):
     message_with_map1 = test_map_pb2.MessageWithMap(string_string_map={
         "key1": "foo",
         "key3": "bar"
     })
     message_with_map2 = test_map_pb2.MessageWithMap()
     message_with_map3 = test_map_pb2.MessageWithMap(string_string_map={
         "key2": "baz",
         "key1": "kaz"
     })
     [(values_key1, indices_key1), (values_key2, indices_key2),
      (values_key3, indices_key3)] = self._parse_map_entry(
          [message_with_map1, message_with_map2, message_with_map3],
          "string_string_map", ["key1", "key2", "key3"])
     self.assertAllEqual(values_key1, [b"foo", b"kaz"])
     self.assertAllEqual(values_key2, [b"baz"])
     self.assertAllEqual(values_key3, [b"bar"])
     self.assertAllEqual(indices_key1, [0, 2])
     self.assertAllEqual(indices_key2, [2])
     self.assertAllEqual(indices_key3, [0])
示例#6
0
 def test_message_value_type(self):
     sub_message = test_map_pb2.SubMessage(repeated_int64=[1, 2, 3])
     message_with_map = test_map_pb2.MessageWithMap()
     message_with_map.string_message_map["foo"].MergeFrom(sub_message)
     [(values_foo, indices_foo), (values_null, indices_null)
      ] = self._parse_map_entry([message_with_map], "string_message_map",
                                ["foo", ""])
     self.assertAllEqual(values_foo, [sub_message.SerializeToString()])
     self.assertAllEqual(values_null, [])
     self.assertAllEqual(indices_foo, [0])
     self.assertAllEqual(indices_null, [])
示例#7
0
 def test_fp_value_types(self, value_type):
   field_name = "string_{}_map".format(value_type)
   message_with_map = test_map_pb2.MessageWithMap()
   map_entry = getattr(message_with_map, "string_{}_map".format(value_type))
   map_entry["foo"] = 0.5
   [(values_foo, indices_foo), (values_null, indices_null)
   ] = self._parse_map_entry([message_with_map], field_name, ["foo", ""])
   self.assertAllEqual(values_foo, [0.5])
   self.assertAllEqual(values_null, [])
   self.assertAllEqual(indices_foo, [0])
   self.assertAllEqual(indices_null, [])
示例#8
0
    def test_unsigned_integer_key_types(self, key_type):
        field_name = "{}_string_map".format(key_type)
        message_with_map = test_map_pb2.MessageWithMap()
        map_entry = getattr(message_with_map, "{}_string_map".format(key_type))
        map_entry[42] = "hello"

        [(values_42, indices_42), (values_0, indices_0)
         ] = self._parse_map_entry([message_with_map], field_name, ["42", "0"])
        self.assertAllEqual(values_42, [b"hello"])
        self.assertAllEqual(values_0, [])
        self.assertAllEqual(indices_42, [0])
        self.assertAllEqual(indices_0, [])
示例#9
0
 def test_signed_integer_value_types(self, value_type):
   field_name = "string_{}_map".format(value_type)
   message_with_map = test_map_pb2.MessageWithMap()
   map_entry = getattr(message_with_map, "string_{}_map".format(value_type))
   map_entry["foo"] = 42
   map_entry["bar"] = -42
   [(values_foo, indices_foo), (values_bar, indices_bar),
    (values_null, indices_null)] = self._parse_map_entry([message_with_map],
                                                         field_name,
                                                         ["foo", "bar", ""])
   self.assertAllEqual(values_foo, [42])
   self.assertAllEqual(values_bar, [-42])
   self.assertAllEqual(values_null, [])
   self.assertAllEqual(indices_foo, [0])
   self.assertAllEqual(indices_bar, [0])
   self.assertAllEqual(indices_null, [])
示例#10
0
 def test_invalid_int32_key(self):
     with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                  "Failed to parse .*string"):
         self.evaluate(
             self._parse_map_entry([test_map_pb2.MessageWithMap()],
                                   "int32_string_map", ["foo"]))