def append(self, row, values): """Returns ops for appending multiple time values at the given row. Args: row: A scalar location at which to append values. values: A nest of Tensors to append. The outermost dimension of each tensor is treated as a time axis, and these must all be equal. Returns: Ops for appending values at the given row. """ row = tf.convert_to_tensor(value=row, dtype=tf.int64) flattened_values = tf.nest.flatten(values) append_ops = [] for spec, slot, value in zip(self._flattened_specs, self._flattened_slots, flattened_values): var_slot = self._slot2variable_map[slot].lookup(row) value_as_tl = list_ops.tensor_list_from_tensor( value, element_shape=tf.cast(spec.shape.as_list(), dtype=tf.int64)) new_value = list_ops.tensor_list_concat_lists( var_slot, value_as_tl, element_dtype=spec.dtype) append_ops.append(self._slot2variable_map[slot].insert_or_assign( row, new_value)) return tf.group(*append_ops)
def extend(self, rows, episode_lists): """Returns ops for extending a set of rows by the given TensorLists. Args: rows: A batch of row locations to extend. episode_lists: Nested batch of TensorLists, must have the same batch dimension as rows. Returns: Ops for extending the table. """ tf.nest.assert_same_structure(self.slots, episode_lists) rows = tf.convert_to_tensor(value=rows, dtype=tf.int64) existing_lists = self.get_episode_lists(rows) flat_existing_lists = tf.nest.flatten(existing_lists) flat_episode_lists = tf.nest.flatten(episode_lists) write_ops = [] for spec, slot, existing_list, episode_list in zip( self._flattened_specs, self._flattened_slots, flat_existing_lists, flat_episode_lists): extended_list = list_ops.tensor_list_concat_lists( existing_list, episode_list, element_dtype=spec.dtype) slot_variable = self._slot2variable_map[slot] write_ops.append( slot_variable.insert_or_assign(rows, extended_list)) return tf.group(*write_ops)
def testConcat(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) l1 = list_ops.tensor_list_from_tensor([-1.0], element_shape=scalar_shape()) l_batch_0 = array_ops.stack([l0, l1]) l_batch_1 = array_ops.stack([l1, l0]) l_concat_01 = list_ops.tensor_list_concat_lists( l_batch_0, l_batch_1, element_dtype=dtypes.float32) l_concat_10 = list_ops.tensor_list_concat_lists( l_batch_1, l_batch_0, element_dtype=dtypes.float32) l_concat_00 = list_ops.tensor_list_concat_lists( l_batch_0, l_batch_0, element_dtype=dtypes.float32) l_concat_11 = list_ops.tensor_list_concat_lists( l_batch_1, l_batch_1, element_dtype=dtypes.float32) expected_00 = [[1.0, 2.0, 1.0, 2.0], [-1.0, -1.0]] expected_01 = [[1.0, 2.0, -1.0], [-1.0, 1.0, 2.0]] expected_10 = [[-1.0, 1.0, 2.0], [1.0, 2.0, -1.0]] expected_11 = [[-1.0, -1.0], [1.0, 2.0, 1.0, 2.0]] for i, (concat, expected) in enumerate(zip( [l_concat_00, l_concat_01, l_concat_10, l_concat_11], [expected_00, expected_01, expected_10, expected_11])): splitted = array_ops.unstack(concat) splitted_stacked_ret = self.evaluate( (list_ops.tensor_list_stack(splitted[0], dtypes.float32), list_ops.tensor_list_stack(splitted[1], dtypes.float32))) print("Test concat %d: %s, %s, %s, %s" % (i, expected[0], splitted_stacked_ret[0], expected[1], splitted_stacked_ret[1])) self.assertAllClose(expected[0], splitted_stacked_ret[0]) self.assertAllClose(expected[1], splitted_stacked_ret[1]) # Concatenating mismatched shapes fails. with self.assertRaises((errors.InvalidArgumentError, ValueError)): self.evaluate( list_ops.tensor_list_concat_lists( l_batch_0, list_ops.empty_tensor_list(scalar_shape(), dtypes.float32), element_dtype=dtypes.float32)) with self.assertRaisesRegexp(errors.InvalidArgumentError, "element shapes are not identical at index 0"): l_batch_of_vec_tls = array_ops.stack( [list_ops.tensor_list_from_tensor([[1.0]], element_shape=[1])] * 2) self.evaluate( list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_vec_tls, element_dtype=dtypes.float32)) with self.assertRaisesRegexp(errors.InvalidArgumentError, r"input_b\[0\].dtype != element_dtype."): l_batch_of_int_tls = array_ops.stack( [list_ops.tensor_list_from_tensor([1], element_shape=scalar_shape())] * 2) self.evaluate( list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_int_tls, element_dtype=dtypes.float32))