def _True(anchor, bboxes): """True branch when num of bboxes is non-zero.""" n = tf.shape(bboxes)[0] centroid = BBoxesCentroid(bboxes) # Computed dot products between centroid and the anchor point. dot = tf.squeeze(tf.matmul(centroid, tf.expand_dims(anchor, 1)), axis=1) # Normalize dot to get the cosine of the angles. norm = tf.norm(anchor) * tf.norm(centroid, axis=1) cosine = tf.where(tf.greater(norm, 0), dot / norm, tf.zeros([n], norm.dtype)) # Disambiguates the angle anchor--O--point is positive or negative by the # sign of cross products between angle and points. tf.linalg.cross takes # 3-vector (x, y, z), so we set z to 0. tf.linalg.cross does not support # broadcasting, so we tile anchor to shape [n, 3]. cross = tf.linalg.cross( tf.tile(tf.pad(tf.expand_dims(anchor, 0), [[0, 0], [0, 1]]), [n, 1]), tf.pad(centroid, [[0, 0], [0, 1]])) # If the sign is positive, the points lie on the clockwise side of # O-->anchor. Hence, -1 - cosine moves the cosine values to [-2, 0]. If the # sign is negative, the points lie on the counter-clockwise side of # O-->anchor. 1 + cosine moves the cosine values to [0, 2]. # # The car dataset shows that the points are scanned in the counter-clockwise # fashion. Therefore, top-k orders the points in the same order in which # bboxes appears in the spin. score = tf.where(tf.greater(cross, 0)[:, 2], -1 - cosine, 1 + cosine) _, indices = tf.nn.top_k(score, n, sorted=True) return indices
def FProp(self, theta, inputs, paddings): """Applies causal pooling to inputs. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. It is expected to be of shape [batch, time, frequency, channel]. The time dimension corresponds to the height dimension as in images and the frequency dimension corresponds to the width dimension as in images. paddings: The paddings tensor. It is expected to be of shape [batch, time]. Returns: outputs, out_paddings pair. - outputs: has the same shape as inputs. - out_paddings: has the same tshape as paddings. """ p = self.params if p.left_context == -1: if p.pooling_type == 'AVG': cumulative_sum = tf.math.cumsum(inputs, axis=1) cumulative_count = 1.0 + tf.range(py_utils.GetShape(inputs)[1], dtype=p.dtype) cumulative_mean = cumulative_sum / cumulative_count[ tf.newaxis, :, tf.newaxis, tf.newaxis] cumulative_mean *= 1.0 - paddings[..., tf.newaxis, tf.newaxis] return cumulative_mean, paddings else: raise NotImplementedError( 'Cumulative max pooling not implemented.') window_size = p.left_context left_pad_size = window_size - 1 large_negative = p.dtype.max * tf.constant(-0.7, dtype=p.dtype) # For max pooling, use a large negative padding value such that the max # element is almost always from a non-padding position. pad_value = 0 if p.pooling_type == 'AVG' else large_negative inputs = tf.pad(inputs, [[0, 0], [left_pad_size, 0], [0, 0], [0, 0]], constant_values=pad_value) out_feature = tf.nn.pool(inputs, window_shape=(window_size, 1), pooling_type=p.pooling_type, padding='VALID') if p.pooling_type == 'AVG': # Count the fraction of non-padding elements inside each pooling window. in_mask = tf.pad(1.0 - paddings, [[0, 0], [left_pad_size, 0]]) non_padding_ratio = tf.nn.pool(in_mask[:, :, tf.newaxis], window_shape=(window_size, ), pooling_type='AVG', padding='VALID') # Divide by non-padding ratios to eliminate the effect of padded zeros. out_feature *= tf.math.reciprocal_no_nan( non_padding_ratio[..., tf.newaxis]) out_feature *= 1.0 - paddings[..., tf.newaxis, tf.newaxis] return out_feature, paddings
def RelPositionBias(self, content, abs_pos_emb, skip_term_b=False): """Compute relative position bias. This is a subroutine used by variants of self-attentions with relative positional embedding. output[b][n][i][j] = content[b][i][n] x abs_pos_emb[i-j+T-1][n] Padding should be masked by the caller of this function. B: batch size T: sequence length N: num of attention heads. H: per-head attention dimension. Args: tensors of the following shapes: content: [N, H] if skip_term_b else [B, T, N, H] abs_pos_emb: [2T - 1, N, H], the absolute positional embedding. abs_pos_emb[i] is the emb of relative distance i - (T-1). skip_term_b: If to skip term_b in section 3.3 equation. Returns: The attention logits tensor. [N, T, T] if skip_term_b else [B, N, T, T]. """ if not skip_term_b: b, t, n, h = py_utils.GetShape(content) l = 2 * t - 1 abs_pos_emb = py_utils.HasShape(abs_pos_emb, [l, n, h]) else: n, h = py_utils.GetShape(content) l = py_utils.GetShape(abs_pos_emb)[0] t = (l + 1) // 2 if not skip_term_b: # [B, N, T, L=2T-1] content, abs_pos_emb = self.ToAqtActActInputs(content, abs_pos_emb) term_bd = tf.einsum('BTNH,LNH->BNTL', content, abs_pos_emb) term_bd = self.FromAqtActActMatmul(term_bd) term_bd = tf.reshape(term_bd, [b, n, t * l], name='flatten') # [B, N, T * (L + 1)]. term_bd = tf.pad(term_bd, ((0, 0), (0, 0), (0, t))) # [B, N, T, L + 1]. term_bd = tf.reshape(term_bd, [b, n, t, l + 1], name='restore') return term_bd[:, :, :, t - 1::-1] else: # [N, L=2T-1] content, abs_pos_emb = self.ToAqtActActInputs(content, abs_pos_emb) term_d = tf.einsum('NH,LNH->NL', content, abs_pos_emb) term_d = self.FromAqtActActMatmul(term_d) # [N, T, L] term_d = tf.tile(tf.expand_dims(term_d, axis=1), [1, t, 1], name='tile') term_d = tf.reshape(term_d, [n, t * l]) # [N, T * (L + 1)]. term_d = tf.pad(term_d, ((0, 0), (0, t))) # [N, T, L + 1]. term_d = tf.reshape(term_d, [n, t, l + 1], name='restore') return term_d[:, :, t - 1::-1]
def _PadInput(self, inputs, paddings, num_frames): if self.input_rank == 3: inputs = tf.pad(inputs, [[0, 0], [0, num_frames], [0, 0]]) else: inputs = tf.pad(inputs, [[0, 0], [0, num_frames], [0, 0], [0, 0]]) paddings = tf.pad(paddings, [[0, 0], [0, num_frames]], constant_values=1) return inputs, paddings
def Mask(self, seq_ids, weights, actual_seq_len): p = self.params if p.mask_ratio == 0.: tf.logging.info( 'ATTENTION! mask ratio is set to 0, no mask is applied') src_ids = seq_ids tgt_ids = tf.pad(seq_ids, [[0, 0], [1, 0]], constant_values=1)[:, :-1] tgt_labels = seq_ids tgt_weights = weights else: (src_ids, tgt_ids, tgt_labels, tgt_weights) = ops.mass(seq_ids, weights, actual_seq_len, mask_id=p.mask_id, mask_ratio=p.mask_ratio, mask_minlen=p.mask_minlen, span_len=p.span_len, random_start_prob=p.random_start_prob, keep_prob=p.keep_prob, rand_prob=p.rand_prob, mask_prob=p.mask_prob, mask_target=p.mask_target, vocab_size=p.vocab_size, first_unreserved_id=p.first_unreserved_id) mass_out = py_utils.NestedMap() mass_out.src = py_utils.NestedMap() mass_out.src.ids = src_ids mass_out.tgt = py_utils.NestedMap() mass_out.tgt.ids = tgt_ids mass_out.tgt.labels = tgt_labels mass_out.tgt.weights = tgt_weights return mass_out
def ComputeConvOutputPadding(paddings, window, stride, padding_algorithm='SAME'): """Computes paddings for convolution and pooling output. out_padding[i] == 1 iff any in_padding corresponding to that output is 1. Args: paddings: The paddings tensor. It is expected to be of shape [batch, time]. window: The size of the windows. stride: The time-stride between adjacent windows. padding_algorithm: 'SAME' or 'VALID'. Returns: out_padding, The new padding tensor of size [batch, ceil(time / stride)]. """ if stride == 1: return paddings # Pad so input_length divides stride. input_length = py_utils.GetShape(paddings)[1] pad_len = (input_length + stride - 1) // stride * stride - input_length paddings = tf.pad(paddings, [[0, 0], [0, pad_len]], constant_values=1.0) out_padding = tf.nn.pool( tf.expand_dims(paddings, -1), [window], 'MAX', padding_algorithm, strides=[stride], ) return tf.squeeze(out_padding, -1)
def SequenceAppendToken(x, x_paddings, token, extend=False): """Appends <token> to sequence `x`. Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. token: The token to append (of type integer). extend: Whether to extend `x` along the length dimension, this must be true for any sequence length in `x` that is `x_len_max` or else an invalid sequence will be emitted. Returns: A tuple. - The new sequence, Tensor of shape [batch_size, x_len_max]. - The new paddings, Tensor of shape [batch_size, x_len_max]. """ batch_size = py_utils.GetShape(x)[0] x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) if extend: x = tf.pad(x, [[0, 0], [0, 1]]) # Mask all invalid entries of `x` to 0. x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype) # Append the <token> based on `x_len`. x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1), tf.cast(tf.fill([batch_size], token), x.dtype), py_utils.GetShape(x)) x_paddings = 1 - tf.sequence_mask(x_len + 1, py_utils.GetShape(x)[1], x_paddings.dtype) return x, x_paddings
def _CombinerEmbLookup(self, sparse_ids: tf.SparseTensor, partition_strategy: str) -> Dict[str, tf.Tensor]: """Combiner embedding lookup. Args: sparse_ids: A dict of `input_key` string -> [batch, ...] int32 SparseTensor. partition_strategy: See TPUEmbeddingLayer partition_strategy param. Returns: An activations dict of string -> float32 Tensor of dimension [batch, 1, embedding_dim] """ p = self.params embs = tf.nn.embedding_lookup_sparse( self.theta.wm, sparse_ids, None, # sp_weights combiner=p.combiner, partition_strategy=partition_strategy) batch_size = sparse_ids.dense_shape[0] # For tf.nn.embedding_lookup_sparse, output.dim0 might be different from # sparse_ids.dense_shape.dim0. # Explicitly pad results to maintain dim0=batch. dim0_padlen = tf.cast(batch_size, tf.int32) - tf.shape(embs)[0] embs = tf.pad(embs, [[0, dim0_padlen], [0, 0]]) # [batch, 1, embedding_dim] embs = py_utils.HasShape(embs, [batch_size], ndims=1) return tf.expand_dims(embs, 1)
def Pad(key, t): constant_v = 0 if t.dtype.is_floating and key.endswith('.paddings'): constant_v = 1.0 need = self.params.batch_size - py_utils.GetShape(t)[0] padded = tf.pad(t, [[0, need], [0, 0]], 'CONSTANT', constant_v) return padded
def _RescaleBoundary(self, out, in_paddings): # Rescale every output position by: # (# input positions) / (# non-padding input positions) # where (# input positions) = filter_size. p = self.params in_mask = 1.0 - in_paddings # Compute the left and right implicity padding size used in 'SAME' mode. filter_t = p.filter_shape[0] effective_filter_size = (filter_t - 1) * p.dilation_rate[0] + 1 left_pad_size = (effective_filter_size - 1) // 2 right_pad_size = effective_filter_size // 2 # Compute the rescaling factor. # This expanded tensor has 1 on all valid positions, 0 on all padded ones, # which include both explicit padding provided by 'in_padding', and implicit # padding on boundaries. in_mask_padded = tf.pad(in_mask, [[0, 0], [left_pad_size, right_pad_size]]) # (# non-padding input positions) / (# input positions) factor_inverse = tf.nn.pool(in_mask_padded[:, :, tf.newaxis], window_shape=(filter_t, ), pooling_type='AVG', strides=(p.filter_stride[0], ), padding='VALID', dilations=(p.dilation_rate[0], )) factor = tf.math.reciprocal_no_nan(factor_inverse) return out * factor[..., tf.newaxis]
def testCausalConv2DLayerStridedWithPaddingFPropV2(self, seq_len): """Check strided convs get the same values for different length dim.""" with self.session(use_gpu=True): batch_size = 5 expected_seq_len = 3 params = conv_layers.CausalConv2DLayerWithPadding.Params() params.v2_padding = True params.weight_norm = False params.filter_stride = [2, 2] params.name = 'conv' params.filter_shape = [3, 1, 1, 1] params.params_init = py_utils.WeightInit.Constant(1.0) conv_layer = params.Instantiate() # Set up the padding for the sequence length. (starting at 5). in_padding = tf.constant([ [0, 0, 0, 0, 0], [0, 0, 0, 0, 1], [0, 0, 0, 1, 1], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], ], tf.float32) in_padding = tf.pad( in_padding, [[0, 0], [0, seq_len - 5]], constant_values=1.0) inputs = 1.0 + tf.tile( tf.reshape(tf.range(seq_len, dtype=tf.float32), [1, seq_len, 1, 1]), [batch_size, 1, 3, 1]) inputs = py_utils.ApplyPadding( tf.reshape(in_padding, [batch_size, seq_len, 1, 1]), inputs) inputs = py_utils.Debug(inputs) output, out_padding = conv_layer.FPropDefaultTheta(inputs, in_padding) output = py_utils.Debug(output) out_padding = py_utils.Debug(out_padding) self.evaluate(tf.global_variables_initializer()) output, out_padding = self.evaluate([output, out_padding]) self.assertEqual((batch_size, expected_seq_len, 2, 1), output.shape) self.assertAllClose([ [0, 0, 0], [0, 0, 1], [0, 0, 1], [0, 1, 1], [0, 1, 1], ], out_padding) self.assertAllClose( [ [[[1], [1]], [[6], [6]], [[12], [12]]], [[[1], [1]], [[6], [6]], [[7], [7]]], [[[1], [1]], [[6], [6]], [[3], [3]]], # NOTE: not padded. [[[1], [1]], [[3], [3]], [[0], [0]]], [[[1], [1]], [[1], [1]], [[0], [0]]], ], output)
def CpuEmbLookup(self, ids_map, partition_strategy): """CPU evaluation embedding lookup. Args: ids_map: A dict of `input_key` string -> [batch, sequence] int32 Tensor. -1 is used as a padding id. partition_strategy: See TPUEmbeddingLayer partition_strategy param. Returns: An activations dict of string -> float32 Tensor. For non-sequence embeddings: [batch, 1, embedding_dim] For sequence embeddings: [batch, max_sequence_length, embedding_dim] """ p = self.params rets = py_utils.NestedMap() if self.max_sequence_length > 0: # "Sequence embedding", no combiner case for k, ids in ids_map.items(): embs = tf.nn.embedding_lookup( self.theta.wm, tf.reshape(ids, [-1]), partition_strategy=partition_strategy) out_shape = tf.concat([tf.shape(ids), [p.embedding_dim]], 0) rets[k] = tf.reshape(embs, out_shape) else: # Non-"Sequence embedding", combiner case for k, ids in ids_map.items(): # Dense to sparse. dense_shape = tf.shape(ids, out_type=tf.int64) sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64) embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64) sparse_ids = tf.SparseTensor(indices=sample_indices, values=embedding_indices, dense_shape=dense_shape) # [?, embedding_dim] # For tf.nn.embedding_lookup_sparse, output.dim0 might be different from # sparse_ids.dense_shape.dim0. # In fact, the '?' is the smallest span starting from the index=0 that # covers all the results. embs = tf.nn.embedding_lookup_sparse( self.theta.wm, sparse_ids, None, # sp_weights combiner=p.combiner, partition_strategy=partition_strategy) batch_size = dense_shape[0] # Explicitly pad results to maintain dim0=batch. dim0_padlen = tf.cast(batch_size, tf.int32) - tf.shape(embs)[0] embs = tf.pad(embs, [[0, dim0_padlen], [0, 0]]) # [batch, 1, embedding_dim] embs = py_utils.HasShape(embs, [batch_size], ndims=1) rets[k] = tf.expand_dims(embs, 1) return rets
def _MakeTransformTestRotationMatrices(self, batch_size): # Make a batch of 4x4 transformation matrices that only has rotation around # the z-axis (world rotation). rot_matrices = [] for _ in range(batch_size): rot_matrix = geometry._MakeRotationMatrix(tf.random_uniform([]), 0., 0.) # Embed rotation matrix into a 4 x 4 matrix rot_matrix = tf.pad(rot_matrix, [[0, 1], [0, 1]]) + tf.diag([0, 0, 0, 1.]) rot_matrices.append(rot_matrix) transforms = tf.stack(rot_matrices, axis=0) return transforms
def _MakeTransformTestTranslationMatrices(self, batch_size): # Make a batch of 4x4 transformation matrices that translate in all # directions. translation_matrices = [] for _ in range(batch_size): translation_matrix = tf.random_uniform([3, 1]) translation_matrix = tf.pad(translation_matrix, [[0, 1], [3, 0]]) translation_matrix += tf.diag([1., 1., 1., 1.]) translation_matrices.append(translation_matrix) transforms = tf.stack(translation_matrices, axis=0) return transforms
def _EvaluateConvKernel(self, theta, inputs): """Apply convolution to inputs.""" # Same as CausalDepthwiseConv2DLayer. p = self.params assert p.filter_shape[1] == 1, 'Only 1D causal convolutions supported.' padding_algorithm = 'VALID' causal_pad_size = (p.filter_shape[0] - 1) * p.dilation_rate[0] inputs = tf.pad(inputs, [[0, 0], [causal_pad_size, 0], [0, 0], [0, 0]]) filter_w = self._GetWeight(theta) return tf.nn.depthwise_conv2d( inputs, filter_w, strides=[1, p.filter_stride[0], p.filter_stride[1], 1], rate=p.dilation_rate, data_format='NHWC', padding=padding_algorithm)
def _CreateTargetLambdas(self, atten_probs, source_lambdas_pair, source_paddings_pair, target_paddings_pair, smooth=0): """Compute target interpolation ratios. Args: atten_probs: A list containing two attention matrics. source_lambdas_pair: A list containing two source interpolation ratios. source_paddings_pair: A list containing two source paddings. target_paddings_pair: A list containing two target paddings smooth: A real value to smooth target interpolation ratios before normalization. Returns: source_lambdas_pair: Source interpolation ratios. input_lambdas: Interpolation ratios for target input embeddings. label_lambdas: Interpolation ratios for target labels. """ atten_probs_0 = tf.stop_gradient(atten_probs[0]) atten_probs_1 = tf.stop_gradient(atten_probs[1]) source_lambdas = source_lambdas_pair[0] other_source_lambdas = source_lambdas_pair[1] lambdas_0 = atten_probs_0 * tf.expand_dims( source_lambdas * (1.0 - source_paddings_pair[0]), 1) lambdas_0 = tf.reduce_sum(lambdas_0, -1) lambdas_0 = (lambdas_0 + smooth) * (1.0 - target_paddings_pair[0]) lambdas_1 = atten_probs_1 * tf.expand_dims( other_source_lambdas * (1.0 - source_paddings_pair[1]), 1) lambdas_1 = tf.reduce_sum(lambdas_1, -1) lambdas_1 = (lambdas_1 + smooth) * (1.0 - target_paddings_pair[1]) label_lambdas_0 = lambdas_0 / (lambdas_0 + lambdas_1 + 1e-9) label_lambdas = [label_lambdas_0, (1.0 - label_lambdas_0)] input_lambdas_0 = tf.pad(label_lambdas_0, [[0, 0], [1, 0]], constant_values=1.)[:, :-1] input_lambdas = [ input_lambdas_0 * (1. - target_paddings_pair[0]), (1.0 - input_lambdas_0) * (1. - target_paddings_pair[1]) ] return source_lambdas_pair, input_lambdas, label_lambdas
def RelShift(x): """Performs relative shift on 4D tensor (first 2 axis are batching dims). Given input of shape [?, ?, W, W], this does "relative shifting" for the last two dims, s.t. output[b, n, i, j] = 0 if i > j else input[b, n, i, j-i] Args: x: A Tensor of shape [?, ?, W, W] Returns: A Tensor of the same shape as input with its content shifted (as described above). """ b, n, w, _ = py_utils.GetShape(x) x = py_utils.HasShape(x, [-1, -1, w, w]) x = tf.pad(x, ((0, 0), (0, 0), (0, 0), (0, 1))) x = tf.reshape(x, [b, n, w + 1, w]) x = x[:, :, :w, :] return x
def _EvaluateConvKernel(self, theta, inputs): """Apply convolution to inputs.""" p = self.params assert p.filter_shape[1] == 1, 'Only 1D causal convolutions supported.' # Use VALID padding and shift the inputs to the right to ensure that the # first output only depends on the first input and so on. The output is # the same size as the input, as if the convolution used SAME padding. padding_algorithm = 'VALID' # The effective spatial filter width for dilated convolutions is # (kernel_width - 1) * dilation_rate + 1 as according to # https://www.tensorflow.org/api_docs/python/tf/nn/convolution. causal_pad_size = (p.filter_shape[0] - 1) * p.dilation_rate[0] inputs = tf.pad(inputs, [[0, 0], [causal_pad_size, 0], [0, 0], [0, 0]]) filter_w = self._GetWeight(theta) return tf.nn.depthwise_conv2d( inputs, filter_w, strides=[1, p.filter_stride[0], p.filter_stride[1], 1], rate=p.dilation_rate, data_format='NHWC', padding=padding_algorithm)
def RelPositionBias(content, abs_pos_emb): """Compute relative position bias. This is a subroutine used by variants of self-attentions with relative positional embedding. B: batch size T: sequence length N: num of attention heads. H: per-head attention dimension. output[b][n][i][j] = content[b][i][n] x abs_pos_emb[i-j+T-1][n] Notice padding is supposed to be masked by the caller of this function. Args: tensors of the following shapes: content: [B, T, N, H] abs_pos_emb: [2T - 1, N, H], the absolute positional embedding. abs_pos_emb[i] is the emb of relative distance i - (T-1). Returns: The attention logits tensor. [B, N, T, T] """ b, t, n, h = py_utils.GetShape(content) l = 2 * t - 1 abs_pos_emb = py_utils.HasShape(abs_pos_emb, [l, n, h]) # [B, N, T, L=2T-1] term_bd = tf.einsum('BTNH,LNH->BNTL', content, abs_pos_emb) term_bd = tf.reshape(term_bd, [b, n, t * l], name='flatten') # [B, N, T * (L + 1)]. term_bd = tf.pad(term_bd, ((0, 0), (0, 0), (0, t))) # [B, N, T, L + 1]. term_bd = tf.reshape(term_bd, [b, n, t, l + 1], name='restore') return term_bd[:, :, :, t - 1::-1]
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def FProp(self, theta, batch, state0=None): """Encodes source as represented by 'inputs' and 'paddings'. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. batch: A NestedMap with fields: - src_inputs - The inputs tensor. It is expected to be of shape [batch, time, feature_dim, channels]. - paddings - The paddings tensor. It is expected to be of shape [batch, time]. state0: Recurrent input state. Not supported/ignored by this encoder. Returns: A NestedMap containing - 'encoded': a feature tensor of shape [time, batch, depth] - 'padding': a 0/1 tensor of shape [time, batch] - 'state': the updated recurrent state - '${layer_type}_${layer_index}': The per-layer encoder output. Each one is a NestedMap containing 'encoded' and 'padding' similar to regular final outputs, except that 'encoded' from conv or conv_lstm layers are of shape [time, batch, depth, channels]. """ p = self.params inputs, paddings = batch.src_inputs, batch.paddings outputs = py_utils.NestedMap() with tf.name_scope(p.name): # Adding specAugmentation. if p.use_specaugment and not self.do_eval: inputs, paddings = self.specaugment.FProp( theta.specaugment, inputs, paddings) # Add a few extra padded timesteps at the end. This is for ensuring the # correctness of the conv-layers at the edges. if p.pad_steps > 0: # inplace_update() is not supported by TPU for now. Since we have done # padding on the input_generator, we may avoid this additional padding. assert not py_utils.use_tpu() inputs_pad = tf.zeros( inplace_ops.inplace_update(tf.shape(inputs), 1, p.pad_steps), inputs.dtype) paddings_pad = tf.ones( inplace_ops.inplace_update(tf.shape(paddings), 1, p.pad_steps), paddings.dtype) inputs = tf.concat([inputs, inputs_pad], 1, name='inputs') paddings = tf.concat([paddings, paddings_pad], 1) plots = [ summary_utils.PrepareSequenceForPlot( tf.transpose(inputs, [0, 1, 3, 2]), paddings, 'inputs') ] conv_out = inputs out_padding = paddings for i, conv_layer in enumerate(self.conv): conv_out, out_padding = conv_layer.FProp( theta.conv[i], conv_out, out_padding) if p.extra_per_layer_outputs: conv_out *= (1.0 - out_padding[:, :, tf.newaxis, tf.newaxis]) outputs['conv_%d' % i] = py_utils.NestedMap( encoded=tf.transpose(conv_out, [1, 0, 2, 3]), # to [t, b, d, c] padding=tf.transpose(out_padding)) plots.append( summary_utils.PrepareSequenceForPlot( tf.transpose(conv_out, [0, 1, 3, 2]), out_padding, 'conv_%d_out' % i)) def TransposeFirstTwoDims(t): first_dim = tf.shape(t)[0] second_dim = tf.shape(t)[1] t_new = tf.transpose( tf.reshape(t, [first_dim, second_dim, -1]), [1, 0, 2]) t_shape_new = tf.concat([[second_dim], [first_dim], tf.shape(t)[2:]], 0) return tf.reshape(t_new, t_shape_new) # Now the conv-lstm part. conv_lstm_out = conv_out conv_lstm_out_padding = out_padding for i, (rnn, cnn) in enumerate( zip(self.conv_lstm_rnn, self.conv_lstm_cnn)): conv_lstm_in = conv_lstm_out # Move time dimension to be the first. conv_lstm_in = TransposeFirstTwoDims(conv_lstm_in) conv_lstm_in = tf.expand_dims(conv_lstm_in, 2) conv_lstm_in_padding = tf.expand_dims( tf.transpose(conv_lstm_out_padding), 2) lstm_out = rnn.FProp(theta.conv_lstm_rnn[i], conv_lstm_in, conv_lstm_in_padding) # Move time dimension to be the second. cnn_in = TransposeFirstTwoDims(lstm_out) cnn_in = tf.squeeze(cnn_in, 2) cnn_in_padding = conv_lstm_out_padding cnn_out, cnn_out_padding = cnn.FProp(theta.conv_lstm_cnn[i], cnn_in, cnn_in_padding) conv_lstm_out, conv_lstm_out_padding = cnn_out, cnn_out_padding if p.extra_per_layer_outputs: conv_lstm_out *= ( 1.0 - conv_lstm_out_padding[:, :, tf.newaxis, tf.newaxis]) outputs['conv_lstm_%d' % i] = py_utils.NestedMap( encoded=tf.transpose(conv_lstm_out, [1, 0, 2, 3]), # to [t, b, d, c] padding=tf.transpose(conv_lstm_out_padding)) plots.append( summary_utils.PrepareSequenceForPlot( conv_lstm_out, conv_lstm_out_padding, 'conv_lstm_%d_out' % i)) # Need to do a reshape before starting the rnn layers. conv_lstm_out = py_utils.HasRank(conv_lstm_out, 4) conv_lstm_out_shape = tf.shape(conv_lstm_out) new_shape = tf.concat([conv_lstm_out_shape[:2], [-1]], 0) conv_lstm_out = tf.reshape(conv_lstm_out, new_shape) if self._first_lstm_input_dim_pad: conv_lstm_out = tf.pad( conv_lstm_out, [[0, 0], [0, 0], [0, self._first_lstm_input_dim_pad]]) conv_lstm_out = py_utils.HasShape( conv_lstm_out, [-1, -1, self._first_lstm_input_dim]) # Transpose to move the time dimension to be the first. rnn_in = tf.transpose(conv_lstm_out, [1, 0, 2]) rnn_padding = tf.expand_dims(tf.transpose(conv_lstm_out_padding), 2) # rnn_in is of shape [time, batch, depth] # rnn_padding is of shape [time, batch, 1] # Now the rnn layers. num_skips = 0 for i in range(p.num_lstm_layers): rnn_out = self.rnn[i].FProp(theta.rnn[i], rnn_in, rnn_padding) residual_index = i - p.residual_start + 1 if p.residual_start > 0 and residual_index >= 0: if residual_index % p.residual_stride == 0: residual_in = rnn_in if residual_index % p.residual_stride == p.residual_stride - 1: # Highway skip connection. if p.highway_skip: rnn_out = self.highway_skip[num_skips].FProp( theta.highway_skip[num_skips], residual_in, rnn_out) num_skips += 1 else: # Residual skip connection. rnn_out += py_utils.HasShape( residual_in, tf.shape(rnn_out)) if p.project_lstm_output and (i < p.num_lstm_layers - 1): # Projection layers. rnn_out = self.proj[i].FProp(theta.proj[i], rnn_out, rnn_padding) if i == p.num_lstm_layers - 1: rnn_out *= (1.0 - rnn_padding) if p.extra_per_layer_outputs: rnn_out *= (1.0 - rnn_padding) outputs['rnn_%d' % i] = py_utils.NestedMap( encoded=rnn_out, padding=tf.squeeze(rnn_padding, [2])) # Stacking layer connection. if p.layer_index_before_stacking == i: # Stacking layer expects input tensor shape as [batch, time, feature]. # So transpose the tensors before and after the layer. rnn_out, rnn_padding = self.stacking.FProp( tf.transpose(rnn_out, [1, 0, 2]), tf.transpose(rnn_padding, [1, 0, 2])) rnn_out = tf.transpose(rnn_out, [1, 0, 2]) rnn_padding = tf.transpose(rnn_padding, [1, 0, 2]) plots.append( summary_utils.PrepareSequenceForPlot( tf.transpose(rnn_out, [1, 0, 2]), tf.transpose(rnn_padding, [1, 0, 2]), 'rnn_%d_out' % i)) rnn_in = rnn_out final_out = rnn_in summary_utils.PlotSequenceFeatures(list(reversed(plots)), 'encoder_example', xlabel='Time') outputs['encoded'] = final_out outputs['padding'] = tf.squeeze(rnn_padding, [2]) outputs['state'] = py_utils.NestedMap() return outputs
def test_max_assign_batch_version(self): # 2x2 example score1 = tf.convert_to_tensor([[0.5, 1.0], [0.2, 0.6]]) row_sums1 = tf.convert_to_tensor([1.0, 1.0]) col_sums1 = tf.convert_to_tensor([1.0, 1.0]) upper_bound1 = tf.ones_like(score1) # 3x3 example score2 = tf.convert_to_tensor([[1.0, 0, 0], [0, 1.0, 0], [0, 0, 1.0]]) row_sums2 = tf.convert_to_tensor([1.0, 1.0, 1.0]) col_sums2 = tf.convert_to_tensor([1.0, 1.0, 1.0]) upper_bound2 = tf.ones_like(score2) score1 = score1[tf.newaxis] row_sums1 = row_sums1[tf.newaxis] col_sums1 = col_sums1[tf.newaxis] upper_bound1 = upper_bound1[tf.newaxis] score2 = score2[tf.newaxis] row_sums2 = row_sums2[tf.newaxis] col_sums2 = col_sums2[tf.newaxis] upper_bound2 = upper_bound2[tf.newaxis] # A batch with example 1 and example 2. We need to pad example 1. # Padded scores should have very large negative value. # Padded sums and upper bound should be zero. # score1_ = tf.pad(score1, [[0, 0], [0, 1], [0, 1]], constant_values=-1e+20) row_sums1_ = tf.pad(row_sums1, [[0, 0], [0, 1]]) col_sums1_ = tf.pad(col_sums1, [[0, 0], [0, 1]]) upper_bound1_ = tf.pad(upper_bound1, [[0, 0], [0, 1], [0, 1]]) score3 = tf.concat([score1_, score2], axis=0) row_sums3 = tf.concat([row_sums1_, row_sums2], axis=0) col_sums3 = tf.concat([col_sums1_, col_sums2], axis=0) upper_bound3 = tf.concat([upper_bound1_, upper_bound2], axis=0) results1 = differentiable_assignment.max_assignment( score1, elementwise_upper_bound=upper_bound1, row_sums=row_sums1, col_sums=col_sums1, epsilon=0.01, num_iterations=200) results2 = differentiable_assignment.max_assignment( score2, elementwise_upper_bound=upper_bound2, row_sums=row_sums2, col_sums=col_sums2, epsilon=0.01, num_iterations=200) results3 = differentiable_assignment.max_assignment( score3, elementwise_upper_bound=upper_bound3, row_sums=row_sums3, col_sums=col_sums3, epsilon=0.01, num_iterations=200) assignment1 = results1[0] assignment2 = results2[0] assignment3 = results3[0] print("") print("Test case - batched:") print("Used iter:", results1[1], results2[1], results3[1]) print("Delta:", results1[-1], results2[-1], results3[-1]) print("Assignments:") print(assignment1[0]) print(assignment2[0]) print(assignment3) self.assertNDArrayNear(assignment1[0], assignment3[0, :2, :2], err=1e-4) self.assertNDArrayNear(assignment2[0], assignment3[1], err=1e-4)
def _AddNoise(self, batch): """Adding noise the src (see https://arxiv.org/pdf/1711.00043). This function implement 3 types of noise (hyparams defined in self.params.denoise): 1) slightly shuffle the sentence following p.shuffle_tok_range 2) randomly drop tokens with probability p.drop_tok_prob 3) randomly mask tokens with probability p.blank_tok_prob The noises are added to the input with probability p.noise_sent_prob. Args: batch: a `.NestedMap` of the input batch. """ def IsSpecialExample(task_ids, special_task_ids): """A utility function indicates whether inputs belong to specific tasks. Args: task_ids: Task ids for the input batch. Tensor of shape [batch]. special_task_ids: A list of specified task ids. Returns: A tensor indicating whether each sample in the batch belong to the specified task. Return a tensor of size [batch]. """ batch_size = py_utils.GetShape(task_ids)[0] return tf.reduce_any( tf.equal( tf.expand_dims(task_ids, -1), tf.cast( tf.broadcast_to( special_task_ids, [batch_size, len(special_task_ids)]), tf.int32)), -1) p = self.params.denoise batch_size = tf.shape(batch.src.ids)[0] source_max_len = tf.shape(batch.src.ids)[1] # Shuffle tokens according to p.shuffle_tok_range noise = tf.random.uniform([batch_size, source_max_len], 0, p.shuffle_tok_range + 1) # Don't shuffle eos or padding shuffle_tok_range = tf.fill([batch_size, source_max_len], float(p.shuffle_tok_range)) shifted_paddings = tf.pad(batch.src.paddings[:, 1:], [[0, 0], [0, 1]], constant_values=1) noise = tf.where(tf.equal(shifted_paddings, 0), noise, shuffle_tok_range) indices = tf.broadcast_to(tf.range(source_max_len, dtype=tf.int32), [batch_size, source_max_len]) noisy_indices = tf.cast(indices, dtype=tf.float32) + noise permutations = tf.argsort(noisy_indices) stacked = tf.stack([batch.src.ids, permutations], axis=1) denoise_src_ids = tf.stack(tf.map_fn(lambda x: tf.gather(x[0], x[1]), stacked), axis=0) # Select tokens to drop with probability=p.drop_tok_prob random_drop_tok = tf.random.uniform([batch_size, source_max_len]) # Don't drop eos token is_keep_tok = tf.math.logical_or( tf.greater(random_drop_tok, p.drop_tok_prob), tf.equal(denoise_src_ids, self._src_tokenizer.eos_id)) denoise_src_ids = tf.ragged.boolean_mask( denoise_src_ids, is_keep_tok).to_tensor(default_value=0, shape=tf.shape(batch.src.ids)) denoise_src_paddings = tf.ragged.boolean_mask( batch.src.paddings, is_keep_tok).to_tensor(default_value=1, shape=tf.shape(batch.src.ids)) # Select tokens to blank with probability=p.blank_tok_prob # Don't blank eos token random_blank_tok = tf.random.uniform([batch_size, source_max_len]) shifted_paddings = tf.pad(denoise_src_paddings[:, 1:], [[0, 0], [0, 1]], constant_values=1) is_blank_tok = tf.math.logical_and( tf.less(random_blank_tok, p.blank_tok_prob), tf.equal(shifted_paddings, 0)) blank_id = tf.fill([batch_size, source_max_len], p.blank_id) denoise_src_ids = tf.where(is_blank_tok, blank_id, denoise_src_ids) # Select denoising task examples with probability=p.denoise_sent_prob random_uniform_sent = tf.random.uniform([batch_size]) is_denoise_sent = tf.math.logical_and( tf.less(random_uniform_sent, p.noise_sent_prob), IsSpecialExample(self._GetTaskIds(batch.src.source_ids[:, 0]), p.task_ids)) batch.src.ids = tf.where(is_denoise_sent, denoise_src_ids, batch.src.ids) batch.src.paddings = tf.where(is_denoise_sent, denoise_src_paddings, batch.src.paddings) batch.src.ids_indicator = 1 - batch.src.paddings batch.src.weights = batch.src.ids_indicator
def PadToTargetSeqLen(tensor, constant): length = tf.shape(tensor)[1] pad = tf.maximum(0, p.beam_search.target_seq_len - length) return tf.pad(tensor, [[0, 0], [0, pad]], constant_values=constant)
def testPackSingleSequence(self, input_lengths, max_packed_length, require_sequential_order, expected_packed_idxs): with self.session() as sess: np.random.seed(12345) segment_ids, indices_in_input = sess.run( ops.pack_single_sequence( input_lengths=input_lengths, max_packed_length=max_packed_length, require_sequential_order=require_sequential_order)) self.assertLen(expected_packed_idxs, segment_ids.shape[0]) # Test the output is compatible with apply_packing. inputs = [] for i, length in enumerate(input_lengths): inputs.append( np.random.randint(100000, size=[length, 2, 2], dtype=np.int32)) outputs = sess.run( ops.apply_packing(input=tf.stack([ tf.pad( x, [[0, max_packed_length - x.shape[0]], [0, 0], [0, 0]]) for x in inputs ]), padding=0, segment_ids=segment_ids, indices_in_input=indices_in_input)) for segment_id, idxs, output, expected_idxs in zip( segment_ids, indices_in_input, outputs, expected_packed_idxs): # Build the expected results from the provided expected_packed_idxs. expected_segment_ids = [] expected_idxs_vec = [] expected_outputs = [] for i, idx in enumerate(expected_idxs): expected_segment_ids += [i + 1] * input_lengths[idx] expected_idxs_vec += [idx] * input_lengths[idx] expected_outputs.append(inputs[idx]) expected_outputs = np.concatenate(expected_outputs) expected_packed_length = len(expected_outputs) self.assertLessEqual(expected_packed_length, max_packed_length) self.assertLen(expected_segment_ids, expected_packed_length) self.assertLen(expected_idxs_vec, expected_packed_length) # Check indices_in_input is non-decreasing. if expected_packed_length > 1: self.assertAllGreaterEqual( idxs[1:expected_packed_length] - idxs[:expected_packed_length - 1], 0) # Pad to max_packed_length. pad_len = max_packed_length - expected_packed_length expected_segment_ids += [0] * pad_len expected_idxs_vec += [-1] * pad_len expected_outputs = np.pad(expected_outputs, [(0, pad_len), (0, 0), (0, 0)], mode='constant') self.assertAllEqual(expected_idxs_vec, idxs) self.assertAllEqual(expected_segment_ids, segment_id) self.assertAllEqual(expected_outputs, output)
def PadToTargetSeqLen(tensor, constant): length = tf.shape(tensor)[1] pad = p.target_seq_len - length return tf.pad(tensor, [[0, 0], [0, pad]], constant_values=constant)
def testConv2DLayerStridedWithPaddingFProp(self, seq_len): """Check strided convs get the same values for different length dim.""" # TODO(isaace): THIS TEST SHOWS THAT THERE IS A BUG IN THE CODE. with self.session(use_gpu=True): batch_size = 3 expected_seq_len = 3 params = conv_layers.Conv2DLayerWithPadding.Params() params.weight_norm = False params.filter_stride = [2, 2] params.name = 'conv' params.filter_shape = [3, 3, 1, 1] params.params_init = py_utils.WeightInit.Constant(1.0) conv_layer = params.Instantiate() # Set up the padding for the sequence length. (starting at 5). in_padding = tf.constant([ [0, 0, 0, 0, 0], [0, 0, 0, 0, 1], [0, 0, 0, 1, 1], ], tf.float32) in_padding = tf.pad( in_padding, [[0, 0], [0, seq_len - 5]], constant_values=1.0) inputs = 1.0 + tf.tile( tf.reshape(tf.range(seq_len, dtype=tf.float32), [1, seq_len, 1, 1]), [batch_size, 1, 3, 1]) inputs = py_utils.ApplyPadding( tf.reshape(in_padding, [batch_size, seq_len, 1, 1]), inputs) inputs = py_utils.Debug(inputs) output, out_padding = conv_layer.FPropDefaultTheta(inputs, in_padding) output = py_utils.Debug(output) out_padding = py_utils.Debug(out_padding) self.evaluate(tf.global_variables_initializer()) output, out_padding = self.evaluate([output, out_padding]) self.assertEqual((batch_size, expected_seq_len, 2, 1), output.shape) self.assertAllClose([ [0, 0, 1], [0, 0, 1], [0, 1, 1], ], out_padding) # This here shows a bug in the implementation; the output should be the # same. Also there are bugs with the output not having the correct # padding. if seq_len == 5: self.assertAllClose([ [[[6], [6]], [[18], [18]], [[18], [18]]], [[[6], [6]], [[18], [18]], [[8], [8]]], [[[6], [6]], [[10], [10]], [[0], [0]]], ], output) elif seq_len == 6: self.assertAllClose([ [[[12], [12]], [[24], [24]], [[10], [10]]], [[[12], [12]], [[14], [14]], [[0], [0]]], [[[12], [12]], [[6], [6]], [[0], [0]]], ], output) else: raise ValueError('Test does not handle length {seq_len}')
def PadOne(inp): inp = py_utils.HasShape(inp, [-1, -1, 3]) return tf.pad(inp, [[0, 0], [0, 0], [0, 1]], constant_values=1.0)
def flat_beam_search(batch_size, beam_size, max_steps, dec_callback, dec_state, bos_id=1, eos_id=2, length_norm_alpha=0.8, beam_gap=3.0, top_k_fn=tf.math.top_k, prefix=None, prefix_len=None, fprop_dtype=tf.float32, ext_size=0, nbest_size=None, debug=True): """Flat beam search. Args: batch_size: batch size beam_size: beam size limit in number of hyps max_steps: max steps dec_callback: decoder callback (see above) dec_state: decoder state bos_id: <s> token id eos_id: </s> token id length_norm_alpha: length normalization parameter beam_gap: early stopping threshold; None to disable top_k_fn: top_k function to call prefix: (optional) int32 tensor [batch_size, prefix_max] prefix_len: (optional) int32 tensor [batch_size] fprop_dtype: fprop dtype ext_size: int >= beam_size, extension buffer size nbest_size: number of returned hyps, default is beam_size debug: log intermediate vlaues with tpu_summary.tensor() Returns: (loop_vars, dec_state, nbest) where nbest = (topk_ids, topk_len, topk_score) """ assert beam_size > 0 assert batch_size > 0 assert max_steps > 0 buf_size = beam_size * max_steps output_len = max_steps if prefix is None: assert prefix_len is None prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32) prefix += tf.one_hot(0, beam_size, dtype=tf.int32) * bos_id prefix_len = tf.ones([batch_size], dtype=tf.int32) else: assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape) assert int(prefix_len.shape[0]) == batch_size, (batch_size, prefix_len.shape) output_len += int(prefix.shape[1]) if debug: tpu_summary.tensor('prefix', prefix) tpu_summary.tensor('prefix_len', prefix_len) with tf.name_scope('init_state'): t = tf.constant(0) tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32) tgt_id += bos_id tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32) tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype) tgt_mask += tf.one_hot(tf.range(beam_size), buf_size, dtype=fprop_dtype) hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype) # penalize all hyps except the first hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype) nbest_size = nbest_size or beam_size nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype) nbest_score -= 1e9 nbest_score_norm = nbest_score nbest_mask = tf.zeros([batch_size, nbest_size, buf_size], dtype=fprop_dtype) with tf.name_scope('init_ext'): # Initialize the extension buffer. # # Extension buffer stores a (potentially large) set of 'extensions', # which consist of a hypothesis (represented by ext_mask) and next token # (represented by ext_id). At each decoder iteration, top_k extensions # from each hypothesis are added to the buffer and sorted by score. # # Then top beam_size extensions are removed from the buffer and used # in the next decoder iteration. And top 'ext_size' remaining extensions # are carried over to be possibly evaluated at a later step. # # As a result of this manipulation, the decoder is no longer restricted # to always compare hyps of the same token length at each iteration. # In particular, for a fixed length N it can generate more than beam_size # terminated hyps. # # Setting ext_size = 0 disables this feautre. if ext_size: ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32) ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype) ext_score -= 1e9 ext_mask = tf.zeros([batch_size, ext_size, buf_size], dtype=fprop_dtype) else: ext_size = ext_id = ext_score = ext_mask = 0 with tf.name_scope('init_prefix'): # rename prefix->pfx for shorter variables pfx = tf.cast(prefix, tf.int32) pfx_len = tf.cast(prefix_len, tf.int32) del prefix, prefix_len # Before the first call to dec_callback() the prefix shall be packed into # the tgt_id buffer as follows: # # [ P P P P P P - - - - - - P* - - - ] ^ # [ P P P P P P P P P P - - P* - - - ] | batch # [ P - - - - - - - - - - - P* - - - ] V # |<---- prefix len ----> |<-- beam --> # # The last meaningful token in the prefix (P*) # must be located at the same position in all batch rows. # # We then make one dec_callback() with full prefix (minus P*) # which will populate the initial dec_state # (for transformer -- self-attention key/value cache) # # The last block [batch, beam] then becomes the first tgt_id for the loop. pfx_max = int(pfx.shape[1]) pfx_mul = pfx_max // beam_size assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size) pfx_time = tf.range(pfx_max) pfx_pad = tf.cast( tf.less(tf.expand_dims(pfx_time, 0), tf.expand_dims(pfx_len - 1, 1)), tf.int32) pfx_id = pfx * pfx_pad pfx_last = einsum_i32( 'BT,BT->B', pfx, tf.one_hot(pfx_len - 1, pfx_max, dtype=fprop_dtype)) buf_time = tf.range(buf_size) pfx_time_mask = tf.cast( tf.less_equal(tf.expand_dims(buf_time, 0), tf.expand_dims(pfx_time, 1)), fprop_dtype) pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype), pfx_time_mask) pfx_segment_id = pfx_pad pfx_pos = pfx_time * pfx_pad if debug: tpu_summary.tensor('pfx_id', pfx_id) tpu_summary.tensor('pfx_len', pfx_len) tpu_summary.tensor('pfx_pos', pfx_pos) tpu_summary.tensor('pfx_last', pfx_last) # Now call decoder with prefix minus P*: # 'dec_state' now shall contain the key/value cache for prefix tokens # (for transformer models), and 'logits' we can either discard or # roll into the initial hyp_score. Discard is simpler. with tf.name_scope('prefix_fprop'): # TODO(krikun): remove extra type checks assert (pfx_id.dtype == tf.int32), (pfx_id.dtype) assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype) assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype) assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype) assert (t.dtype == tf.int32), (t.dtype) logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos, pfx_mask, dec_state, t) del logits # Now construct the initial state for the rest of the beam search loop. # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape # 'tgt_pos' is different for each batch row and is equal to prefix_len # 'tgt_segment_id' always 1 (no packing) # 'hyp_score' is 0 for beam=0 and negative for beam>=1 tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims( pfx_last, 1) tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims( (pfx_len - 1), 1) hyp_score = tf.zeros( [batch_size, beam_size], dtype=fprop_dtype) - tf.cast( tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype) # TODO(krikun) Here we make initial 't' constant and determined by the # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic # as t ~ max(pfx_len) / beam_size and this will more steps for beam search # however 'max' results in a very slow all-to-all for 'max' on 16x16 # and variable number of decoder steps may result in bad latency. t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32) # Initial tgt_mask is such that each token P* has attention on itself # (as usual) and on all prefix tokens before it, which are not padding. tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype) tgt_mask += tf.cast( tf.expand_dims( tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1), fprop_dtype) tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) if debug: tpu_summary.tensor('tgt_id', tgt_id) tpu_summary.tensor('tgt_pos', tgt_pos) tpu_summary.tensor('tgt_mask', tgt_mask) tpu_summary.tensor('t', t) with tf.name_scope('init_hist'): # h_tgt_id is used to recover topk_ids from nbest_mask h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps) h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps) # When non-trivial prefix is present we also write prefix ids to # h_tgt_id so that the full sequence including prefix can be recovered # by unmask() below. When prefix is empty, pfx_id shape is [batch, 0] # and the loop below becomes a no-op. # TODO(krikun): maybe a tf.while_loop is more appropriate here. for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)): h_tgt_id = h_tgt_id.write(i, x_i) for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)): h_tgt_pos = h_tgt_pos.write(i, x_i) hist = (h_tgt_id, h_tgt_pos) tf.logging.info('hist=%r', hist) nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm) tf.logging.info('nbest_hyps=%r', nbest_hyps) ext = (ext_id, ext_score, ext_mask) tf.logging.info('ext=%r', ext) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) def loop_step(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (ext_id, ext_score, ext_mask) = ext (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id') h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos') # not using tf.ones() here because of XLA compilation error tgt_segment_id = tgt_id * 0 + 1 logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos, tgt_mask, dec_state, t) # take predicted EOS score for each hyp and compute normalized score eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype) def length_norm(t): t = tf.cast(t, fprop_dtype) alpha = length_norm_alpha tf.logging.info('length_norm.alpha=%r', alpha) return tf.math.pow((t + 5.) / 5., alpha) hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1) eos_score_norm = eos_score / length_norm(hyp_len) # update the n-best list nbest_hyps = update_nbest(nbest_hyps, (tgt_mask, hyp_score, eos_score_norm)) if debug: tpu_summary.tensor('eos_score', eos_score) tpu_summary.tensor('hyp_len', hyp_len) # take top k tokens for each hyp k = beam_size with tf.name_scope('topk1'): top_score, top_id = top_k_fn(logits, k) top_score = tf.cast(top_score, fprop_dtype) top_score += tf.expand_dims(hyp_score, -1) top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype) top_score = tf.reshape(top_score, [batch_size, beam_size * k]) top_id = tf.reshape(top_id, [batch_size, beam_size * k]) top_mask = tf.repeat(tgt_mask, beam_size, 1) if debug: tpu_summary.tensor('top_id', top_id) tpu_summary.tensor('top_score', top_score) # tpu_summary.tensor('top_mask', top_mask) with tf.name_scope('update_ext'): # combine top k tokens with extension buffer (if any) if ext_size: ext_id = tf.concat([ext_id, top_id], 1) ext_score = tf.concat([ext_score, top_score], 1) ext_mask = tf.concat([ext_mask, top_mask], 1) else: ext_id, ext_score, ext_mask = top_id, top_score, top_mask # sort by score ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size) i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype) ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1) ext_id = einsum_i32('bk,bjk->bj', ext_id, i1) # pick top beam_size extensions to evaluate at next iteration if ext_size: hyp_score = ext_score[:, :beam_size] ext_score = ext_score[:, beam_size:] tgt_id = ext_id[:, :beam_size] ext_id = ext_id[:, beam_size:] tgt_mask = ext_mask[:, :beam_size] ext_mask = ext_mask[:, beam_size:] else: hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask ext_score = ext_id = ext_mask = 0 tgt_pos = tf.reduce_sum(tgt_mask, -1) tgt_pos = tf.cast(tgt_pos, tf.int32) t += 1 with tf.name_scope('tgt_mask_extend'): tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) ext = (ext_id, ext_score, ext_mask) hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) return loop_vars, dec_state def loop_cond(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) if beam_gap is None: (t, _, _, _, _, _, _, _) = loop_vars return t < max_steps else: (t, _, _, _, _, nbest_hyps, _, _) = loop_vars (_, nbest_score, _) = nbest_hyps # stop early if all current hyps are significantly worse than nbest diff = tf.reduce_min( tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1)) return tf.math.logical_and(t < max_steps, diff < beam_gap) with tf.name_scope('flat_beam_search_loop'): (loop_vars, dec_state) = tf.while_loop(loop_cond, loop_step, loop_vars=(loop_vars, dec_state), back_prop=False, swap_memory=False, maximum_iterations=max_steps) # flatten all tensorarrays into tensors (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.stack() h_tgt_pos = h_tgt_pos.stack() hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) # recover topk_ids from nbest_mask and tgt_id history h = tf.transpose(h_tgt_id, [1, 0, 2]) h = tf.reshape(h, [batch_size, buf_size]) def unmask(h, m): with tf.name_scope('unmask'): tpu_summary.tensor('unmask_h', h) tpu_summary.tensor('unmask_m', m) t = tf.cumsum(m, -1) * m - 1 mh = einsum_i32('bkt,bt->bkt', m, h) t2 = tf.one_hot(tf.cast(t, tf.int32), output_len, dtype=fprop_dtype) x = einsum_i32('bkt,bktT->bkT', mh, t2) return tf.cast(x, h.dtype) topk_ids = unmask(h, nbest_mask) topk_len = tf.reduce_sum(nbest_mask, -1) topk_len = tf.cast(topk_len, tf.int32) # add eos, because nbest_mask does not encode eos topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32) topk_len += 1 topk_len = tf.minimum(topk_len, output_len) topk_score = nbest_score_norm nbest = (topk_ids, topk_len, topk_score) return loop_vars, dec_state, nbest