示例#1
0
  def testGlobalDispatcher(self):
    original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
    try:
      TensorTracerOpDispatcher().register()

      x = TensorTracer("x")
      y = TensorTracer("y")
      trace = math_ops.reduce_sum(math_ops.add(math_ops.abs(x), y), axis=3)
      self.assertEqual(
          str(trace), "math.reduce_sum(math.add(math.abs(x), y), axis=3)")

      proto_val = TensorTracer("proto")
      trace = decode_proto(proto_val, "message_type", ["field"], ["float32"])
      self.assertIn("io.decode_proto(bytes=proto,", str(trace))

    finally:
      # Clean up.
      dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
 def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
     with self.cached_session() as sess:
         request_tensors = proto_ops.encode_proto(
             message_type='tensorflow.contrib.rpc.TestCase',
             field_names=['values'],
             sizes=[[3]] * 20,
             values=[
                 [[i, i + 1, i + 2] for i in range(20)],
             ])
         response_tensor_strings = self.rpc(
             method=self.get_method_name('Increment'),
             address=self._address,
             request=request_tensors)
         _, (response_shape, ) = proto_ops.decode_proto(
             bytes=response_tensor_strings,
             message_type='tensorflow.contrib.rpc.TestCase',
             field_names=['values'],
             output_types=[dtypes.int32])
         response_shape_values = sess.run(response_shape)
     self.assertAllEqual([[i + 1, i + 2, i + 3] for i in range(20)],
                         response_shape_values)
 def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
   with self.cached_session() as sess:
     request_tensors = proto_ops.encode_proto(
         message_type='tensorflow.contrib.rpc.TestCase',
         field_names=['values'],
         sizes=[[3]] * 20,
         values=[
             [[i, i + 1, i + 2] for i in range(20)],
         ])
     response_tensor_strings = self.rpc(
         method=self.get_method_name('Increment'),
         address=self._address,
         request=request_tensors)
     _, (response_shape,) = proto_ops.decode_proto(
         bytes=response_tensor_strings,
         message_type='tensorflow.contrib.rpc.TestCase',
         field_names=['values'],
         output_types=[dtypes.int32])
     response_shape_values = sess.run(response_shape)
   self.assertAllEqual([[i + 1, i + 2, i + 3]
                        for i in range(20)], response_shape_values)