def _TestCorruptProtobuf(self, sanitize):
    """Test failure cases for DecodeToProto."""

    # The goal here is to check the error reporting.
    # Testing against a variety of corrupt protobufs is
    # done by fuzzing.
    corrupt_proto = 'This is not a binary protobuf'

    # Numpy silently truncates the strings if you don't specify dtype=object.
    batch = np.array(corrupt_proto, dtype=object)
    msg_type = 'tensorflow.contrib.proto.TestCase'
    field_names = ['sizes']
    field_types = [dtypes.int32]

    with self.test_session() as sess:
      ctensor, vtensor = decode_proto_op.decode_proto(
          batch,
          message_type=msg_type,
          field_names=field_names,
          output_types=field_types,
          sanitize=sanitize)
      with self.assertRaisesRegexp(errors.DataLossError,
                                   'Unable to parse binary protobuf'
                                   '|Failed to consume entire buffer'):
        _ = sess.run([ctensor] + vtensor)
  def _TestCorruptProtobuf(self, sanitize):
    """Test failure cases for DecodeToProto."""

    # The goal here is to check the error reporting.
    # Testing against a variety of corrupt protobufs is
    # done by fuzzing.
    corrupt_proto = 'This is not a binary protobuf'

    # Numpy silently truncates the strings if you don't specify dtype=object.
    batch = np.array(corrupt_proto, dtype=object)
    msg_type = 'tensorflow.contrib.proto.TestCase'
    field_names = ['sizes']
    field_types = [dtypes.int32]

    with self.test_session() as sess:
      ctensor, vtensor = decode_proto_op.decode_proto(
          batch,
          message_type=msg_type,
          field_names=field_names,
          output_types=field_types,
          sanitize=sanitize)
      with self.assertRaisesRegexp(errors.DataLossError,
                                   'Unable to parse binary protobuf'
                                   '|Failed to consume entire buffer'):
        _ = sess.run([ctensor] + vtensor)
