def SequenceAppendToken(x, x_paddings, token, extend=False): """Appends <token> to sequence `x`. Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. token: The token to append (of type integer). extend: Whether to extend `x` along the length dimension, this must be true for any sequence length in `x` that is `x_len_max` or else an invalid sequence will be emitted. Returns: A tuple. - The new sequence, Tensor of shape [batch_size, x_len_max]. - The new paddings, Tensor of shape [batch_size, x_len_max]. """ batch_size = py_utils.GetShape(x)[0] x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) if extend: x = tf.pad(x, [[0, 0], [0, 1]]) # Mask all invalid entries of `x` to 0. x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype) # Append the <token> based on `x_len`. x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1), tf.cast(tf.fill([batch_size], token), x.dtype), py_utils.GetShape(x)) x_paddings = 1 - tf.sequence_mask(x_len + 1, py_utils.GetShape(x)[1], x_paddings.dtype) return x, x_paddings
def SequenceConcat(x, x_paddings, y, y_paddings, pad=0): """Concats sequence `x` with sequence `y`. This function is length aware (based off the paddings). Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. y: A sequence of tokens of shape [batch_size, y_len_max]. y_paddings: The paddings of `y`. pad: The <pad> token to fill the concatenated sequence (of type integer). Returns: A tuple. - Concatenation of `x` and `y` of shape [batch_size, x_len_max + y_len_max]. - Paddings of the concatenation of shape [batch_size, x_len_max + y_len_max]. """ # Get the length (w/ eos). x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32) batch_size = py_utils.GetShape(x)[0] y_len_max = py_utils.GetShape(y)[1] # Pad `x` with necessary <pad>. x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1) # Replace all <pad> with 0. x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0)) # Compute the write indices of `y` in `xy`. indices = tf.stack([ tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]), (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) + tf.expand_dims(x_len, 1)), ], 2) xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x)) # We need to remap all <pad> to `pad`. xy = tf.where( tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0), tf.expand_dims(x_len + y_len, 1)), xy, tf.fill(py_utils.GetShape(xy), pad)) xy_paddings = 1 - tf.sequence_mask(x_len + y_len, py_utils.GetShape(xy)[1], x_paddings.dtype) return xy, xy_paddings
def SparseToDense(grid_shape, locations, feats): """Converts a sparse representation back to the dense grid. Args: grid_shape: (nx, ny, nz). The shape of the grid. locations: [b, p, 3]. Locations of the pillars. feats: [b, p, fdims]. Extracted features for pillars. Returns: grid_feats: [b, nx, ny, nz * fdims]. """ nx, ny, nz = grid_shape b, p, _ = py_utils.GetShape(locations, 3) feats = py_utils.HasShape(feats, [b, p, -1]) _, _, fdims = py_utils.GetShape(feats, 3) indices = tf.concat( [tf.tile(tf.range(b)[:, tf.newaxis, tf.newaxis], [1, p, 1]), locations], axis=2) grid = tf.scatter_nd(indices, feats, [b, nx, ny, nz, fdims]) return tf.reshape(grid, [b, nx, ny, nz * fdims])
def ComputeSparseAttention(q, k, v, sparsity_indices, paddings=None): """Computes attention according to a sparsity pattern. We use the following capital letters to denote shape parameters: B = batch size S = length of the source sequence T = length of the target sequence N = number of attention heads H = dimensions of each attention head K = number of clusters W = attention window (K <= S) The 'sparsity_indices' is a tensor of integral type where the last dimension contains W indices (W is the attention window) for each corresponding position along S in 'k' that the query is allowed to attend to. For example, if sparsity_indices[batch_idx, target time step, head_idx] = [1, 7, 8], it means that token in the query attends to values with indices 1, 7, and 8, and the attention window here is 3. The valid values in 'sparsity_indices' are [-1, S-1]. Note that the value -1 is reserved to mean paddings, distinct from the value (S-1). For example, if W=S and 'sparsity_indices' contains range(S) on the last dimension, this degenerates to the original full attention. We require that 'sparsity_indices' does not contain duplicates (except for -1 to indicate paddings), but we do not require 'sparsity_indices' to be sorted. Note that this implementation is flexible and generic but is not optimized for time or space complexity. Please consider grouping queries that attend to the same subset of values first for efficiency. Args: q: (projected) queries, [B, T, N, H]; k: (projected) keys, [B, S, N, H]; v: (projected) values, [B, S, N, H]; sparsity_indices: [B, T, N, W], where W is the attention window; paddings: paddings for keys, [B, S] if not None. Returns: output: the encoded output, [B, T, N, H]. atten_probs: the attention weights, [B, T, N, S]. """ q = tf.convert_to_tensor(q) k = tf.convert_to_tensor(k) v = tf.convert_to_tensor(v) sparsity_indices = tf.convert_to_tensor(sparsity_indices) k = py_utils.HasRank(k, 4) _, source_length, _, dim_per_head = py_utils.GetShape(k, 4) sparsity_indices = py_utils.HasRank(sparsity_indices, 4) batch_size, target_length, num_heads, attention_window = py_utils.GetShape( sparsity_indices, 4) py_utils.assert_less_equal( attention_window, source_length, 'The provided sparsity_indices has attention window ' ' > source length. This is likely an error.') # To prepare for gathering the relevant vectors from 'k', we prepare # gather_idx of shape [B, T, N, W, 3] where the last dimension corresponds to # slices in 'k' indexed by (batch index, source time step, head index), # where the source length index comes from the original W dimension in # 'sparsity_indices'. seq_idx = tf.expand_dims(sparsity_indices, axis=-1) # Overwrite the paddings -1 with valid gather indices (zeros). We will # fix the logits with -inf in these positions later. seq_idx = tf.where(seq_idx < 0, tf.zeros_like(seq_idx), seq_idx) batch_idx = tf.reshape( tf.range(0, batch_size, dtype=sparsity_indices.dtype), [batch_size, 1, 1, 1, 1]) batch_idx = tf.tile(batch_idx, [1, target_length, num_heads, attention_window, 1]) head_idx = tf.reshape( tf.range(0, num_heads, dtype=sparsity_indices.dtype), [1, 1, num_heads, 1, 1]) head_idx = tf.tile(head_idx, [batch_size, target_length, 1, attention_window, 1]) # [B, T, N, W, 3], where last dimension is (batch index, source length index, # head index). gather_idx = tf.concat([batch_idx, seq_idx, head_idx], axis=-1) # Both the gathered k and v have shape [B, T, N, W, H] k = tf.gather_nd(k, gather_idx) v = tf.gather_nd(v, gather_idx) if paddings is None: paddings = tf.zeros([batch_size, source_length]) paddings = tf.convert_to_tensor(paddings) paddings = tf.expand_dims(paddings, axis=-1) # [B, S, N] paddings = tf.tile(paddings, [1, 1, num_heads]) # [B, T, N, W] paddings = tf.gather_nd(paddings, gather_idx) logits = tf.einsum('BTNH, BTNWH -> BTNW', q, k) logits *= tf.math.rsqrt(tf.cast(dim_per_head, q.dtype)) very_negative_logits = ( tf.ones_like(logits) * logits.dtype.max * tf.constant(-0.7, dtype=logits.dtype)) padded_logits = tf.where( tf.math.logical_or(sparsity_indices < 0, paddings > 0.0), very_negative_logits, logits) # [B, T, N, W] atten_probs = tf.nn.softmax(padded_logits, name='attention_weights') atten_probs = tf.where(sparsity_indices < 0, tf.zeros_like(logits), atten_probs) output = tf.einsum('BTNW, BTNWH -> BTNH', atten_probs, v) # Scatter 'atten_probs' back into the original source length. # [B, T, N, W, 1] batch_idx = tf.tile( tf.range(batch_size)[:, None, None, None, None], [1, target_length, num_heads, attention_window, 1]) # [B, T, N, W, 1] target_seq_idx = tf.tile( tf.range(target_length)[None, :, None, None, None], [batch_size, 1, num_heads, attention_window, 1]) # [B, T, N, W, 1] head_idx = tf.tile( tf.range(num_heads)[None, None, :, None, None], [batch_size, target_length, 1, attention_window, 1]) # seq_idx: [B, T, N, W, 1] # [B, T, N, W, 4] scatter_idx = tf.concat([batch_idx, target_seq_idx, head_idx, seq_idx], -1) # [B, T, N, S] scattered_probs = tf.scatter_nd( scatter_idx, atten_probs, [batch_size, target_length, num_heads, source_length]) return output, scattered_probs
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def _replicate_rows(tensor, multiple): tensor_shape = tensor.shape.as_list() expanded_shape = [tensor_shape[0] * multiple, tensor_shape[1]] indices = tf.constant(_generate_indices(tensor_shape[0], multiple)) return tf.scatter_nd(indices, _tile_rows(tensor, multiple), expanded_shape)