def test_global_annotation(self):
        # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds.
        # pylint: disable=g-import-not-at-top
        try:
            from tensorflow_transform import annotations_pb2
        except ImportError:
            return
        # pylint: enable=g-import-not-at-top
        with tf.Graph().as_default() as graph:
            outputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
                'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
            }

            # Annotate an arbitrary proto at the schema level (not sure what global
            # schema boundaries would mean, but hey I'm just a test).
            boundaries = tf.constant([[1.0]])
            message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
            sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
            message_proto = encode_proto_op.encode_proto(
                sizes, [tf.cast(boundaries, tf.float32)], ['boundaries'],
                message_type)[0]
            type_url = os.path.join('type.googleapis.com', message_type)
            schema_inference.annotate(type_url, message_proto)

            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema._schema_proto.annotation.extra_metadata,
                               1)
                for annotation in schema._schema_proto.annotation.extra_metadata:
                    # Extract the annotated message and validate its contents
                    message = annotations_pb2.BucketBoundaries()
                    annotation.Unpack(message)
                    self.assertAllClose(list(message.boundaries), [1])
    def testBadInputs(self):
        # Invalid field name
        with self.test_session():
            with self.assertRaisesOpError('Unknown field: non_existent_field'):
                encode_proto_op.encode_proto(
                    sizes=[[1]],
                    values=[np.array([[0.0]], dtype=np.int32)],
                    message_type=
                    'tensorflow.contrib.proto.RepeatedPrimitiveValue',
                    field_names=['non_existent_field']).eval()

        # Incorrect types.
        with self.test_session():
            with self.assertRaisesOpError(
                    'Incompatible type for field double_value.'):
                encode_proto_op.encode_proto(
                    sizes=[[1]],
                    values=[np.array([[0.0]], dtype=np.int32)],
                    message_type=
                    'tensorflow.contrib.proto.RepeatedPrimitiveValue',
                    field_names=['double_value']).eval()

        # Incorrect shapes of sizes.
        with self.test_session():
            with self.assertRaisesOpError(
                    r'sizes should be batch_size \+ \[len\(field_names\)\]'):
                sizes = array_ops.placeholder(dtypes.int32)
                values = array_ops.placeholder(dtypes.float64)
                encode_proto_op.encode_proto(
                    sizes=sizes,
                    values=[values],
                    message_type=
                    'tensorflow.contrib.proto.RepeatedPrimitiveValue',
                    field_names=['double_value']).eval(feed_dict={
                        sizes: [[[0, 0]]],
                        values: [[0.0]]
                    })

        # Inconsistent shapes of values.
        with self.test_session():
            with self.assertRaisesOpError(
                    'Values must match up to the last dimension'):
                sizes = array_ops.placeholder(dtypes.int32)
                values1 = array_ops.placeholder(dtypes.float64)
                values2 = array_ops.placeholder(dtypes.int32)
                (encode_proto_op.encode_proto(
                    sizes=[[1, 1]],
                    values=[values1, values2],
                    message_type=
                    'tensorflow.contrib.proto.RepeatedPrimitiveValue',
                    field_names=['double_value',
                                 'int32_value']).eval(feed_dict={
                                     values1: [[0.0]],
                                     values2: [[0], [0]]
                                 }))
  def testBadInputs(self):
    # Invalid field name
    with self.test_session():
      with self.assertRaisesOpError('Unknown field: non_existent_field'):
        encode_proto_op.encode_proto(
            sizes=[[1]],
            values=[np.array([[0.0]], dtype=np.int32)],
            message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
            field_names=['non_existent_field']).eval()

    # Incorrect types.
    with self.test_session():
      with self.assertRaisesOpError(
          'Incompatible type for field double_value.'):
        encode_proto_op.encode_proto(
            sizes=[[1]],
            values=[np.array([[0.0]], dtype=np.int32)],
            message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
            field_names=['double_value']).eval()

    # Incorrect shapes of sizes.
    with self.test_session():
      with self.assertRaisesOpError(
          r'sizes should be batch_size \+ \[len\(field_names\)\]'):
        sizes = array_ops.placeholder(dtypes.int32)
        values = array_ops.placeholder(dtypes.float64)
        encode_proto_op.encode_proto(
            sizes=sizes,
            values=[values],
            message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
            field_names=['double_value']).eval(feed_dict={
                sizes: [[[0, 0]]],
                values: [[0.0]]
            })

    # Inconsistent shapes of values.
    with self.test_session():
      with self.assertRaisesOpError(
          'Values must match up to the last dimension'):
        sizes = array_ops.placeholder(dtypes.int32)
        values1 = array_ops.placeholder(dtypes.float64)
        values2 = array_ops.placeholder(dtypes.int32)
        (encode_proto_op.encode_proto(
            sizes=[[1, 1]],
            values=[values1, values2],
            message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
            field_names=['double_value', 'int32_value']).eval(feed_dict={
                values1: [[0.0]],
                values2: [[0], [0]]
            }))
