Exemple #1
0
 def testBuildTensorInfoOp(self):
     x = constant_op.constant(1, name="x")
     y = constant_op.constant(2, name="y")
     z = control_flow_ops.group([x, y], name="op_z")
     z_op_info = utils.build_tensor_info_from_op(z)
     self.assertEqual("op_z", z_op_info.name)
     self.assertEqual(types_pb2.DT_INVALID, z_op_info.dtype)
     self.assertEqual(0, len(z_op_info.tensor_shape.dim))
Exemple #2
0
 def testBuildTensorInfoOp(self):
   x = constant_op.constant(1, name="x")
   y = constant_op.constant(2, name="y")
   z = control_flow_ops.group([x, y], name="op_z")
   z_op_info = utils.build_tensor_info_from_op(z)
   self.assertEqual("op_z", z_op_info.name)
   self.assertEqual(types_pb2.DT_INVALID, z_op_info.dtype)
   self.assertEqual(0, len(z_op_info.tensor_shape.dim))
Exemple #3
0
    def testBuildTensorInfoDefunOp(self):
        @function.defun
        def my_init_fn(x, y):
            self.x_var = x
            self.y_var = y

        x = constant_op.constant(1, name="x")
        y = constant_op.constant(2, name="y")
        init_op_info = utils.build_tensor_info_from_op(my_init_fn(x, y))
        self.assertEqual("PartitionedCall", init_op_info.name)
        self.assertEqual(types_pb2.DT_INVALID, init_op_info.dtype)
        self.assertEqual(0, len(init_op_info.tensor_shape.dim))
Exemple #4
0
  def testBuildTensorInfoDefunOp(self):
    @function.defun
    def my_init_fn(x, y):
      self.x_var = x
      self.y_var = y

    x = constant_op.constant(1, name="x")
    y = constant_op.constant(2, name="y")
    init_op_info = utils.build_tensor_info_from_op(my_init_fn(x, y))
    self.assertEqual("PartitionedFunctionCall", init_op_info.name)
    self.assertEqual(types_pb2.DT_INVALID, init_op_info.dtype)
    self.assertEqual(0, len(init_op_info.tensor_shape.dim))