def testPushBackBatch(self):
        c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
        l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
        l1 = list_ops.tensor_list_from_tensor([-1.0],
                                              element_shape=scalar_shape())
        l_batch = array_ops.stack([l0, l1])
        l_push = list_ops.tensor_list_push_back_batch(l_batch, [3.0, 4.0])
        l_unstack = array_ops.unstack(l_push)
        l0_ret = list_ops.tensor_list_stack(l_unstack[0], dtypes.float32)
        l1_ret = list_ops.tensor_list_stack(l_unstack[1], dtypes.float32)
        self.assertAllClose([1.0, 2.0, 3.0], self.evaluate(l0_ret))
        self.assertAllClose([-1.0, 4.0], self.evaluate(l1_ret))

        with ops.control_dependencies([l_push]):
            l_unstack_orig = array_ops.unstack(l_batch)
            l0_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[0],
                                                     dtypes.float32)
            l1_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[1],
                                                     dtypes.float32)

        # Check that without aliasing, push_back_batch still works; and
        # that it doesn't modify the input.
        l0_r_v, l1_r_v, l0_orig_v, l1_orig_v = self.evaluate(
            (l0_ret, l1_ret, l0_orig_ret, l1_orig_ret))
        self.assertAllClose([1.0, 2.0, 3.0], l0_r_v)
        self.assertAllClose([-1.0, 4.0], l1_r_v)
        self.assertAllClose([1.0, 2.0], l0_orig_v)
        self.assertAllClose([-1.0], l1_orig_v)

        # Pushing back mismatched shapes fails.
        with self.assertRaises((errors.InvalidArgumentError, ValueError)):
            self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, []))

        with self.assertRaisesRegexp(
                errors.InvalidArgumentError,
                "incompatible shape to a list at index 0"):
            self.evaluate(
                list_ops.tensor_list_push_back_batch(l_batch, [[3.0], [4.0]]))

        with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                     "Invalid data type at index 0"):
            self.evaluate(list_ops.tensor_list_push_back_batch(
                l_batch, [3, 4]))
Exemple #2
0
  def testPushBackBatch(self):
    c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
    l0 = list_ops.tensor_list_from_tensor(c, element_shape=[])
    l1 = list_ops.tensor_list_from_tensor([-1.0], element_shape=[])
    l_batch = array_ops.stack([l0, l1])
    l_push = list_ops.tensor_list_push_back_batch(l_batch, [3.0, 4.0])
    l_unstack = array_ops.unstack(l_push)
    l0_ret = list_ops.tensor_list_stack(l_unstack[0], dtypes.float32)
    l1_ret = list_ops.tensor_list_stack(l_unstack[1], dtypes.float32)
    self.assertAllClose([1.0, 2.0, 3.0], self.evaluate(l0_ret))
    self.assertAllClose([-1.0, 4.0], self.evaluate(l1_ret))

    with ops.control_dependencies([l_push]):
      l_unstack_orig = array_ops.unstack(l_batch)
      l0_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[0],
                                               dtypes.float32)
      l1_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[1],
                                               dtypes.float32)

    # Check that without aliasing, push_back_batch still works; and
    # that it doesn't modify the input.
    l0_r_v, l1_r_v, l0_orig_v, l1_orig_v = self.evaluate(
        (l0_ret, l1_ret, l0_orig_ret, l1_orig_ret))
    self.assertAllClose([1.0, 2.0, 3.0], l0_r_v)
    self.assertAllClose([-1.0, 4.0], l1_r_v)
    self.assertAllClose([1.0, 2.0], l0_orig_v)
    self.assertAllClose([-1.0], l1_orig_v)

    # Pushing back mismatched shapes fails.
    with self.assertRaises((errors.InvalidArgumentError, ValueError)):
      self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, []))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "incompatible shape to a list at index 0"):
      self.evaluate(
          list_ops.tensor_list_push_back_batch(l_batch, [[3.0], [4.0]]))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "Invalid data type at index 0"):
      self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, [3, 4]))
Exemple #3
0
    def add(self, rows, values):
        """Returns ops for appending a single frame value to the given rows.

    This operation is batch-aware.

    Args:
      rows: A list/tensor of location(s) to write values at.
      values: A nest of Tensors to write. If rows has more than one element,
        values can have an extra first dimension representing the batch size.
        Values must have the same structure as the tensor_spec of this class
        Must have batch dimension matching the number of rows.

    Returns:
      Ops for appending values at rows.
    """
        rows = tf.convert_to_tensor(value=rows, dtype=tf.int64)
        flattened_values = tf.nest.flatten(values)
        write_ops = []
        for slot, value in zip(self._flattened_slots, flattened_values):
            var_slots = self._slot2variable_map[slot].lookup(rows)
            new_value = list_ops.tensor_list_push_back_batch(var_slots, value)
            write_ops.append(self._slot2variable_map[slot].insert_or_assign(
                rows, new_value))
        return tf.group(*write_ops)