示例#4
0
 def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
   with self.test_session() as sess:
     request_tensors = encode_proto_op.encode_proto(
         message_type='tensorflow.contrib.rpc.TestCase',
         field_names=['shape'],
         sizes=[[3]] * 20,
         values=[
             [[i, i + 1, i + 2] for i in range(20)],
         ])
     response_tensor_strings = self.rpc(
         method=self.get_method_name('IncrementTestShapes'),
         address=self._address,
         request=request_tensors)
     _, (response_shape,) = decode_proto_op.decode_proto(
         bytes=response_tensor_strings,
         message_type='tensorflow.contrib.rpc.TestCase',
         field_names=['shape'],
         output_types=[dtypes.int32])
     response_shape_values = sess.run(response_shape)
   self.assertAllEqual([[i + 1, i + 2, i + 3]
                        for i in range(20)], response_shape_values)
    def _testRoundtrip(self, in_bufs, message_type, fields):

        field_names = [f.name for f in fields]
        out_types = [f.dtype for f in fields]

        with self.test_session() as sess:
            sizes, field_tensors = decode_proto_op.decode_proto(
                in_bufs,
                message_type=message_type,
                field_names=field_names,
                output_types=out_types)

            out_tensors = encode_proto_op.encode_proto(
                sizes,
                field_tensors,
                message_type=message_type,
                field_names=field_names)

            out_bufs, = sess.run([out_tensors])

            # Check that the re-encoded tensor has the same shape.
            self.assertEqual(in_bufs.shape, out_bufs.shape)

            # Compare the input and output.
            for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat):
                in_obj = test_example_pb2.RepeatedPrimitiveValue()
                in_obj.ParseFromString(in_buf)

                out_obj = test_example_pb2.RepeatedPrimitiveValue()
                out_obj.ParseFromString(out_buf)

                # Check that the deserialized objects are identical.
                self.assertEqual(in_obj, out_obj)

                # Check that the input and output serialized messages are identical.
                # If we fail here, there is a difference in the serialized
                # representation but the new serialization still parses. This could
                # be harmless (a change in map ordering?) or it could be bad (e.g.
                # loss of packing in the encoding).
                self.assertEqual(in_buf, out_buf)
  def _testRoundtrip(self, in_bufs, message_type, fields):

    field_names = [f.name for f in fields]
    out_types = [f.dtype for f in fields]

    with self.test_session() as sess:
      sizes, field_tensors = decode_proto_op.decode_proto(
          in_bufs,
          message_type=message_type,
          field_names=field_names,
          output_types=out_types)

      out_tensors = encode_proto_op.encode_proto(
          sizes,
          field_tensors,
          message_type=message_type,
          field_names=field_names)

      out_bufs, = sess.run([out_tensors])

      # Check that the re-encoded tensor has the same shape.
      self.assertEqual(in_bufs.shape, out_bufs.shape)

      # Compare the input and output.
      for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat):
        in_obj = test_example_pb2.RepeatedPrimitiveValue()
        in_obj.ParseFromString(in_buf)

        out_obj = test_example_pb2.RepeatedPrimitiveValue()
        out_obj.ParseFromString(out_buf)

        # Check that the deserialized objects are identical.
        self.assertEqual(in_obj, out_obj)

        # Check that the input and output serialized messages are identical.
        # If we fail here, there is a difference in the serialized
        # representation but the new serialization still parses. This could
        # be harmless (a change in map ordering?) or it could be bad (e.g.
        # loss of packing in the encoding).
        self.assertEqual(in_buf, out_buf)
