def FProp(self, theta, inputs, paddings): """Applies causal pooling to inputs. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. It is expected to be of shape [batch, time, frequency, channel]. The time dimension corresponds to the height dimension as in images and the frequency dimension corresponds to the width dimension as in images. paddings: The paddings tensor. It is expected to be of shape [batch, time]. Returns: outputs, out_paddings pair. - outputs: has the same shape as inputs. - out_paddings: has the same tshape as paddings. """ p = self.params if p.left_context is None: raise ValueError('left_context must be set.') window_size = p.left_context left_pad_size = window_size - 1 large_negative = p.dtype.max * tf.constant(-0.7, dtype=p.dtype) # For max pooling, use a large negative padding value such that the max # element is almost always from a non-padding position. pad_value = 0 if p.pooling_type == 'AVG' else large_negative inputs = tf.pad(inputs, [[0, 0], [left_pad_size, 0], [0, 0], [0, 0]], constant_values=pad_value) out_feature = tf.nn.pool(inputs, window_shape=(window_size, 1), pooling_type=p.pooling_type, padding='VALID') if p.pooling_type == 'AVG': # Count the fraction of non-padding elements inside each pooling window. in_mask = tf.pad(1.0 - paddings, [[0, 0], [left_pad_size, 0]]) non_padding_ratio = tf.nn.pool(in_mask[:, :, tf.newaxis], window_shape=(window_size, ), pooling_type='AVG', padding='VALID') # Divide by non-padding ratios to eliminate the effect of padded zeros. out_feature *= tf.math.reciprocal_no_nan( non_padding_ratio[..., tf.newaxis]) out_feature *= 1.0 - paddings[..., tf.newaxis, tf.newaxis] return out_feature, paddings
def SequenceAppendToken(x, x_paddings, token, extend=False): """Appends <token> to sequence `x`. Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. token: The token to append (of type integer). extend: Whether to extend `x` along the length dimension, this must be true for any sequence length in `x` that is `x_len_max` or else an invalid sequence will be emitted. Returns: A tuple. - The new sequence, Tensor of shape [batch_size, x_len_max]. - The new paddings, Tensor of shape [batch_size, x_len_max]. """ batch_size = py_utils.GetShape(x)[0] x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) if extend: x = tf.pad(x, [[0, 0], [0, 1]]) # Mask all invalid entries of `x` to 0. x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype) # Append the <token> based on `x_len`. x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1), tf.cast(tf.fill([batch_size], token), x.dtype), py_utils.GetShape(x)) x_paddings = 1 - tf.sequence_mask(x_len + 1, py_utils.GetShape(x)[1], x_paddings.dtype) return x, x_paddings
def ComputeConvOutputPadding(paddings, window, stride, padding_algorithm='SAME'): """Computes paddings for convolution and pooling output. out_padding[i] == 1 iff any in_padding corresponding to that output is 1. Args: paddings: The paddings tensor. It is expected to be of shape [batch, time]. window: The size of the windows. stride: The time-stride between adjacent windows. padding_algorithm: 'SAME' or 'VALID'. Returns: out_padding, The new padding tensor of size [batch, ceil(time / stride)]. """ if stride == 1: return paddings # Pad so input_length divides stride. input_length = py_utils.GetShape(paddings)[1] pad_len = (input_length + stride - 1) // stride * stride - input_length paddings = tf.pad(paddings, [[0, 0], [0, pad_len]], constant_values=1.0) out_padding = tf.nn.pool( tf.expand_dims(paddings, -1), [window], 'MAX', padding=padding_algorithm, strides=[stride], ) return tf.squeeze(out_padding, -1)
def _RescaleBoundary(self, out, in_paddings): # Rescale every output position by: # (# input positions) / (# non-padding input positions) # where (# input posisions) = filter_size. p = self.params in_mask = 1.0 - in_paddings # Compute the left and right implicity padding size used in 'SAME' mode. filter_t = p.filter_shape[0] effective_filter_size = (filter_t - 1) * p.dilation_rate[0] + 1 left_pad_size = (effective_filter_size - 1) // 2 right_pad_size = effective_filter_size // 2 # Compute the rescaling factor. # This expanded tensor has 1 on all valid positions, 0 on all padded ones, # which include both explicit padding provided by 'in_padding', and implicit # padding on boundaries. in_mask_padded = tf.pad(in_mask, [[0, 0], [left_pad_size, right_pad_size]]) # (# non-padding input positions) / (# input positions) factor_inverse = tf.nn.pool(in_mask_padded[:, :, tf.newaxis], window_shape=(filter_t, ), pooling_type='AVG', strides=(p.filter_stride[0], ), padding='VALID', dilations=(p.dilation_rate[0], )) factor = tf.math.reciprocal_no_nan(factor_inverse) return out * factor[..., tf.newaxis]
def _EvaluateConvKernel(self, theta, inputs): """Apply convolution to inputs.""" # Same as CausalDepthwiseConv2DLayer. p = self.params assert p.filter_shape[1] == 1, 'Only 1D causal convolutions supported.' padding_algorithm = 'VALID' causal_pad_size = (p.filter_shape[0] - 1) * p.dilation_rate[0] inputs = tf.pad(inputs, [[0, 0], [causal_pad_size, 0], [0, 0], [0, 0]]) filter_w = self._GetWeight(theta) return tf.nn.depthwise_conv2d( inputs, filter_w, strides=[1, p.filter_stride[0], p.filter_stride[1], 1], dilations=p.dilation_rate, data_format='NHWC', padding=padding_algorithm)
def RelShift(x): """Performs relative shift on 4D tensor (first 2 axis are batching dims). Given input of shape [?, ?, W, W], this does "relative shifting" for the last two dims, s.t. output[b, n, i, j] = 0 if i > j else input[b, n, i, j-i] Args: x: A Tensor of shape [?, ?, W, W] Returns: A Tensor of the same shape as input with its content shifted (as described above). """ b, n, w, _ = py_utils.GetShape(x) x = py_utils.HasShape(x, [-1, -1, w, w]) x = tf.pad(x, ((0, 0), (0, 0), (0, 0), (0, 1))) x = tf.reshape(x, [b, n, w + 1, w]) x = x[:, :, :w, :] return x
def _EvaluateConvKernel(self, theta, inputs): """Apply convolution to inputs.""" p = self.params assert p.filter_shape[1] == 1, 'Only 1D causal convolutions supported.' # Use VALID padding and shift the inputs to the right to ensure that the # first output only depends on the first input and so on. The output is # the same size as the input, as if the convolution used SAME padding. padding_algorithm = 'VALID' # The effective spatial filter width for dilated convolutions is # (kernel_width - 1) * dilation_rate + 1 as according to # https://www.tensorflow.org/api_docs/python/tf/nn/convolution. causal_pad_size = (p.filter_shape[0] - 1) * p.dilation_rate[0] inputs = tf.pad(inputs, [[0, 0], [causal_pad_size, 0], [0, 0], [0, 0]]) filter_w = self._GetWeight(theta) return tf.nn.depthwise_conv2d( inputs, filter_w, strides=[1, p.filter_stride[0], p.filter_stride[1], 1], dilations=p.dilation_rate, data_format='NHWC', padding=padding_algorithm)
def PadToTargetSeqLen(tensor, constant): length = tf.shape(tensor)[1] pad = tf.maximum(0, p.beam_search.target_seq_len - length) return tf.pad(tensor, [[0, 0], [0, pad]], constant_values=constant)
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)