def test_tensor_list_empty_list(self): l = special_functions.tensor_list([], element_dtype=dtypes.int32, element_shape=()) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(self.evaluate(sl), []) l = special_functions.tensor_list((), element_dtype=dtypes.int32, element_shape=()) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(self.evaluate(sl), [])
def test_tensor_list_empty_list(self): l = special_functions.tensor_list([], element_dtype=dtypes.int32, element_shape=()) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(sess.run(sl), []) l = special_functions.tensor_list((), element_dtype=dtypes.int32, element_shape=()) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(sess.run(sl), [])
def test_tensor_list_array_from_elements(self): elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])] l = special_functions.tensor_list(elements, use_tensor_array=True) sl = l.stack() with self.cached_session() as sess: self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
def test_tensor_list_array_from_elements(self): elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])] l = special_functions.tensor_list(elements, use_tensor_array=True) sl = l.stack() with self.test_session() as sess: self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_tensor_list_from_elements(self): elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])] l = special_functions.tensor_list(elements) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.test_session() as sess: self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_tensor_list_from_elements(self): elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])] l = special_functions.tensor_list(elements) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
def test_tensor_list_empty_list_no_type(self): with self.assertRaisesRegexp( ValueError, 'element_dtype and element_shape are required'): special_functions.tensor_list([])
def test_tensor_list_unsupported_initializer(self): with self.assertRaisesRegexp(ValueError, 'unknown type'): special_functions.tensor_list(np.array([1, 2, 3]))
def test_tensor_list_tensor(self): l = special_functions.tensor_list( constant_op.constant([], dtype=dtypes.int32)) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(self.evaluate(sl), [])
def test_fn(): l = special_functions.tensor_list([1, 2, 3]) s = l.pop() return s, l
def test_fn(): l = special_functions.tensor_list([1]) l.append(2) l.append(3) return l
def test_tensor_list_tensor(self): l = special_functions.tensor_list( constant_op.constant([], dtype=dtypes.int32)) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(sess.run(sl), [])
def f(): l = special_functions.tensor_list([1, 2, 3]) directives.set_element_type(l, dtype=dtypes.int32, shape=()) s = l.pop() return s, l