コード例 #1
0
    def testOutOfOrderRepeated(self):
        fragments = [
            test_example_pb2.TestValue(double_value=[1.0]).SerializeToString(),
            test_example_pb2.TestValue(message_value=[
                test_example_pb2.PrimitiveValue(string_value='abc')
            ]).SerializeToString(),
            test_example_pb2.TestValue(message_value=[
                test_example_pb2.PrimitiveValue(string_value='def')
            ]).SerializeToString()
        ]
        all_fields_to_parse = ['double_value', 'message_value']
        field_types = {
            'double_value': dtypes.double,
            'message_value': dtypes.string,
        }
        # Test against all 3! permutations of fragments, and for each permutation
        # test parsing all possible combination of 2 fields.
        for indices in itertools.permutations(range(len(fragments))):
            proto = b''.join(fragments[i] for i in indices)
            for i in indices:
                if i == 1:
                    expected_message_values = [
                        test_example_pb2.PrimitiveValue(
                            string_value='abc').SerializeToString(),
                        test_example_pb2.PrimitiveValue(
                            string_value='def').SerializeToString(),
                    ]
                    break
                if i == 2:
                    expected_message_values = [
                        test_example_pb2.PrimitiveValue(
                            string_value='def').SerializeToString(),
                        test_example_pb2.PrimitiveValue(
                            string_value='abc').SerializeToString(),
                    ]
                    break

            expected_field_values = {
                'double_value': [[1.0]],
                'message_value': [expected_message_values],
            }

            for num_fields_to_parse in range(len(all_fields_to_parse)):
                for comb in itertools.combinations(all_fields_to_parse,
                                                   num_fields_to_parse):
                    parsed_values = self.evaluate(
                        self._decode_module.decode_proto(
                            [proto],
                            message_type='tensorflow.contrib.proto.TestValue',
                            field_names=comb,
                            output_types=[field_types[f] for f in comb],
                            sanitize=False)).values
                    self.assertLen(parsed_values, len(comb))
                    for field_name, parsed in zip(comb, parsed_values):
                        self.assertAllEqual(
                            parsed, expected_field_values[field_name],
                            'perm: {}, comb: {}'.format(indices, comb))
コード例 #2
0
  def _compareProtos(self, batch_shape, sizes, fields, field_dict):
    """Compare protos of type TestValue.

    Args:
      batch_shape: the shape of the input tensor of serialized messages.
      sizes: int matrix of repeat counts returned by decode_proto
      fields: list of test_example_pb2.FieldSpec (types and expected values)
      field_dict: map from field names to decoded numpy tensors of values
    """

    # Check that expected values match.
    for field in fields:
      values = field_dict[field.name]
      self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype)

      if 'ext_value' in field.name:
        fd = test_example_pb2.PrimitiveValue()
      else:
        fd = field.value.DESCRIPTOR.fields_by_name[field.name]

      # Values has the same shape as the input plus an extra
      # dimension for repeats.
      self.assertEqual(list(values.shape)[:-1], batch_shape)

      # Nested messages are represented as TF strings, requiring
      # some special handling.
      if field.name == 'message_value' or 'ext_value' in field.name:
        vs = []
        for buf in values.flat:
          msg = test_example_pb2.PrimitiveValue()
          msg.ParseFromString(buf)
          vs.append(msg)
        if 'ext_value' in field.name:
          evs = field.value.Extensions[test_example_pb2.ext_value]
        else:
          evs = getattr(field.value, field.name)
        if len(vs) != len(evs):
          self.fail('Field %s decoded %d outputs, expected %d' %
                    (fd.name, len(vs), len(evs)))
        for v, ev in zip(vs, evs):
          self.assertEqual(v, ev)
        continue

      tf_type_to_primitive_value_field = {
          dtypes.bool:
              'bool_value',
          dtypes.float32:
              'float_value',
          dtypes.float64:
              'double_value',
          dtypes.int8:
              'int8_value',
          dtypes.int32:
              'int32_value',
          dtypes.int64:
              'int64_value',
          dtypes.string:
              'string_value',
          dtypes.uint8:
              'uint8_value',
          dtypes.uint32:
              'uint32_value',
          dtypes.uint64:
              'uint64_value',
      }
      if field.name in ['enum_value', 'enum_value_with_default']:
        tf_field_name = 'enum_value'
      else:
        tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
      if tf_field_name is None:
        self.fail('Unhandled tensorflow type %d' % field.dtype)

      self._compareValues(fd, values.flat,
                          getattr(field.value, tf_field_name))