Пример #3
0
 def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
   with self.test_session() as sess:
     request_tensors = encode_proto_op.encode_proto(
         message_type='tensorflow.contrib.rpc.TestCase',
         field_names=['shape'],
         sizes=[[3]] * 20,
         values=[
             [[i, i + 1, i + 2] for i in range(20)],
         ])
     response_tensor_strings = self.rpc(
         method=self.get_method_name('IncrementTestShapes'),
         address=self._address,
         request=request_tensors)
     _, (response_shape,) = decode_proto_op.decode_proto(
         bytes=response_tensor_strings,
         message_type='tensorflow.contrib.rpc.TestCase',
         field_names=['shape'],
         output_types=[dtypes.int32])
     response_shape_values = sess.run(response_shape)
   self.assertAllEqual([[i + 1, i + 2, i + 3]
                        for i in range(20)], response_shape_values)
    def _testRoundtrip(self, in_bufs, message_type, fields):

        field_names = [f.name for f in fields]
        out_types = [f.dtype for f in fields]

        with self.test_session() as sess:
            sizes, field_tensors = decode_proto_op.decode_proto(
                in_bufs,
                message_type=message_type,
                field_names=field_names,
                output_types=out_types)

            out_tensors = encode_proto_op.encode_proto(
                sizes,
                field_tensors,
                message_type=message_type,
                field_names=field_names)

            out_bufs, = sess.run([out_tensors])

            # Check that the re-encoded tensor has the same shape.
            self.assertEqual(in_bufs.shape, out_bufs.shape)

            # Compare the input and output.
            for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat):
                in_obj = test_example_pb2.RepeatedPrimitiveValue()
                in_obj.ParseFromString(in_buf)

                out_obj = test_example_pb2.RepeatedPrimitiveValue()
                out_obj.ParseFromString(out_buf)

                # Check that the deserialized objects are identical.
                self.assertEqual(in_obj, out_obj)

                # Check that the input and output serialized messages are identical.
                # If we fail here, there is a difference in the serialized
                # representation but the new serialization still parses. This could
                # be harmless (a change in map ordering?) or it could be bad (e.g.
                # loss of packing in the encoding).
                self.assertEqual(in_buf, out_buf)
  def _testRoundtrip(self, in_bufs, message_type, fields):

    field_names = [f.name for f in fields]
    out_types = [f.dtype for f in fields]

    with self.test_session() as sess:
      sizes, field_tensors = decode_proto_op.decode_proto(
          in_bufs,
          message_type=message_type,
          field_names=field_names,
          output_types=out_types)

      out_tensors = encode_proto_op.encode_proto(
          sizes,
          field_tensors,
          message_type=message_type,
          field_names=field_names)

      out_bufs, = sess.run([out_tensors])

      # Check that the re-encoded tensor has the same shape.
      self.assertEqual(in_bufs.shape, out_bufs.shape)

      # Compare the input and output.
      for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat):
        in_obj = test_example_pb2.RepeatedPrimitiveValue()
        in_obj.ParseFromString(in_buf)

        out_obj = test_example_pb2.RepeatedPrimitiveValue()
        out_obj.ParseFromString(out_buf)

        # Check that the deserialized objects are identical.
        self.assertEqual(in_obj, out_obj)

        # Check that the input and output serialized messages are identical.
        # If we fail here, there is a difference in the serialized
        # representation but the new serialization still parses. This could
        # be harmless (a change in map ordering?) or it could be bad (e.g.
        # loss of packing in the encoding).
        self.assertEqual(in_buf, out_buf)
  def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch,
                           message_type, message_format, sanitize,
                           force_disordered=False):
    """Run decode tests on a batch of messages.

    Args:
      fields: list of test_example_pb2.FieldSpec (types and expected values)
      case_sizes: expected sizes array
      batch_shape: the shape of the input tensor of serialized messages
      batch: list of serialized messages
      message_type: descriptor name for messages
      message_format: format of messages, 'text' or 'binary'
      sanitize: whether to sanitize binary protobuf inputs
      force_disordered: whether to force fields encoded out of order.
    """

    if force_disordered:
      # Exercise code path that handles out-of-order fields by prepending extra
      # fields with tag numbers higher than any real field. Note that this won't
      # work with sanitization because that forces reserialization using a
      # trusted decoder and encoder.
      assert not sanitize
      extra_fields = test_example_pb2.ExtraFields()
      extra_fields.string_value = 'IGNORE ME'
      extra_fields.bool_value = False
      extra_msg = extra_fields.SerializeToString()
      batch = [extra_msg + msg for msg in batch]

    # Numpy silently truncates the strings if you don't specify dtype=object.
    batch = np.array(batch, dtype=object)
    batch = np.reshape(batch, batch_shape)

    field_names = [f.name for f in fields]
    output_types = [f.dtype for f in fields]

    with self.test_session() as sess:
      sizes, vtensor = decode_proto_op.decode_proto(
          batch,
          message_type=message_type,
          field_names=field_names,
          output_types=output_types,
          message_format=message_format,
          sanitize=sanitize)

      vlist = sess.run([sizes] + vtensor)
      sizes = vlist[0]
      # Values is a list of tensors, one for each field.
      value_tensors = vlist[1:]

      # Check that the repeat sizes are correct.
      self.assertTrue(
          np.all(np.array(sizes.shape) == batch_shape + [len(field_names)]))

      # Check that the decoded sizes match the expected sizes.
      self.assertEqual(len(sizes.flat), len(case_sizes))
      self.assertTrue(
          np.all(sizes.flat == np.array(
              case_sizes, dtype=np.int32)))

      field_dict = dict(zip(field_names, value_tensors))

      self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields,
                                          field_dict)
