def _TestCorruptProtobuf(self, sanitize): """Test failure cases for DecodeToProto.""" # The goal here is to check the error reporting. # Testing against a variety of corrupt protobufs is # done by fuzzing. corrupt_proto = 'This is not a binary protobuf' # Numpy silently truncates the strings if you don't specify dtype=object. batch = np.array(corrupt_proto, dtype=object) msg_type = 'tensorflow.contrib.proto.TestCase' field_names = ['sizes'] field_types = [dtypes.int32] with self.test_session() as sess: ctensor, vtensor = decode_proto_op.decode_proto( batch, message_type=msg_type, field_names=field_names, output_types=field_types, sanitize=sanitize) with self.assertRaisesRegexp(errors.DataLossError, 'Unable to parse binary protobuf' '|Failed to consume entire buffer'): _ = sess.run([ctensor] + vtensor)
def testVecHostPortRpcUsingEncodeAndDecodeProto(self): with self.test_session() as sess: request_tensors = encode_proto_op.encode_proto( message_type='tensorflow.contrib.rpc.TestCase', field_names=['shape'], sizes=[[3]] * 20, values=[ [[i, i + 1, i + 2] for i in range(20)], ]) response_tensor_strings = self.rpc( method=self.get_method_name('IncrementTestShapes'), address=self._address, request=request_tensors) _, (response_shape,) = decode_proto_op.decode_proto( bytes=response_tensor_strings, message_type='tensorflow.contrib.rpc.TestCase', field_names=['shape'], output_types=[dtypes.int32]) response_shape_values = sess.run(response_shape) self.assertAllEqual([[i + 1, i + 2, i + 3] for i in range(20)], response_shape_values)
def _testRoundtrip(self, in_bufs, message_type, fields): field_names = [f.name for f in fields] out_types = [f.dtype for f in fields] with self.test_session() as sess: sizes, field_tensors = decode_proto_op.decode_proto( in_bufs, message_type=message_type, field_names=field_names, output_types=out_types) out_tensors = encode_proto_op.encode_proto( sizes, field_tensors, message_type=message_type, field_names=field_names) out_bufs, = sess.run([out_tensors]) # Check that the re-encoded tensor has the same shape. self.assertEqual(in_bufs.shape, out_bufs.shape) # Compare the input and output. for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): in_obj = test_example_pb2.RepeatedPrimitiveValue() in_obj.ParseFromString(in_buf) out_obj = test_example_pb2.RepeatedPrimitiveValue() out_obj.ParseFromString(out_buf) # Check that the deserialized objects are identical. self.assertEqual(in_obj, out_obj) # Check that the input and output serialized messages are identical. # If we fail here, there is a difference in the serialized # representation but the new serialization still parses. This could # be harmless (a change in map ordering?) or it could be bad (e.g. # loss of packing in the encoding). self.assertEqual(in_buf, out_buf)
def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, message_type, message_format, sanitize, force_disordered=False): """Run decode tests on a batch of messages. Args: fields: list of test_example_pb2.FieldSpec (types and expected values) case_sizes: expected sizes array batch_shape: the shape of the input tensor of serialized messages batch: list of serialized messages message_type: descriptor name for messages message_format: format of messages, 'text' or 'binary' sanitize: whether to sanitize binary protobuf inputs force_disordered: whether to force fields encoded out of order. """ if force_disordered: # Exercise code path that handles out-of-order fields by prepending extra # fields with tag numbers higher than any real field. Note that this won't # work with sanitization because that forces reserialization using a # trusted decoder and encoder. assert not sanitize extra_fields = test_example_pb2.ExtraFields() extra_fields.string_value = 'IGNORE ME' extra_fields.bool_value = False extra_msg = extra_fields.SerializeToString() batch = [extra_msg + msg for msg in batch] # Numpy silently truncates the strings if you don't specify dtype=object. batch = np.array(batch, dtype=object) batch = np.reshape(batch, batch_shape) field_names = [f.name for f in fields] output_types = [f.dtype for f in fields] with self.test_session() as sess: sizes, vtensor = decode_proto_op.decode_proto( batch, message_type=message_type, field_names=field_names, output_types=output_types, message_format=message_format, sanitize=sanitize) vlist = sess.run([sizes] + vtensor) sizes = vlist[0] # Values is a list of tensors, one for each field. value_tensors = vlist[1:] # Check that the repeat sizes are correct. self.assertTrue( np.all(np.array(sizes.shape) == batch_shape + [len(field_names)])) # Check that the decoded sizes match the expected sizes. self.assertEqual(len(sizes.flat), len(case_sizes)) self.assertTrue( np.all(sizes.flat == np.array( case_sizes, dtype=np.int32))) field_dict = dict(zip(field_names, value_tensors)) self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields, field_dict)
def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, message_type, message_format, sanitize, force_disordered=False): """Run decode tests on a batch of messages. Args: fields: list of test_example_pb2.FieldSpec (types and expected values) case_sizes: expected sizes array batch_shape: the shape of the input tensor of serialized messages batch: list of serialized messages message_type: descriptor name for messages message_format: format of messages, 'text' or 'binary' sanitize: whether to sanitize binary protobuf inputs force_disordered: whether to force fields encoded out of order. """ if force_disordered: # Exercise code path that handles out-of-order fields by prepending extra # fields with tag numbers higher than any real field. Note that this won't # work with sanitization because that forces reserialization using a # trusted decoder and encoder. assert not sanitize extra_fields = test_example_pb2.ExtraFields() extra_fields.string_value = 'IGNORE ME' extra_fields.bool_value = False extra_msg = extra_fields.SerializeToString() batch = [extra_msg + msg for msg in batch] # Numpy silently truncates the strings if you don't specify dtype=object. batch = np.array(batch, dtype=object) batch = np.reshape(batch, batch_shape) field_names = [f.name for f in fields] output_types = [f.dtype for f in fields] with self.test_session() as sess: sizes, vtensor = decode_proto_op.decode_proto( batch, message_type=message_type, field_names=field_names, output_types=output_types, message_format=message_format, sanitize=sanitize) vlist = sess.run([sizes] + vtensor) sizes = vlist[0] # Values is a list of tensors, one for each field. value_tensors = vlist[1:] # Check that the repeat sizes are correct. self.assertTrue( np.all( np.array(sizes.shape) == batch_shape + [len(field_names)])) # Check that the decoded sizes match the expected sizes. self.assertEqual(len(sizes.flat), len(case_sizes)) self.assertTrue( np.all(sizes.flat == np.array(case_sizes, dtype=np.int32))) field_dict = dict(zip(field_names, value_tensors)) self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields, field_dict)
def loop_body(a, unused_b): """Loop body for the tf.while_loop op. Args: a: a constant 0 unused_b: a string placeholder (to satisfy the requirement that a while_loop's condition and body accept the same args as the loop returns). Returns: A TensorFlow subgraph. """ # Request features features. raw_response = tf.contrib.rpc.rpc( address=config.address, method=config.get_features_method, request="", protocol="grpc", fail_fast=True, timeout_in_ms=0, name="get_features") # Decode features from a proto to a flat tensor. _, (batch_id, flat_features) = decode_proto_op.decode_proto( bytes=raw_response, message_type='minigo.GetFeaturesResponse', field_names=['batch_id', 'features'], output_types=[dtypes.int32, dtypes.float32], descriptor_source=config.descriptor_path, name="decode_raw_features") # Reshape flat features. features = tf.reshape( flat_features, [-1, go.N, go.N, features_lib.NEW_FEATURES_PLANES], name="unflatten_features") # Run inference. policy_output, value_output, _ = dual_net.model_inference_fn( features, False) # Flatten model outputs. flat_policy = tf.reshape(policy_output, [-1], name="flatten_policy") flat_value = value_output # value_output is already flat. # Encode outputs from flat tensors to a proto. request_tensors = encode_proto_op.encode_proto( message_type='minigo.PutOutputsRequest', field_names=['batch_id', 'policy', 'value'], sizes=[[1, policy_output_size, value_output_size]], values=[[batch_id], [flat_policy], [flat_value]], descriptor_source=config.descriptor_path, name="encode_outputs") # Send outputs. response = tf.contrib.rpc.rpc( address=config.address, method=config.put_outputs_method, request=request_tensors, protocol="grpc", fail_fast=True, timeout_in_ms=0, name="put_outputs") return a, response[0]