def testElementShape(self): with self.cached_session() as sess, self.test_scope(): dim = array_ops.placeholder(dtypes.int32) l = list_ops.tensor_list_reserve( element_shape=(dim, 15), num_elements=20, element_dtype=dtypes.float32) e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32) e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64) self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15)) self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15))
def testSerializeListWithUnknownRank(self): worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0] with ops.Graph().as_default(), session.Session(target=worker.target): with ops.device("/job:worker"): t = constant_op.constant([[1.0], [2.0]]) l = list_ops.tensor_list_from_tensor(t, element_shape=None) with ops.device("/job:ps"): l_ps = array_ops.identity(l) element_shape = list_ops.tensor_list_element_shape( l_ps, shape_type=dtypes.int32) with ops.device("/job:worker"): element_shape = array_ops.identity(element_shape) self.assertEqual(self.evaluate(element_shape), -1)
def testSerializeListWithUnknownRank(self): worker = test_util.create_local_cluster(num_workers=1, num_ps=1)[0][0] with ops.Graph().as_default(), session.Session(target=worker.target): with ops.device("/job:worker"): t = constant_op.constant([[1.0], [2.0]]) l = list_ops.tensor_list_from_tensor(t, element_shape=-1) with ops.device("/job:ps"): l_ps = array_ops.identity(l) element_shape = list_ops.tensor_list_element_shape( l_ps, shape_type=dtypes.int32) with ops.device("/job:worker"): element_shape = array_ops.identity(element_shape) self.assertEqual(self.evaluate(element_shape), -1)
def testElementShape(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None) shape = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32) self.assertEqual(self.evaluate(shape), -1)
def testElementShape(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=-1) shape = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32) self.assertEqual(self.evaluate(shape), -1)