def testGatherEmptyList(self, max_num_elements): # Should be able to gather from empty lists with fully defined # element_shape. l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[1, 2], max_num_elements=max_num_elements) t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32) self.assertAllEqual((0, 1, 2), self.evaluate(t).shape) # Should not be able to gather from empty lists with partially defined # element_shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "non-fully-defined"): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[None, 2], max_num_elements=max_num_elements) t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32) self.evaluate(t) # Should not be able to gather from empty lists with undefined # element_shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "non-fully-defined"): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None, max_num_elements=max_num_elements) t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32) self.evaluate(t)
def testGatherEmptyList(self, max_num_elements): # Should be able to gather from empty lists with fully defined # element_shape. l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=[1, 2], max_num_elements=max_num_elements) t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32) self.assertAllEqual((0, 1, 2), self.evaluate(t).shape) # Should not be able to gather from empty lists with partially defined # element_shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "non-fully-defined"): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=[-1, 2], max_num_elements=max_num_elements) t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32) self.evaluate(t) # Should not be able to gather from empty lists with undefined # element_shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "non-fully-defined"): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=-1, max_num_elements=max_num_elements) t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32) self.evaluate(t)
def testZerosLikeUninitialized(self): l0 = list_ops.tensor_list_reserve([], 3, element_dtype=dtypes.float32) l1 = list_ops.tensor_list_set_item(l0, 0, 1.) # [1., _, _] zeros_1 = array_ops.zeros_like(l1) # [0., _, _] l2 = list_ops.tensor_list_set_item(l1, 2, 2.) # [1., _, 2.] zeros_2 = array_ops.zeros_like(l2) # [0., _, 0.] # Gather indices with zeros in `zeros_1`. res_1 = list_ops.tensor_list_gather( zeros_1, [0], element_dtype=dtypes.float32) # Gather indices with zeros in `zeros_2`. res_2 = list_ops.tensor_list_gather( zeros_2, [0, 2], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(res_1), [0.]) self.assertAllEqual(self.evaluate(res_2), [0., 0.])
def testGather(self, input_list, element_shape, indices, output): with self.session(), self.test_scope(): tensor_list = list_ops.tensor_list_from_tensor( input_list, element_shape=element_shape) gather_t = list_ops.tensor_list_gather( tensor_list, indices, element_dtype=dtypes.float32) self.assertAllEqual(gather_t, output)
def testZerosLikeUninitialized(self): l0 = list_ops.tensor_list_reserve([], 3, element_dtype=dtypes.float32) l1 = list_ops.tensor_list_set_item(l0, 0, 1.) # [1., _, _] zeros_1 = array_ops.zeros_like(l1) # [0., _, _] l2 = list_ops.tensor_list_set_item(l1, 2, 2.) # [1., _, 2.] zeros_2 = array_ops.zeros_like(l2) # [0., _, 0.] # Gather indices with zeros in `zeros_1`. res_1 = list_ops.tensor_list_gather( zeros_1, [0], element_dtype=dtypes.float32) # Gather indices with zeros in `zeros_2`. res_2 = list_ops.tensor_list_gather( zeros_2, [0, 2], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(res_1), [0.]) self.assertAllEqual(self.evaluate(res_2), [0., 0.])
def testGatherWithPartiallyDefinedElementShape(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[-1]) l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0])) l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0, 3.0])) l = list_ops.tensor_list_push_back(l, constant_op.constant([4.0, 5.0])) t = list_ops.tensor_list_gather(l, [0], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [[1.0]]) t = list_ops.tensor_list_gather(l, [1, 2], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [[2.0, 3.0], [4.0, 5.0]]) # Should raise an error when the requested tensors do not all have the same # shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"): t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32) self.evaluate(t)
def gather(self, indices, name=None): """See TensorArray.""" value = list_ops.tensor_list_gather(input_handle=self._flow, indices=indices, element_dtype=self._dtype, name=name) if self._element_shape and self._element_shape[0].dims is not None: value.set_shape([None] + self._element_shape[0].dims) return value
def gather(self, indices, name=None): """See TensorArray.""" value = list_ops.tensor_list_gather( input_handle=self._flow, indices=indices, element_dtype=self._dtype, name=name) if self._element_shape and self._element_shape[0].dims is not None: value.set_shape([None] + self._element_shape[0].dims) return value
def testGatherWithUnknownElementShape(self, max_num_elements): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None, max_num_elements=max_num_elements) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0])) t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [2.0, 1.0]) t = list_ops.tensor_list_gather(l, [2], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [[3.0, 4.0]]) # Should raise an error when the requested tensors do not all have the same # shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"): t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32) self.evaluate(t)
def testGatherWithUnknownElementShape(self, max_num_elements): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None, max_num_elements=max_num_elements) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0])) t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [2.0, 1.0]) t = list_ops.tensor_list_gather(l, [2], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [[3.0, 4.0]]) # Should raise an error when the requested tensors do not all have the same # shape. with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"): t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32) self.evaluate(t)
def gather(self, indices, name=None): """See TensorArray.""" if self._element_shape: element_shape = self._element_shape[0] else: element_shape = tensor_shape.unknown_shape(None) value = list_ops.tensor_list_gather( input_handle=self._flow, indices=indices, element_dtype=self._dtype, element_shape=element_shape, name=name) return value
def testGatherGrad(self): with backprop.GradientTape() as tape: l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) c0 = constant_op.constant(1.0) tape.watch(c0) l = list_ops.tensor_list_push_back(l, c0) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [2.0, 1.0]) s = (t[0] + t[1]) * (t[0] + t[1]) dt = tape.gradient(s, c0) self.assertAllEqual(self.evaluate(dt), 6.0)
def testGatherGrad(self): with backprop.GradientTape() as tape: l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) c0 = constant_op.constant(1.0) tape.watch(c0) l = list_ops.tensor_list_push_back(l, c0) l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [2.0, 1.0]) s = (t[0] + t[1]) * (t[0] + t[1]) dt = tape.gradient(s, c0) self.assertAllEqual(self.evaluate(dt), 6.0)