示例#7
0
def _encode_proto(values_dict, message_type):
    """A wrapper around encode_proto_op.encode_proto."""
    field_names = []
    sizes = []
    values = []
    for field_name, value in sorted(values_dict.items(), key=lambda x: x[0]):
        if isinstance(value, tf.SparseTensor):
            size = tf.sparse_reduce_sum(tf.SparseTensor(
                value.indices, tf.ones_like(value.values, dtype=tf.int32),
                value.dense_shape),
                                        axis=1)
            value = tf.sparse_tensor_to_dense(
                value, _DEFAULT_VALUE_BY_DTYPE[value.dtype])
        else:
            value = tf.reshape(value, [tf.shape(value)[0], -1])
            size = tf.fill((tf.shape(value)[0], ), tf.shape(value)[1])
        field_names.append(field_name)
        values.append(value)
        sizes.append(size)

    sizes = tf.stack(sizes, axis=1)
    return encode_proto_op.encode_proto(sizes, values, field_names,
                                        message_type)
示例#8
0
    def loop_body(a, unused_b):
        """Loop body for the tf.while_loop op.

        Args:
            a: a constant 0
            unused_b: a string placeholder (to satisfy the requirement that a
                      while_loop's condition and body accept the same args as
                      the loop returns).

        Returns:
            A TensorFlow subgraph.
        """

        # Request features features.
        raw_response = tf.contrib.rpc.rpc(
            address=config.address,
            method=config.get_features_method,
            request="",
            protocol="grpc",
            fail_fast=True,
            timeout_in_ms=0,
            name="get_features")

        # Decode features from a proto to a flat tensor.
        _, (batch_id, flat_features) = decode_proto_op.decode_proto(
            bytes=raw_response,
            message_type='minigo.GetFeaturesResponse',
            field_names=['batch_id', 'features'],
            output_types=[dtypes.int32, dtypes.float32],
            descriptor_source=config.descriptor_path,
            name="decode_raw_features")

        # Reshape flat features.
        features = tf.reshape(
            flat_features, [-1, go.N, go.N, features_lib.NEW_FEATURES_PLANES],
            name="unflatten_features")

        # Run inference.
        policy_output, value_output, _ = dual_net.model_inference_fn(
            features, False)

        # Flatten model outputs.
        flat_policy = tf.reshape(policy_output, [-1], name="flatten_policy")
        flat_value = value_output  # value_output is already flat.

        # Encode outputs from flat tensors to a proto.
        request_tensors = encode_proto_op.encode_proto(
            message_type='minigo.PutOutputsRequest',
            field_names=['batch_id', 'policy', 'value'],
            sizes=[[1, policy_output_size, value_output_size]],
            values=[[batch_id], [flat_policy], [flat_value]],
            descriptor_source=config.descriptor_path,
            name="encode_outputs")

        # Send outputs.
        response = tf.contrib.rpc.rpc(
            address=config.address,
            method=config.put_outputs_method,
            request=request_tensors,
            protocol="grpc",
            fail_fast=True,
            timeout_in_ms=0,
            name="put_outputs")

        return a, response[0]