示例#1
0
    def append(self, row, values):
        """Returns ops for appending multiple time values at the given row.

    Args:
      row: A scalar location at which to append values.
      values: A nest of Tensors to append.  The outermost dimension of each
        tensor is treated as a time axis, and these must all be equal.

    Returns:
      Ops for appending values at the given row.
    """
        row = tf.convert_to_tensor(value=row, dtype=tf.int64)
        flattened_values = tf.nest.flatten(values)
        append_ops = []
        for spec, slot, value in zip(self._flattened_specs,
                                     self._flattened_slots, flattened_values):
            var_slot = self._slot2variable_map[slot].lookup(row)
            value_as_tl = list_ops.tensor_list_from_tensor(
                value,
                element_shape=tf.cast(spec.shape.as_list(), dtype=tf.int64))
            new_value = list_ops.tensor_list_concat_lists(
                var_slot, value_as_tl, element_dtype=spec.dtype)
            append_ops.append(self._slot2variable_map[slot].insert_or_assign(
                row, new_value))
        return tf.group(*append_ops)
示例#2
0
    def extend(self, rows, episode_lists):
        """Returns ops for extending a set of rows by the given TensorLists.

    Args:
      rows: A batch of row locations to extend.
      episode_lists: Nested batch of TensorLists, must have the same batch
        dimension as rows.

    Returns:
      Ops for extending the table.
    """
        tf.nest.assert_same_structure(self.slots, episode_lists)

        rows = tf.convert_to_tensor(value=rows, dtype=tf.int64)
        existing_lists = self.get_episode_lists(rows)
        flat_existing_lists = tf.nest.flatten(existing_lists)
        flat_episode_lists = tf.nest.flatten(episode_lists)

        write_ops = []
        for spec, slot, existing_list, episode_list in zip(
                self._flattened_specs, self._flattened_slots,
                flat_existing_lists, flat_episode_lists):
            extended_list = list_ops.tensor_list_concat_lists(
                existing_list, episode_list, element_dtype=spec.dtype)
            slot_variable = self._slot2variable_map[slot]
            write_ops.append(
                slot_variable.insert_or_assign(rows, extended_list))
        return tf.group(*write_ops)
示例#3
0
  def testConcat(self):
    c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
    l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
    l1 = list_ops.tensor_list_from_tensor([-1.0], element_shape=scalar_shape())
    l_batch_0 = array_ops.stack([l0, l1])
    l_batch_1 = array_ops.stack([l1, l0])

    l_concat_01 = list_ops.tensor_list_concat_lists(
        l_batch_0, l_batch_1, element_dtype=dtypes.float32)
    l_concat_10 = list_ops.tensor_list_concat_lists(
        l_batch_1, l_batch_0, element_dtype=dtypes.float32)
    l_concat_00 = list_ops.tensor_list_concat_lists(
        l_batch_0, l_batch_0, element_dtype=dtypes.float32)
    l_concat_11 = list_ops.tensor_list_concat_lists(
        l_batch_1, l_batch_1, element_dtype=dtypes.float32)

    expected_00 = [[1.0, 2.0, 1.0, 2.0], [-1.0, -1.0]]
    expected_01 = [[1.0, 2.0, -1.0], [-1.0, 1.0, 2.0]]
    expected_10 = [[-1.0, 1.0, 2.0], [1.0, 2.0, -1.0]]
    expected_11 = [[-1.0, -1.0], [1.0, 2.0, 1.0, 2.0]]

    for i, (concat, expected) in enumerate(zip(
        [l_concat_00, l_concat_01, l_concat_10, l_concat_11],
        [expected_00, expected_01, expected_10, expected_11])):
      splitted = array_ops.unstack(concat)
      splitted_stacked_ret = self.evaluate(
          (list_ops.tensor_list_stack(splitted[0], dtypes.float32),
           list_ops.tensor_list_stack(splitted[1], dtypes.float32)))
      print("Test concat %d: %s, %s, %s, %s"
            % (i, expected[0], splitted_stacked_ret[0],
               expected[1], splitted_stacked_ret[1]))
      self.assertAllClose(expected[0], splitted_stacked_ret[0])
      self.assertAllClose(expected[1], splitted_stacked_ret[1])

    # Concatenating mismatched shapes fails.
    with self.assertRaises((errors.InvalidArgumentError, ValueError)):
      self.evaluate(
          list_ops.tensor_list_concat_lists(
              l_batch_0,
              list_ops.empty_tensor_list(scalar_shape(), dtypes.float32),
              element_dtype=dtypes.float32))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "element shapes are not identical at index 0"):
      l_batch_of_vec_tls = array_ops.stack(
          [list_ops.tensor_list_from_tensor([[1.0]], element_shape=[1])] * 2)
      self.evaluate(
          list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_vec_tls,
                                            element_dtype=dtypes.float32))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 r"input_b\[0\].dtype != element_dtype."):
      l_batch_of_int_tls = array_ops.stack(
          [list_ops.tensor_list_from_tensor([1], element_shape=scalar_shape())]
          * 2)
      self.evaluate(
          list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_int_tls,
                                            element_dtype=dtypes.float32))