def test_bool_key_type(self): map_field = "bool_string_map[1]" message_with_map_0 = test_map_pb2.MessageWithMap() message_with_map_0.bool_string_map[False] = "hello" message_with_map_1 = test_map_pb2.MessageWithMap() message_with_map_1.bool_string_map[True] = "goodbye" with self.session() as sess: result = _run_parse_message_level_ex( [message_with_map_0, message_with_map_1], {map_field}, sess) self.assertIn(map_field, result) self.assertAllEqual(result[map_field][_VALUE], [b"goodbye"]) self.assertAllEqual(result[map_field][_INDEX], [1])
def test_invalid_bool_key(self): message_with_map = test_map_pb2.MessageWithMap() with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, "Failed to parse .*string"): self.evaluate( self._parse_map_entry([message_with_map], "bool_string_map", ["2"]))
def test_bool_key_type(self): message_with_map = test_map_pb2.MessageWithMap() message_with_map.bool_string_map[False] = "hello" [(values_false, indices_false), (values_true, indices_true) ] = self._parse_map_entry([message_with_map], "bool_string_map", ["0", "1"]) self.assertAllEqual(values_true, []) self.assertAllEqual(values_false, [b"hello"]) self.assertAllEqual(indices_true, []) self.assertAllEqual(indices_false, [0])
def test_enum_value_type(self): message_with_map = test_map_pb2.MessageWithMap() message_with_map.string_enum_map["foo"] = test_map_pb2.BAZ [(values_foo, indices_foo), (values_null, indices_null) ] = self._parse_map_entry([message_with_map], "string_enum_map", ["foo", ""]) self.assertAllEqual(values_foo, [int(test_map_pb2.BAZ)]) self.assertAllEqual(values_null, []) self.assertAllEqual(indices_foo, [0]) self.assertAllEqual(indices_null, [])
def test_multiple_messages(self): message_with_map1 = test_map_pb2.MessageWithMap(string_string_map={ "key1": "foo", "key3": "bar" }) message_with_map2 = test_map_pb2.MessageWithMap() message_with_map3 = test_map_pb2.MessageWithMap(string_string_map={ "key2": "baz", "key1": "kaz" }) [(values_key1, indices_key1), (values_key2, indices_key2), (values_key3, indices_key3)] = self._parse_map_entry( [message_with_map1, message_with_map2, message_with_map3], "string_string_map", ["key1", "key2", "key3"]) self.assertAllEqual(values_key1, [b"foo", b"kaz"]) self.assertAllEqual(values_key2, [b"baz"]) self.assertAllEqual(values_key3, [b"bar"]) self.assertAllEqual(indices_key1, [0, 2]) self.assertAllEqual(indices_key2, [2]) self.assertAllEqual(indices_key3, [0])
def test_message_value_type(self): sub_message = test_map_pb2.SubMessage(repeated_int64=[1, 2, 3]) message_with_map = test_map_pb2.MessageWithMap() message_with_map.string_message_map["foo"].MergeFrom(sub_message) [(values_foo, indices_foo), (values_null, indices_null) ] = self._parse_map_entry([message_with_map], "string_message_map", ["foo", ""]) self.assertAllEqual(values_foo, [sub_message.SerializeToString()]) self.assertAllEqual(values_null, []) self.assertAllEqual(indices_foo, [0]) self.assertAllEqual(indices_null, [])
def test_fp_value_types(self, value_type): field_name = "string_{}_map".format(value_type) message_with_map = test_map_pb2.MessageWithMap() map_entry = getattr(message_with_map, "string_{}_map".format(value_type)) map_entry["foo"] = 0.5 [(values_foo, indices_foo), (values_null, indices_null) ] = self._parse_map_entry([message_with_map], field_name, ["foo", ""]) self.assertAllEqual(values_foo, [0.5]) self.assertAllEqual(values_null, []) self.assertAllEqual(indices_foo, [0]) self.assertAllEqual(indices_null, [])
def test_unsigned_integer_key_types(self, key_type): field_name = "{}_string_map".format(key_type) message_with_map = test_map_pb2.MessageWithMap() map_entry = getattr(message_with_map, "{}_string_map".format(key_type)) map_entry[42] = "hello" [(values_42, indices_42), (values_0, indices_0) ] = self._parse_map_entry([message_with_map], field_name, ["42", "0"]) self.assertAllEqual(values_42, [b"hello"]) self.assertAllEqual(values_0, []) self.assertAllEqual(indices_42, [0]) self.assertAllEqual(indices_0, [])
def test_signed_integer_value_types(self, value_type): field_name = "string_{}_map".format(value_type) message_with_map = test_map_pb2.MessageWithMap() map_entry = getattr(message_with_map, "string_{}_map".format(value_type)) map_entry["foo"] = 42 map_entry["bar"] = -42 [(values_foo, indices_foo), (values_bar, indices_bar), (values_null, indices_null)] = self._parse_map_entry([message_with_map], field_name, ["foo", "bar", ""]) self.assertAllEqual(values_foo, [42]) self.assertAllEqual(values_bar, [-42]) self.assertAllEqual(values_null, []) self.assertAllEqual(indices_foo, [0]) self.assertAllEqual(indices_bar, [0]) self.assertAllEqual(indices_null, [])
def test_invalid_int32_key(self): with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, "Failed to parse .*string"): self.evaluate( self._parse_map_entry([test_map_pb2.MessageWithMap()], "int32_string_map", ["foo"]))