def testStackEmptyList(self, max_num_elements): # Should be able to stack empty lists with fully defined element_shape. l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[1, 2], max_num_elements=max_num_elements) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t).shape, (0, 1, 2)) # Should not be able to stack empty lists with partially defined # element_shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "non-fully-defined"): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[None, 2], max_num_elements=max_num_elements) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.evaluate(t) # Should not be able to stack empty lists with undefined element_shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "non-fully-defined"): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None, max_num_elements=max_num_elements) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.evaluate(t)
def testStackEmptyList(self, max_num_elements): # Should be able to stack empty lists with fully defined element_shape. l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=[1, 2], max_num_elements=max_num_elements) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t).shape, (0, 1, 2)) # Should not be able to stack empty lists with partially defined # element_shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "non-fully-defined"): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=[-1, 2], max_num_elements=max_num_elements) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.evaluate(t) # Should not be able to stack empty lists with undefined element_shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "non-fully-defined"): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=-1, max_num_elements=max_num_elements) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.evaluate(t)
def testConcat(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) l1 = list_ops.tensor_list_from_tensor([-1.0], element_shape=scalar_shape()) l_batch_0 = array_ops.stack([l0, l1]) l_batch_1 = array_ops.stack([l1, l0]) l_concat_01 = list_ops.tensor_list_concat_lists( l_batch_0, l_batch_1, element_dtype=dtypes.float32) l_concat_10 = list_ops.tensor_list_concat_lists( l_batch_1, l_batch_0, element_dtype=dtypes.float32) l_concat_00 = list_ops.tensor_list_concat_lists( l_batch_0, l_batch_0, element_dtype=dtypes.float32) l_concat_11 = list_ops.tensor_list_concat_lists( l_batch_1, l_batch_1, element_dtype=dtypes.float32) expected_00 = [[1.0, 2.0, 1.0, 2.0], [-1.0, -1.0]] expected_01 = [[1.0, 2.0, -1.0], [-1.0, 1.0, 2.0]] expected_10 = [[-1.0, 1.0, 2.0], [1.0, 2.0, -1.0]] expected_11 = [[-1.0, -1.0], [1.0, 2.0, 1.0, 2.0]] for i, (concat, expected) in enumerate(zip( [l_concat_00, l_concat_01, l_concat_10, l_concat_11], [expected_00, expected_01, expected_10, expected_11])): splitted = array_ops.unstack(concat) splitted_stacked_ret = self.evaluate( (list_ops.tensor_list_stack(splitted[0], dtypes.float32), list_ops.tensor_list_stack(splitted[1], dtypes.float32))) print("Test concat %d: %s, %s, %s, %s" % (i, expected[0], splitted_stacked_ret[0], expected[1], splitted_stacked_ret[1])) self.assertAllClose(expected[0], splitted_stacked_ret[0]) self.assertAllClose(expected[1], splitted_stacked_ret[1]) # Concatenating mismatched shapes fails. with self.assertRaises((errors.InvalidArgumentError, ValueError)): self.evaluate( list_ops.tensor_list_concat_lists( l_batch_0, list_ops.empty_tensor_list(scalar_shape(), dtypes.float32), element_dtype=dtypes.float32)) with self.assertRaisesRegexp(errors.InvalidArgumentError, "element shapes are not identical at index 0"): l_batch_of_vec_tls = array_ops.stack( [list_ops.tensor_list_from_tensor([[1.0]], element_shape=[1])] * 2) self.evaluate( list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_vec_tls, element_dtype=dtypes.float32)) with self.assertRaisesRegexp(errors.InvalidArgumentError, r"input_b\[0\].dtype != element_dtype."): l_batch_of_int_tls = array_ops.stack( [list_ops.tensor_list_from_tensor([1], element_shape=scalar_shape())] * 2) self.evaluate( list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_int_tls, element_dtype=dtypes.float32))
def testAddTensorListsFailsIfLeadingDimsMismatch(self): with self.cached_session(), self.test_scope(): l1 = list_ops.tensor_list_reserve( element_shape=[], element_dtype=dtypes.float32, num_elements=2) l2 = list_ops.tensor_list_reserve( element_shape=[], element_dtype=dtypes.float32, num_elements=3) l = math_ops.add_n([l1, l2]) with self.assertRaisesRegexp( errors.InvalidArgumentError, "TensorList arguments to AddN must all have the same shape"): list_ops.tensor_list_stack(l, element_dtype=dtypes.float32).eval()
def testAddTensorListsFailsIfLeadingDimsMismatch(self): with self.session(), self.test_scope(): l1 = list_ops.tensor_list_reserve( element_shape=[], element_dtype=dtypes.float32, num_elements=2) l2 = list_ops.tensor_list_reserve( element_shape=[], element_dtype=dtypes.float32, num_elements=3) l = math_ops.add_n([l1, l2]) with self.assertRaisesRegex( errors.InvalidArgumentError, "TensorList arguments to AddN must all have the same shape"): list_ops.tensor_list_stack(l, element_dtype=dtypes.float32).eval()
def test_tensor_list_empty_list(self): l = special_functions.tensor_list([], element_dtype=dtypes.int32, element_shape=()) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(self.evaluate(sl), []) l = special_functions.tensor_list((), element_dtype=dtypes.int32, element_shape=()) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(self.evaluate(sl), [])
def test_tensor_list_empty_list(self): l = special_functions.tensor_list([], element_dtype=dtypes.int32, element_shape=()) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(sess.run(sl), []) l = special_functions.tensor_list((), element_dtype=dtypes.int32, element_shape=()) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(sess.run(sl), [])
def testStackWithPartiallyDefinedElementShape(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[-1]) l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0])) l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0])) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [[1.0], [2.0]]) # Should raise an error when the element tensors do not all have the same # shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"): l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0, 3.0])) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.evaluate(t)
def test_tf_tensor_list_new_empty(self): l = data_structures.tf_tensor_list_new([], element_dtype=dtypes.int32, element_shape=()) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(sess.run(t), [])
def _tf_tensor_list_stack(list_, opts): """Overload of list_stack that stages a Tensor list write.""" if opts.element_dtype is None: raise ValueError( 'cannot stack a list without knowing its element type;' ' use set_element_type to annotate it') return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype)
def testStackWithUninitializedTensors(self): with self.cached_session(), self.test_scope(): l = list_ops.tensor_list_reserve(element_dtype=dtypes.float32, element_shape=[], num_elements=3) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [0., 0., 0.])
def testSetStackReservedUnknownElementShape(self): with self.cached_session(), self.test_scope(): l = list_ops.tensor_list_reserve( element_dtype=dtypes.float32, element_shape=None, num_elements=2) l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0]) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]])
def test_tensor_list_from_elements(self): elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])] l = special_functions.tensor_list(elements) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.cached_session() as sess: self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
def testAddN(self): l1 = list_ops.tensor_list_from_tensor([1.0, 2.0], element_shape=[]) l2 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[]) l3 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[]) result = math_ops.add_n((l1, l2, l3)) result_t = list_ops.tensor_list_stack(result, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(result_t), [9., 12.])
def testStack(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [1.0, 2.0])
def testPruning(self): x = constant_op.constant(1) tensor_list = list_ops.empty_tensor_list( element_dtype=x.dtype, element_shape=x.shape) def Cond(x, tl): del tl # Unused for Cond. return x < 5 def Body(x, tl): return x + 1, list_ops.tensor_list_push_back(tl, x) outputs = while_loop_v1(Cond, Body, [x, tensor_list]) train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(outputs[0]) def GetOptimizedGraph(): mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) rewriter_config = rewriter_config_pb2.RewriterConfig( constant_folding=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) return tf_optimizer.OptimizeGraph(rewriter_config, mg) g = GetOptimizedGraph() self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1) stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) train_op.append(stack) g = GetOptimizedGraph() self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
def test_list_pop(self): def test_fn(): l = [1, 2, 3] utils.set_element_type(l, dtypes.int32, ()) s = l.pop() return s, l node = self.parse_and_analyze( test_fn, { 'utils': utils, 'dtypes': dtypes }, include_type_analysis=True, ) node = lists.transform(node, self.ctx) with self.compiled(node) as result: result.utils = utils result.dtypes = dtypes with self.test_session() as sess: ts, tl = result.test_fn() r = list_ops.tensor_list_stack(tl, dtypes.int32) self.assertAllEqual(sess.run(r), [1, 2]) self.assertAllEqual(sess.run(ts), 3)
def testPruning(self): x = constant_op.constant(1) tensor_list = list_ops.empty_tensor_list(element_dtype=x.dtype, element_shape=x.shape) def Cond(x, tl): del tl # Unused for Cond. return x < 5 def Body(x, tl): return x + 1, list_ops.tensor_list_push_back(tl, x) outputs = while_loop_v1(Cond, Body, [x, tensor_list]) train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(outputs[0]) def GetOptimizedGraph(): mg = meta_graph.create_meta_graph_def( graph=ops.get_default_graph()) rewriter_config = rewriter_config_pb2.RewriterConfig( constant_folding=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) return tf_optimizer.OptimizeGraph(rewriter_config, mg) g = GetOptimizedGraph() self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1) stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) train_op.append(stack) g = GetOptimizedGraph() self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
def build_graph(parameters): """Build the TensorListSetItem op testing graph.""" item = tf.placeholder(dtype=parameters["element_dtype"], shape=parameters["element_shape"]) tensor_list = list_ops.tensor_list_reserve( element_shape=None, num_elements=parameters["num_elements"], element_dtype=parameters["element_dtype"]) init_state = (0, tensor_list) condition = lambda i, _: i < parameters["num_elements"] def loop_body(i, tensor_list): new_item = tf.add( tf.add(item, item), tf.constant(value=1, dtype=parameters["element_dtype"])) new_list = list_ops.tensor_list_set_item(tensor_list, i, new_item) return i + 1, new_list _, tensor_list = tf.while_loop(condition, loop_body, init_state) out = list_ops.tensor_list_stack( tensor_list, num_elements=parameters["num_elements"], element_dtype=parameters["element_dtype"]) return [item], [out]
def testStack(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
def test_tensor_list_from_elements(self): elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])] l = special_functions.tensor_list(elements) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.test_session() as sess: self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def testZerosLike(self): l_empty = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) l_empty_zeros = array_ops.zeros_like(l_empty) t_empty_zeros = list_ops.tensor_list_stack( l_empty_zeros, element_dtype=dtypes.float32) l_full = list_ops.tensor_list_push_back(l_empty, constant_op.constant(1.0)) l_full = list_ops.tensor_list_push_back(l_full, constant_op.constant(2.0)) l_full_zeros = array_ops.zeros_like(l_full) t_full_zeros = list_ops.tensor_list_stack(l_full_zeros, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t_empty_zeros), []) self.assertAllEqual(self.evaluate(t_full_zeros), [0.0, 0.0])
def stack(self, name=None): """See TensorArray.""" with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]): value = list_ops.tensor_list_stack( input_handle=self._flow, element_dtype=self._dtype) if self._element_shape and self._element_shape[0].dims is not None: value.set_shape([None] + self._element_shape[0].dims) return value
def test_append_tensor_list(self): l = data_structures.new_list() x = constant_op.constant([1, 2, 3]) l = data_structures.list_append(l, x) t = list_ops.tensor_list_stack(l, element_dtype=x.dtype) with self.test_session() as sess: self.assertAllEqual(sess.run(t), [[1, 2, 3]])
def testStackWithUnknownElementShape(self, max_num_elements): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None, max_num_elements=max_num_elements) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [1.0, 2.0]) # Should raise an error when the element tensors do not all have the same # shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"): l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0])) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.evaluate(t)
def testGetSetItem(self): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape()) e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e0), 1.0) l = list_ops.tensor_list_set_item(l, 0, 3.0) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [3.0, 2.0])
def testGetSetItem(self): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=[]) e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e0), 1.0) l = list_ops.tensor_list_set_item(l, 0, 3.0) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [3.0, 2.0])
def test_initialized_list(self): def test_fn(): return [1, 2, 3] with self.converted(test_fn, lists, {}) as result: with self.test_session() as sess: tl = result.test_fn() r = list_ops.tensor_list_stack(tl, dtypes.int32) self.assertAllEqual(sess.run(r), [1, 2, 3])
def test_set_item_tensor_list(self): initial_list = constant_op.constant([[1, 2], [3, 4]]) elem_shape = constant_op.constant([2]) l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) l = slices.set_item(l, 0, [5, 6]) with self.cached_session() as sess: t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) self.assertAllEqual(self.evaluate(t), [[5, 6], [3, 4]])
def testStackFromTensorGradients(self): with backprop.GradientTape() as tape: c = constant_op.constant([1.0, 2.0]) tape.watch(c) l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) c2 = list_ops.tensor_list_stack( l, element_dtype=dtypes.float32) result = c2 * 2.0 self.assertAllEqual(tape.gradient(result, [c])[0], [2.0, 2.0])
def testGetSet(self): with self.cached_session(), self.test_scope(): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=[]) e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) self.assertAllEqual(e0, 1.0) l = list_ops.tensor_list_set_item(l, 0, 3.0) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [3.0, 2.0])
def _testStack(self, max_num_elements): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=scalar_shape(), max_num_elements=max_num_elements) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
def testGetSetReserved(self): with self.cached_session(), self.test_scope(): l = list_ops.tensor_list_reserve( element_dtype=dtypes.float32, element_shape=[], num_elements=2) e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) self.assertAllEqual(e0, 0.0) l = list_ops.tensor_list_set_item(l, 0, 3.0) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [3.0, 0.0])
def testStackFromTensorGradients(self): with backprop.GradientTape() as tape: c = constant_op.constant([1.0, 2.0]) tape.watch(c) l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) c2 = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) result = c2 * 2.0 self.assertAllEqual(tape.gradient(result, [c])[0], [2.0, 2.0])
def testGraphStack(self): with context.graph_mode(), self.test_session(): tl = list_ops.empty_tensor_list( element_shape=constant_op.constant([1], dtype=dtypes.int32), element_dtype=dtypes.int32) tl = list_ops.tensor_list_push_back(tl, [1]) self.assertAllEqual( list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32).eval(), [[1]])
def testSetDoesNotUpdatePushIndex(self): with self.cached_session(), self.test_scope(): l = list_ops.empty_tensor_list( element_shape=[], element_dtype=dtypes.float32, max_num_elements=2) # SetItem should not change the push index. l = list_ops.tensor_list_set_item(l, 1, 3.) l = list_ops.tensor_list_push_back(l, 5.) l = list_ops.tensor_list_push_back(l, 7.) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [5., 7.])
def testStackFromTensorGradients(self): with backprop.GradientTape() as tape: c = constant_op.constant([1.0, 2.0]) tape.watch(c) l = list_ops.tensor_list_from_tensor(c, element_shape=[]) c2 = list_ops.tensor_list_stack( l, element_dtype=dtypes.float32, num_elements=2) result = c2 * 2.0 grad = tape.gradient(result, [c])[0] self.assertAllEqual(self.evaluate(grad), [2.0, 2.0])
def testGraphStack(self): with self.cached_session(): tl = list_ops.empty_tensor_list( element_shape=constant_op.constant([1], dtype=dtypes.int32), element_dtype=dtypes.int32) tl = list_ops.tensor_list_push_back(tl, [1]) self.assertAllEqual( self.evaluate( list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)), [[1]])
def testZerosLikeForTensorList(self): with self.session(), self.test_scope(): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=[], max_num_elements=2) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) z = array_ops.zeros_like(l) z = list_ops.tensor_list_stack(z, element_dtype=dtypes.float32) self.assertAllEqual(z.shape.as_list(), [None]) self.assertAllEqual(z, [0.0, 0.0])
def testPushInEmptyListWithUnknownElementShape(self): with self.cached_session(), self.test_scope(): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None, max_num_elements=2) l = list_ops.tensor_list_push_back(l, [3.0, 4.0]) # Pushing an element with a different shape should raise an error. with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"): l = list_ops.tensor_list_push_back(l, 5.) self.evaluate( list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
def testPushInEmptyListWithUnknownElementShape(self): with self.cached_session(), self.test_scope(): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None, max_num_elements=2) l = list_ops.tensor_list_push_back(l, [3.0, 4.0]) # Pushing an element with a different shape should raise an error. with self.assertRaisesRegexp(errors.InternalError, "shape"): l = list_ops.tensor_list_push_back(l, 5.) self.evaluate( list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
def testPushBackBatch(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) l1 = list_ops.tensor_list_from_tensor([-1.0], element_shape=scalar_shape()) l_batch = array_ops.stack([l0, l1]) l_push = list_ops.tensor_list_push_back_batch(l_batch, [3.0, 4.0]) l_unstack = array_ops.unstack(l_push) l0_ret = list_ops.tensor_list_stack(l_unstack[0], dtypes.float32) l1_ret = list_ops.tensor_list_stack(l_unstack[1], dtypes.float32) self.assertAllClose([1.0, 2.0, 3.0], self.evaluate(l0_ret)) self.assertAllClose([-1.0, 4.0], self.evaluate(l1_ret)) with ops.control_dependencies([l_push]): l_unstack_orig = array_ops.unstack(l_batch) l0_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[0], dtypes.float32) l1_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[1], dtypes.float32) # Check that without aliasing, push_back_batch still works; and # that it doesn't modify the input. l0_r_v, l1_r_v, l0_orig_v, l1_orig_v = self.evaluate( (l0_ret, l1_ret, l0_orig_ret, l1_orig_ret)) self.assertAllClose([1.0, 2.0, 3.0], l0_r_v) self.assertAllClose([-1.0, 4.0], l1_r_v) self.assertAllClose([1.0, 2.0], l0_orig_v) self.assertAllClose([-1.0], l1_orig_v) # Pushing back mismatched shapes fails. with self.assertRaises((errors.InvalidArgumentError, ValueError)): self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, [])) with self.assertRaisesRegexp( errors.InvalidArgumentError, "incompatible shape to a list at index 0"): self.evaluate( list_ops.tensor_list_push_back_batch(l_batch, [[3.0], [4.0]])) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Invalid data type at index 0"): self.evaluate(list_ops.tensor_list_push_back_batch( l_batch, [3, 4]))
def testZerosLikeForTensorList(self): with self.cached_session(), self.test_scope(): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[], max_num_elements=2) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) z = array_ops.zeros_like(l) z = list_ops.tensor_list_stack(z, element_dtype=dtypes.float32) self.assertAllEqual(z.shape.as_list(), [None]) self.assertAllEqual(z, [0.0, 0.0])
def testResourceVariableScatterGather(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l = list_ops.tensor_list_from_tensor(c, element_shape=[]) v = vs.get_variable("var", initializer=[l] * 10, use_resource=True) v_r_0_stacked = list_ops.tensor_list_stack(v[0], dtypes.float32) self.evaluate(v.initializer) self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_0_stacked)) v_r_sparse_stacked = list_ops.tensor_list_stack( v.sparse_read(0), dtypes.float32) self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_sparse_stacked)) l_new_0 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[]) l_new_1 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[]) updated_v = state_ops.scatter_update(v, [3, 5], [l_new_0, l_new_1]) updated_v_elems = array_ops.unstack(updated_v) updated_v_stacked = [ list_ops.tensor_list_stack(el, dtypes.float32) for el in updated_v_elems ] expected = ([[1.0, 2.0]] * 3 + [[3.0, 4.0], [1.0, 2.0], [5.0, 6.0]] + [[1.0, 2.0]] * 4) self.assertAllEqual(self.evaluate(updated_v_stacked), expected)
def _testStack(self, max_num_elements): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[], max_num_elements=max_num_elements) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) if not context.executing_eagerly(): self.assertAllEqual(t.shape.as_list(), [None]) self.assertAllEqual(self.evaluate(t), [1.0, 2.0])