Esempio n. 1
0
  def testGatherEmptyList(self, max_num_elements):
    # Should be able to gather from empty lists with fully defined
    # element_shape.
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[1, 2],
        max_num_elements=max_num_elements)
    t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32)
    self.assertAllEqual((0, 1, 2), self.evaluate(t).shape)

    # Should not be able to gather from empty lists with partially defined
    # element_shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=[None, 2],
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32)
      self.evaluate(t)

    # Should not be able to gather from empty lists with undefined
    # element_shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=None,
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32)
      self.evaluate(t)
Esempio n. 2
0
    def testGatherEmptyList(self, max_num_elements):
        # Should be able to gather from empty lists with fully defined
        # element_shape.
        l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                       element_shape=[1, 2],
                                       max_num_elements=max_num_elements)
        t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32)
        self.assertAllEqual((0, 1, 2), self.evaluate(t).shape)

        # Should not be able to gather from empty lists with partially defined
        # element_shape.
        with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                     "non-fully-defined"):
            l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                           element_shape=[-1, 2],
                                           max_num_elements=max_num_elements)
            t = list_ops.tensor_list_gather(l, [],
                                            element_dtype=dtypes.float32)
            self.evaluate(t)

        # Should not be able to gather from empty lists with undefined
        # element_shape.
        with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                     "non-fully-defined"):
            l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                           element_shape=-1,
                                           max_num_elements=max_num_elements)
            t = list_ops.tensor_list_gather(l, [],
                                            element_dtype=dtypes.float32)
            self.evaluate(t)
Esempio n. 3
0
  def testZerosLikeUninitialized(self):
    l0 = list_ops.tensor_list_reserve([], 3, element_dtype=dtypes.float32)
    l1 = list_ops.tensor_list_set_item(l0, 0, 1.)  # [1., _, _]
    zeros_1 = array_ops.zeros_like(l1)  # [0., _, _]
    l2 = list_ops.tensor_list_set_item(l1, 2, 2.)  # [1., _, 2.]
    zeros_2 = array_ops.zeros_like(l2)  # [0., _, 0.]

    # Gather indices with zeros in `zeros_1`.
    res_1 = list_ops.tensor_list_gather(
        zeros_1, [0], element_dtype=dtypes.float32)
    # Gather indices with zeros in `zeros_2`.
    res_2 = list_ops.tensor_list_gather(
        zeros_2, [0, 2], element_dtype=dtypes.float32)

    self.assertAllEqual(self.evaluate(res_1), [0.])
    self.assertAllEqual(self.evaluate(res_2), [0., 0.])
 def testGather(self, input_list, element_shape, indices, output):
     with self.session(), self.test_scope():
         tensor_list = list_ops.tensor_list_from_tensor(
             input_list, element_shape=element_shape)
         gather_t = list_ops.tensor_list_gather(
             tensor_list, indices, element_dtype=dtypes.float32)
         self.assertAllEqual(gather_t, output)
Esempio n. 5
0
  def testZerosLikeUninitialized(self):
    l0 = list_ops.tensor_list_reserve([], 3, element_dtype=dtypes.float32)
    l1 = list_ops.tensor_list_set_item(l0, 0, 1.)  # [1., _, _]
    zeros_1 = array_ops.zeros_like(l1)  # [0., _, _]
    l2 = list_ops.tensor_list_set_item(l1, 2, 2.)  # [1., _, 2.]
    zeros_2 = array_ops.zeros_like(l2)  # [0., _, 0.]

    # Gather indices with zeros in `zeros_1`.
    res_1 = list_ops.tensor_list_gather(
        zeros_1, [0], element_dtype=dtypes.float32)
    # Gather indices with zeros in `zeros_2`.
    res_2 = list_ops.tensor_list_gather(
        zeros_2, [0, 2], element_dtype=dtypes.float32)

    self.assertAllEqual(self.evaluate(res_1), [0.])
    self.assertAllEqual(self.evaluate(res_2), [0., 0.])
Esempio n. 6
0
  def testGatherWithPartiallyDefinedElementShape(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=[-1])
    l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0]))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0, 3.0]))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([4.0, 5.0]))

    t = list_ops.tensor_list_gather(l, [0], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[1.0]])

    t = list_ops.tensor_list_gather(l, [1, 2], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[2.0, 3.0], [4.0, 5.0]])

    # Should raise an error when the requested tensors do not all have the same
    # shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"):
      t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32)
      self.evaluate(t)
Esempio n. 7
0
 def gather(self, indices, name=None):
     """See TensorArray."""
     value = list_ops.tensor_list_gather(input_handle=self._flow,
                                         indices=indices,
                                         element_dtype=self._dtype,
                                         name=name)
     if self._element_shape and self._element_shape[0].dims is not None:
         value.set_shape([None] + self._element_shape[0].dims)
     return value
 def gather(self, indices, name=None):
   """See TensorArray."""
   value = list_ops.tensor_list_gather(
       input_handle=self._flow,
       indices=indices,
       element_dtype=self._dtype,
       name=name)
   if self._element_shape and self._element_shape[0].dims is not None:
     value.set_shape([None] + self._element_shape[0].dims)
   return value
Esempio n. 9
0
  def testGatherWithUnknownElementShape(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=None,
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0]))

    t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [2.0, 1.0])

    t = list_ops.tensor_list_gather(l, [2], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[3.0, 4.0]])

    # Should raise an error when the requested tensors do not all have the same
    # shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"):
      t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32)
      self.evaluate(t)
Esempio n. 10
0
  def testGatherWithUnknownElementShape(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=None,
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0]))

    t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [2.0, 1.0])

    t = list_ops.tensor_list_gather(l, [2], element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[3.0, 4.0]])

    # Should raise an error when the requested tensors do not all have the same
    # shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"):
      t = list_ops.tensor_list_gather(l, [0, 2], element_dtype=dtypes.float32)
      self.evaluate(t)
Esempio n. 11
0
 def gather(self, indices, name=None):
   """See TensorArray."""
   if self._element_shape:
     element_shape = self._element_shape[0]
   else:
     element_shape = tensor_shape.unknown_shape(None)
   value = list_ops.tensor_list_gather(
       input_handle=self._flow,
       indices=indices,
       element_dtype=self._dtype,
       element_shape=element_shape,
       name=name)
   return value
Esempio n. 12
0
 def testGatherGrad(self):
   with backprop.GradientTape() as tape:
     l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                    element_shape=scalar_shape())
     c0 = constant_op.constant(1.0)
     tape.watch(c0)
     l = list_ops.tensor_list_push_back(l, c0)
     l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
     t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32)
     self.assertAllEqual(self.evaluate(t), [2.0, 1.0])
     s = (t[0] + t[1]) * (t[0] + t[1])
   dt = tape.gradient(s, c0)
   self.assertAllEqual(self.evaluate(dt), 6.0)
Esempio n. 13
0
 def testGatherGrad(self):
   with backprop.GradientTape() as tape:
     l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                    element_shape=scalar_shape())
     c0 = constant_op.constant(1.0)
     tape.watch(c0)
     l = list_ops.tensor_list_push_back(l, c0)
     l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
     t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32)
     self.assertAllEqual(self.evaluate(t), [2.0, 1.0])
     s = (t[0] + t[1]) * (t[0] + t[1])
   dt = tape.gradient(s, c0)
   self.assertAllEqual(self.evaluate(dt), 6.0)