Пример #7
0
    def _runDecodeProtoTests(self,
                             fields,
                             case_sizes,
                             batch_shape,
                             batch,
                             message_type,
                             message_format,
                             sanitize,
                             force_disordered=False):
        """Run decode tests on a batch of messages.

    Args:
      fields: list of test_example_pb2.FieldSpec (types and expected values)
      case_sizes: expected sizes array
      batch_shape: the shape of the input tensor of serialized messages
      batch: list of serialized messages
      message_type: descriptor name for messages
      message_format: format of messages, 'text' or 'binary'
      sanitize: whether to sanitize binary protobuf inputs
      force_disordered: whether to force fields encoded out of order.
    """

        if force_disordered:
            # Exercise code path that handles out-of-order fields by prepending extra
            # fields with tag numbers higher than any real field. Note that this won't
            # work with sanitization because that forces reserialization using a
            # trusted decoder and encoder.
            assert not sanitize
            extra_fields = test_example_pb2.ExtraFields()
            extra_fields.string_value = 'IGNORE ME'
            extra_fields.bool_value = False
            extra_msg = extra_fields.SerializeToString()
            batch = [extra_msg + msg for msg in batch]

        # Numpy silently truncates the strings if you don't specify dtype=object.
        batch = np.array(batch, dtype=object)
        batch = np.reshape(batch, batch_shape)

        field_names = [f.name for f in fields]
        output_types = [f.dtype for f in fields]

        with self.test_session() as sess:
            sizes, vtensor = decode_proto_op.decode_proto(
                batch,
                message_type=message_type,
                field_names=field_names,
                output_types=output_types,
                message_format=message_format,
                sanitize=sanitize)

            vlist = sess.run([sizes] + vtensor)
            sizes = vlist[0]
            # Values is a list of tensors, one for each field.
            value_tensors = vlist[1:]

            # Check that the repeat sizes are correct.
            self.assertTrue(
                np.all(
                    np.array(sizes.shape) == batch_shape + [len(field_names)]))

            # Check that the decoded sizes match the expected sizes.
            self.assertEqual(len(sizes.flat), len(case_sizes))
            self.assertTrue(
                np.all(sizes.flat == np.array(case_sizes, dtype=np.int32)))

            field_dict = dict(zip(field_names, value_tensors))

            self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields,
                                                field_dict)
Пример #8
0
    def loop_body(a, unused_b):
        """Loop body for the tf.while_loop op.

        Args:
            a: a constant 0
            unused_b: a string placeholder (to satisfy the requirement that a
                      while_loop's condition and body accept the same args as
                      the loop returns).

        Returns:
            A TensorFlow subgraph.
        """

        # Request features features.
        raw_response = tf.contrib.rpc.rpc(
            address=config.address,
            method=config.get_features_method,
            request="",
            protocol="grpc",
            fail_fast=True,
            timeout_in_ms=0,
            name="get_features")

        # Decode features from a proto to a flat tensor.
        _, (batch_id, flat_features) = decode_proto_op.decode_proto(
            bytes=raw_response,
            message_type='minigo.GetFeaturesResponse',
            field_names=['batch_id', 'features'],
            output_types=[dtypes.int32, dtypes.float32],
            descriptor_source=config.descriptor_path,
            name="decode_raw_features")

        # Reshape flat features.
        features = tf.reshape(
            flat_features, [-1, go.N, go.N, features_lib.NEW_FEATURES_PLANES],
            name="unflatten_features")

        # Run inference.
        policy_output, value_output, _ = dual_net.model_inference_fn(
            features, False)

        # Flatten model outputs.
        flat_policy = tf.reshape(policy_output, [-1], name="flatten_policy")
        flat_value = value_output  # value_output is already flat.

        # Encode outputs from flat tensors to a proto.
        request_tensors = encode_proto_op.encode_proto(
            message_type='minigo.PutOutputsRequest',
            field_names=['batch_id', 'policy', 'value'],
            sizes=[[1, policy_output_size, value_output_size]],
            values=[[batch_id], [flat_policy], [flat_value]],
            descriptor_source=config.descriptor_path,
            name="encode_outputs")

        # Send outputs.
        response = tf.contrib.rpc.rpc(
            address=config.address,
            method=config.put_outputs_method,
            request=request_tensors,
            protocol="grpc",
            fail_fast=True,
            timeout_in_ms=0,
            name="put_outputs")

        return a, response[0]