def testOutOfOrderRepeated(self): fragments = [ test_example_pb2.TestValue(double_value=[1.0]).SerializeToString(), test_example_pb2.TestValue(message_value=[ test_example_pb2.PrimitiveValue(string_value='abc') ]).SerializeToString(), test_example_pb2.TestValue(message_value=[ test_example_pb2.PrimitiveValue(string_value='def') ]).SerializeToString() ] all_fields_to_parse = ['double_value', 'message_value'] field_types = { 'double_value': dtypes.double, 'message_value': dtypes.string, } # Test against all 3! permutations of fragments, and for each permutation # test parsing all possible combination of 2 fields. for indices in itertools.permutations(range(len(fragments))): proto = b''.join(fragments[i] for i in indices) for i in indices: if i == 1: expected_message_values = [ test_example_pb2.PrimitiveValue( string_value='abc').SerializeToString(), test_example_pb2.PrimitiveValue( string_value='def').SerializeToString(), ] break if i == 2: expected_message_values = [ test_example_pb2.PrimitiveValue( string_value='def').SerializeToString(), test_example_pb2.PrimitiveValue( string_value='abc').SerializeToString(), ] break expected_field_values = { 'double_value': [[1.0]], 'message_value': [expected_message_values], } for num_fields_to_parse in range(len(all_fields_to_parse)): for comb in itertools.combinations(all_fields_to_parse, num_fields_to_parse): parsed_values = self.evaluate( self._decode_module.decode_proto( [proto], message_type='tensorflow.contrib.proto.TestValue', field_names=comb, output_types=[field_types[f] for f in comb], sanitize=False)).values self.assertLen(parsed_values, len(comb)) for field_name, parsed in zip(comb, parsed_values): self.assertAllEqual( parsed, expected_field_values[field_name], 'perm: {}, comb: {}'.format(indices, comb))
def _compareProtos(self, batch_shape, sizes, fields, field_dict): """Compare protos of type TestValue. Args: batch_shape: the shape of the input tensor of serialized messages. sizes: int matrix of repeat counts returned by decode_proto fields: list of test_example_pb2.FieldSpec (types and expected values) field_dict: map from field names to decoded numpy tensors of values """ # Check that expected values match. for field in fields: values = field_dict[field.name] self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) if 'ext_value' in field.name: fd = test_example_pb2.PrimitiveValue() else: fd = field.value.DESCRIPTOR.fields_by_name[field.name] # Values has the same shape as the input plus an extra # dimension for repeats. self.assertEqual(list(values.shape)[:-1], batch_shape) # Nested messages are represented as TF strings, requiring # some special handling. if field.name == 'message_value' or 'ext_value' in field.name: vs = [] for buf in values.flat: msg = test_example_pb2.PrimitiveValue() msg.ParseFromString(buf) vs.append(msg) if 'ext_value' in field.name: evs = field.value.Extensions[test_example_pb2.ext_value] else: evs = getattr(field.value, field.name) if len(vs) != len(evs): self.fail('Field %s decoded %d outputs, expected %d' % (fd.name, len(vs), len(evs))) for v, ev in zip(vs, evs): self.assertEqual(v, ev) continue tf_type_to_primitive_value_field = { dtypes.bool: 'bool_value', dtypes.float32: 'float_value', dtypes.float64: 'double_value', dtypes.int8: 'int8_value', dtypes.int32: 'int32_value', dtypes.int64: 'int64_value', dtypes.string: 'string_value', dtypes.uint8: 'uint8_value', dtypes.uint32: 'uint32_value', dtypes.uint64: 'uint64_value', } if field.name in ['enum_value', 'enum_value_with_default']: tf_field_name = 'enum_value' else: tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) if tf_field_name is None: self.fail('Unhandled tensorflow type %d' % field.dtype) self._compareValues(fd, values.flat, getattr(field.value, tf_field_name))