Exemplo n.º 1
0
 def testElementShape(self):
   with self.cached_session() as sess, self.test_scope():
     dim = array_ops.placeholder(dtypes.int32)
     l = list_ops.tensor_list_reserve(
         element_shape=(dim, 15), num_elements=20,
         element_dtype=dtypes.float32)
     e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32)
     e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64)
     self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15))
     self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15))
Exemplo n.º 2
0
 def testElementShape(self):
   with self.cached_session() as sess, self.test_scope():
     dim = array_ops.placeholder(dtypes.int32)
     l = list_ops.tensor_list_reserve(
         element_shape=(dim, 15), num_elements=20,
         element_dtype=dtypes.float32)
     e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32)
     e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64)
     self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15))
     self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15))
Exemplo n.º 3
0
 def testSerializeListWithUnknownRank(self):
   worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0]
   with ops.Graph().as_default(), session.Session(target=worker.target):
     with ops.device("/job:worker"):
       t = constant_op.constant([[1.0], [2.0]])
       l = list_ops.tensor_list_from_tensor(t, element_shape=None)
     with ops.device("/job:ps"):
       l_ps = array_ops.identity(l)
       element_shape = list_ops.tensor_list_element_shape(
           l_ps, shape_type=dtypes.int32)
     with ops.device("/job:worker"):
       element_shape = array_ops.identity(element_shape)
     self.assertEqual(self.evaluate(element_shape), -1)
Exemplo n.º 4
0
 def testSerializeListWithUnknownRank(self):
     worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0]
     with ops.Graph().as_default(), session.Session(target=worker.target):
         with ops.device("/job:worker"):
             t = constant_op.constant([[1.0], [2.0]])
             l = list_ops.tensor_list_from_tensor(t, element_shape=-1)
         with ops.device("/job:ps"):
             l_ps = array_ops.identity(l)
             element_shape = list_ops.tensor_list_element_shape(
                 l_ps, shape_type=dtypes.int32)
         with ops.device("/job:worker"):
             element_shape = array_ops.identity(element_shape)
         self.assertEqual(self.evaluate(element_shape), -1)
Exemplo n.º 5
0
 def testElementShape(self):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32, element_shape=None)
   shape = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32)
   self.assertEqual(self.evaluate(shape), -1)
Exemplo n.º 6
0
 def testElementShape(self):
     l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                    element_shape=-1)
     shape = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32)
     self.assertEqual(self.evaluate(shape), -1)