def SequenceAppendToken(x, x_paddings, token, extend=False): """Appends <token> to sequence `x`. Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. token: The token to append (of type integer). extend: Whether to extend `x` along the length dimension, this must be true for any sequence length in `x` that is `x_len_max` or else an invalid sequence will be emitted. Returns: A tuple. - The new sequence, Tensor of shape [batch_size, x_len_max]. - The new paddings, Tensor of shape [batch_size, x_len_max]. """ batch_size = py_utils.GetShape(x)[0] x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) if extend: x = tf.pad(x, [[0, 0], [0, 1]]) # Mask all invalid entries of `x` to 0. x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype) # Append the <token> based on `x_len`. x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1), tf.cast(tf.fill([batch_size], token), x.dtype), py_utils.GetShape(x)) x_paddings = 1 - tf.sequence_mask(x_len + 1, py_utils.GetShape(x)[1], x_paddings.dtype) return x, x_paddings
def ComputeConvOutputPadding(paddings, window, stride, padding_algorithm='SAME'): """Computes paddings for convolution and pooling output. out_padding[i] == 1 iff any in_padding corresponding to that output is 1. Args: paddings: The paddings tensor. It is expected to be of shape [batch, time]. window: The size of the windows. stride: The time-stride between adjacent windows. padding_algorithm: 'SAME' or 'VALID'. Returns: out_padding, The new padding tensor of size [batch, ceil(time / stride)]. """ if stride == 1: return paddings # Pad so input_length divides stride. input_length = py_utils.GetShape(paddings)[1] pad_len = (input_length + stride - 1) // stride * stride - input_length paddings = tf.pad(paddings, [[0, 0], [0, pad_len]], constant_values=1.0) out_padding = tf.nn.pool( tf.expand_dims(paddings, -1), [window], 'MAX', padding=padding_algorithm, strides=[stride], ) return tf.squeeze(out_padding, -1)
def AddMultiCurveSubplot(fig, tensors, paddings, labels, xlabels=None, **kwargs): """Adds a multi curve subplot to Matplotlib figure. Plots one line for each entry in tensors and assigns a plot label legend. Args: fig: The Matplotlib figure. tensors: List of tensors of shape [batch, length] paddings: Paddings for 'tensors' with shape [batch, length] with 0. in valid positions and 1. in invalid. labels: A list of tensor names (strings) of the same length as 'tensors'. xlabels: A string tensor of shape [batch] with an xlabel per batch. **kwargs: With optional, title, xlabel, ylabel, fontsize. """ data = [] row_labels = [] for t, l in zip(tensors, labels): if t is not None: data.append(py_utils.ApplyPadding(paddings, t)) row_labels.append(l) shape = py_utils.GetShape(data[0], 2) data = tf.reshape(tf.concat(data, -1), [shape[0], len(data), shape[1]]) args = [data, py_utils.LengthsFromPaddings(paddings)] if xlabels is not None: args.append(xlabels) fig.AddSubplot( args, plot_func=_AddMultiCurveRowPlots, row_labels=row_labels, **kwargs)
def _RelPositionBias(query, abs_pos_emb): """Computes relative position bias for general cases.""" _, t, n, h = py_utils.GetShape(query) abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h]) # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1] # Change to [T-1, T-2, ... 0, -1, -2, ... -(T-2), -(T-1)] abs_pos_emb = tf.reverse(abs_pos_emb, [0]) # [B, N, T, L=2T-1] term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb) # Convert to [B, N, T, T] # part1 term_bd_left = term_bd[:, :, :, :t] term_bd_left = tf.reverse(term_bd_left, [2, 3]) term_bd_left = RelShift(term_bd_left) # [B, N, T, T] term_bd_left = tf.reverse(term_bd_left, [2, 3]) # part 2 term_bd_right = term_bd[:, :, :, t - 1:] # [B, N, T, T] term_bd_right = RelShift(term_bd_right) # [lower triangle] mask = tf.linalg.band_part(tf.ones_like(term_bd_right), -1, 0) # stitching togather return tf.where(mask > 0, term_bd_left, term_bd_right)
def FProp(self, theta, inputs, paddings, domain_ids=None): """Applies data augmentation by randomly mask spectrum in inputs. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: A tensor of shape [batch, time, freq, num_channels]. paddings: A 0/1 tensor of shape [batch, time]. domain_ids: input domain_ids of shape [batch, time]. Returns: A pair of 2 tensors: - augmented_inputs: A tensor of shape [batch, time, freq, num_channels]. - paddings: A 0/1 tensor of shape [batch, time]. """ p = self.params global_seed = None # A tensor seed in case stateless random ops are needed. if p.use_input_dependent_random_seed: global_seed = _global_seed_from_inputs(inputs) batch_size, series_length, _, _ = py_utils.GetShape(inputs) if len(p.domain_ids) > 1: augmented_inputs = tf.zeros_like(inputs) original_inputs = inputs for i, domain_id in enumerate(p.domain_ids): augmented_domain = self._AugmentationNetwork( series_length, inputs, paddings, global_seed=global_seed, domain_id_index=i) target_domain = tf.cast(tf.expand_dims( tf.tile([domain_id], [batch_size]), -1), dtype=p.dtype) # [batch, time]. domain_mask = tf.cast(tf.equal(domain_ids, target_domain), dtype=p.dtype) augmented_domain = self.EinsumBxycBxBxyc( augmented_domain, domain_mask, name='einsum_domainmasking') original_inputs = self.EinsumBxycBxBxyc( original_inputs, 1.0 - domain_mask, name='einsum_domainmasking2') augmented_inputs = augmented_domain + augmented_inputs augmented_inputs = original_inputs + augmented_inputs else: augmented_inputs = self._AugmentationNetwork( series_length, inputs, paddings, global_seed=global_seed, domain_id_index=0) return augmented_inputs, paddings
def UnstackFeatures(self, src_inputs, src_paddings): """Unstacks src_input and src_paddings based off stack height.""" sh = self.params.stack_height bs, old_series_length, _, channels = py_utils.GetShape(src_inputs) unstacked_series_length = old_series_length * sh src_inputs = tf.reshape(src_inputs, [bs, unstacked_series_length, -1, channels]) content = 1 - src_paddings lengths = tf.cast(sh * tf.reduce_sum(content, axis=1), tf.int32) mask = tf.sequence_mask(lengths, maxlen=unstacked_series_length) src_paddings = 1 - tf.cast(mask, tf.int32) return src_inputs, src_paddings
def _TimeWarp(self, inputs, seq_lengths, global_seed, dtype=tf.float32, domain_id_index=0): """Applies time warping with given degree to inputs. Args: inputs: Batch of input features of shape (batch_size, time_length, num_freq, channels). seq_lengths: The actual sequence lengths which mask been sampled of shape (batch_size,). global_seed: an integer seed tensor for stateless random ops. dtype: Data type. domain_id_index: Domain ID index. Returns: Inputs with random time warping applied. """ p = self.params batch_size, time_length, _, _ = py_utils.GetShape(inputs) # Get parameters for warping. time_warp_max_frames = p.time_warp_max_frames[domain_id_index] max_ratio = p.time_warp_max_ratio[domain_id_index] time_warp_bound = p.time_warp_bound[domain_id_index] assert time_warp_bound in ('static', 'dynamic') # If maximum warp length is zero, do nothing. if ((time_warp_max_frames == 0 and time_warp_bound == 'static') or max_ratio <= 0.0): return inputs seq_lengths = tf.cast(seq_lengths, tf.int32) # Discard upper-bound on time-warp frames when # dynamic time warping is used. if time_warp_bound == 'dynamic': time_warp_max_frames = None # Create warping matrix in time direction and apply warp_matrix = self._GetWarpMatrix(batch_size, choose_range=seq_lengths, matrix_size=time_length, global_seed=global_seed, max_warp_frames=time_warp_max_frames, dtype=dtype, max_ratio=max_ratio) return self.EinsumBxycBzxBzyc(inputs, warp_matrix, name='einsum_forwarping')
def SequenceConcat(x, x_paddings, y, y_paddings, pad=0): """Concats sequence `x` with sequence `y`. This function is length aware (based off the paddings). Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. y: A sequence of tokens of shape [batch_size, y_len_max]. y_paddings: The paddings of `y`. pad: The <pad> token to fill the concatenated sequence (of type integer). Returns: A tuple. - Concatenation of `x` and `y` of shape [batch_size, x_len_max + y_len_max]. - Paddings of the concatenation of shape [batch_size, x_len_max + y_len_max]. """ # Get the length (w/ eos). x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32) batch_size = py_utils.GetShape(x)[0] y_len_max = py_utils.GetShape(y)[1] # Pad `x` with necessary <pad>. x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1) # Replace all <pad> with 0. x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0)) # Compute the write indices of `y` in `xy`. indices = tf.stack([ tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]), (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) + tf.expand_dims(x_len, 1)), ], 2) xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x)) # We need to remap all <pad> to `pad`. xy = tf.where( tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0), tf.expand_dims(x_len + y_len, 1)), xy, tf.fill(py_utils.GetShape(xy), pad)) xy_paddings = 1 - tf.sequence_mask(x_len + y_len, py_utils.GetShape(xy)[1], x_paddings.dtype) return xy, xy_paddings
def FProp(self, theta, inputs, paddings): """Apply global spatial pooling to inputs. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. It is expected to be of shape [batch, time, frequency, channel]. The time dimension corresponds to the height dimension as in images and the frequency dimension corresponds to the width dimension as in images. paddings: The paddings tensor. It is expected to be of shape [batch, time]. Defaults to None, which means there no paddings. Returns: outputs, out_paddings pair. - outputs: has shape [batch, 1, 1, channel]. - out_paddings: None or has shape [batch, 1]. """ p = self.params assert p.pooling_type in ['MAX', 'AVG'], p.pooling_type b, t, f = py_utils.GetShape(inputs, ndims=3) if paddings is not None: paddings = py_utils.HasShape(paddings, [b, t]) if paddings is not None: mask = 1.0 - paddings[..., tf.newaxis, tf.newaxis] else: mask = tf.ones([b, t, 1, 1], p.dtype) if p.pooling_type == 'AVG': global_sum = tf.reduce_sum(inputs * mask, axis=[1, 2], keepdims=True) f = tf.cast(tf.convert_to_tensor(f), p.dtype) count = f * tf.reduce_sum(mask, axis=[1, 2], keepdims=True) out_feature = global_sum / tf.maximum(1.0, count) elif p.pooling_type == 'MAX': large_negative = (tf.ones_like(inputs) * p.dtype.max * tf.constant(-0.7, dtype=p.dtype)) padded_inputs = tf.where_v2(mask > 0.0, inputs, large_negative) out_feature = tf.reduce_max(padded_inputs, axis=[1, 2], keepdims=True) if paddings is None: out_paddings = None else: out_paddings = tf.reduce_min(paddings, axis=1, keepdims=True) out_feature *= 1.0 - out_paddings[..., tf.newaxis, tf.newaxis] return out_feature, out_paddings
def _AttenLogits(query, key, abs_pos_emb, content_bias=None, positional_bias=None, is_causal=False): """Attention logits from ... Transformer-XL(https://arxiv.org/pdf/1901.02860.pdf, section 3.3) version of self attention with relative position embedding. Notice padding is supposed to be masked by the caller of this function. B: batch size T: sequence length N: num of attention heads. H: per-head attention dimension. Args: tensors of the following shapes: query: [B, T, N, H] key: [B, T, N, H] abs_pos_emb: [2T - 1, N, H]. The sinusoid positional embedding from https://arxiv.org/abs/1706.03762. abs_pos_emb[i] is the emb of relative distance i - (T-1). content_bias: [N, H] or None positional_bias: [N, H] or None is_causal: A Python bool or a scalar bool Tensor. True for causal self attention. Returns: The attention logits tensor. [B, N, T, T] """ b, t, n, h = py_utils.GetShape(query) key = py_utils.HasShape(key, [b, t, n, h]) if content_bias is not None: content_bias = py_utils.HasShape(content_bias, [n, h]) else: content_bias = 0 if positional_bias is not None: positional_bias = py_utils.HasShape(positional_bias, [n, h]) else: positional_bias = 0 # [B, N, T, S=T] term_ac = tf.einsum('BTNH,BSNH->BNTS', query + content_bias, key) term_bd = RelPositionBias(query + positional_bias, abs_pos_emb, is_causal) return term_ac + term_bd
def ComputeLoss(self, theta, predicted, input_batch): diff = predicted - input_batch.tgt_ids per_example_loss = diff * diff batch_dim = py_utils.GetShape(per_example_loss)[0] def replicate_var(name): return tf.convert_to_tensor( [self._private_vars[name]] * batch_dim, dtype=tf.float32) metrics = {'loss': (tf.reduce_sum(per_example_loss), batch_dim)} per_example_tensors = { 'input': input_batch.src_ids, 'loss': per_example_loss, 'diff': diff, 'm': replicate_var('m'), 'b': replicate_var('b'), } return metrics, per_example_tensors
def _RelPositionBiasCausal(query, abs_pos_emb): """Computes relative position bias for causal self attention.""" _, t, n, h = py_utils.GetShape(query) abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h]) # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1] # Retain only half and change order to [T-1, T-2, ... 0] # [T, N, H] abs_pos_emb = tf.reverse(abs_pos_emb, [0])[:t] # [B, N, T, L=T] term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb) # Perform shifting. term_bd = tf.reverse(term_bd, [2, 3]) term_bd = RelShift(term_bd) return tf.reverse(term_bd, [2, 3])
def _Slice(tensor): """Return a slice of this tensor at time=state0.t.""" shape = py_utils.GetShape(tensor) # All zeros except for t in the time dimension. # e.g. if params.axis=1, begin is [0, t, 0, 0, 0, ...] begin = tf.one_hot(self.params.axis, tf.rank(tensor), on_value=state0.t) # Same as shape, but with a 1 in the time dimension. # e.g. if params.axis=1, shape is [shape[0], 1, shape[2], shape[3], ...] size = tf.concat([ shape[0:self.params.axis], tf.constant([1], dtype=tf.int32), shape[self.params.axis + 1:] ], axis=0) # Make a slice where the time dimension is fixed at state0.t. time_slice = tf.slice(tensor, begin, size) # Remove the time dimension. return tf.squeeze(time_slice, axis=self.params.axis)
def RelShift(x): """Performs relative shift on 4D tensor (first 2 axis are batching dims). Given input of shape [?, ?, W, W], this does "relative shifting" for the last two dims, s.t. output[b, n, i, j] = 0 if i > j else input[b, n, i, j-i] Args: x: A Tensor of shape [?, ?, W, W] Returns: A Tensor of the same shape as input with its content shifted (as described above). """ b, n, w, _ = py_utils.GetShape(x) x = py_utils.HasShape(x, [-1, -1, w, w]) x = tf.pad(x, ((0, 0), (0, 0), (0, 0), (0, 1))) x = tf.reshape(x, [b, n, w + 1, w]) x = x[:, :, :w, :] return x
def SequenceTrimLastToken(x, x_paddings): """Trims the last token off of sequence `x`, and set trimmed elements to 0. Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. Returns: A tuple. - The new sequence, Tensor of shape [batch_size, x_len_max]. - The new paddings, Tensor of shape [batch_size, x_len_max]. """ x_len = tf.reduce_sum(1 - x_paddings, 1) x_len_max = py_utils.GetShape(x)[1] x_trimmed_len = tf.maximum(x_len - 1, 0) x_trimmed_paddings = tf.sequence_mask(x_trimmed_len, x_len_max, x_paddings.dtype) x_trimmed = x * tf.cast(x_trimmed_paddings, x.dtype) return x_trimmed, 1 - x_trimmed_paddings
def _FrequencyMask(self, inputs, global_seed, dtype=tf.float32, domain_id_index=0): """Applies frequency masking with given degree to inputs. Args: inputs: Batch of input features of shape (batch_size, time_length, num_freq, channels). global_seed: an integer seed tensor for stateless random ops. dtype: Data type. domain_id_index: domain id index. Returns: Inputs with random frequency masking applied. """ p = self.params # Mask parameters. freq_mask_max_bins = p.freq_mask_max_bins[domain_id_index] multiplicity = p.freq_mask_count[domain_id_index] # If masking length or count is zero, do nothing. if freq_mask_max_bins == 0 or multiplicity == 0: return inputs # Arguments to pass to mask generator. batch_size, _, num_freq, _ = py_utils.GetShape(inputs) choose_range = tf.cast(tf.broadcast_to(num_freq, (batch_size, )), dtype=tf.int32) # Create masks in frequency direction and apply. block_arrays = self._GetMask(tf.shape(inputs)[0], choose_range=choose_range, mask_size=num_freq, global_seed=global_seed, max_length=freq_mask_max_bins, masks_per_frame=0.0, multiplicity=multiplicity, dtype=dtype, max_ratio=1.0) return self.EinsumBxycByBxyc(inputs, block_arrays)
def PrepareSequenceForPlot(tensor, padding, name): """Prepares a sequence feature for plotting. The sequence feature is transposed and channels are flattened. Args: tensor: A n-D Tensor of shape [batch, time, ...]. padding: A Tensor of shape [batch, time]. name: A string as the name of the reshaped Tensor, which will be used as the subcaption for plotting. Returns: A tuple of: reshaped_tensor: A 3-D Tensor of shape [batch, dim, time]. sequence_length: A 1-D Tensor of shape [batch]. """ # Flatten any dimensions beyond the third into the third. batch_size, max_len = py_utils.GetShape(tensor, 2) plot_tensor = tf.reshape(tensor, [batch_size, max_len, -1]) plot_tensor = tf.transpose(plot_tensor, [0, 2, 1], name=name) return (plot_tensor, SequenceLength(padding))
def ConvertToBlocks(x, block_size, padding_val=0.0): """Turns a sequence to non overlapping blocks. Args: x: a tensor of [batch, time, ...]. block_size: int. Number of time frames in a block. padding_val: float. value on the padded frames. Returns: A tensor of [batch, num_blocks, block_size, ...], with necessary paddings, where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...]. """ shape = py_utils.GetShape(x) b, t = shape[:2] if block_size < 1: raise ValueError( 'block_size must be at least 1, got {}'.format(block_size)) w = block_size # Pad t to be a multiply of w. num_blocks = (t + w - 1) // w pad_to_length = num_blocks * w padded = py_utils.PadSequenceDimension(x, pad_to_length, padding_val) reshaped = tf.reshape(padded, [b, num_blocks, w] + shape[2:]) return reshaped
def ZeroState(self, theta, prepared_inputs, batch_size): """Produce a zero state for this step. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. prepared_inputs: A set of inputs pre-processed by using PrepareExternalInputs. batch_size: Number of elements in the batched input. Returns: state0, a state parameter to pass to FProp on its first invocation. """ max_seq_length = py_utils.GetShape(prepared_inputs.src, 3)[0] atten_state = self.atten.ZeroAttentionState(max_seq_length, batch_size) (new_atten_context, _, new_atten_states) = self.atten.ComputeContextVectorWithSource( theta.atten, prepared_inputs.packed_src, tf.zeros([batch_size, self.params.atten.query_dim], dtype=py_utils.FPropDtype(self.params)), attention_state=atten_state) return py_utils.NestedMap(atten_context=new_atten_context, atten_state=new_atten_states)
def _CreateCanvasAndTargets(self, batch): # pyformat: disable """Create the canvas and targets. Args: batch: A `.NestedMap`. - src: A `.NestedMap`. - ids: The source ids, ends in <eos>. - paddings: The source paddings. - tgt: A `.NestedMap`. - ids: The target ids, ends in <eos>. - paddings: The target paddings. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices (i.e., use these indices to tf.gather_nd the log-probs). Optional, only during training. - target_weights: The target weights. Optional, only during training. """ # pyformat: enable p = self.params if not self.do_eval: # Sample our src and tgt canvas. src_descriptor = self._SampleCanvasAndTargets( batch.src.ids, batch.src.paddings) tgt_descriptor = self._SampleCanvasAndTargets( batch.tgt.ids, batch.tgt.paddings) # Offset the src ids (to unshare embeddings between src/tgt). Note, we # only offset the canvas ids, but we do not offset the vocab ids. This # will result in unshared embeddings, but shared softmax. This is due to # GPU/TPU memory limitations, empirically it is known that unsharing # everything results in better performance. vocab_size = p.decoder.softmax.num_classes src_descriptor.canvas = tf.where( tf.equal(src_descriptor.canvas_paddings, 0), src_descriptor.canvas + vocab_size, src_descriptor.canvas) # Offset the tgt indices (need shift according to src length). batch_size = py_utils.GetShape(batch.src.ids)[0] # `target_batch` is a [num_targets, batch_size] tensor where each row # identifies which batch the target belongs to. Note the observation that, # tf.reduce_sum(target_batch, 1) == 1 \forall rows. target_batch = tf.cast( tf.equal( tf.expand_dims(tf.range(batch_size), 0), tf.expand_dims(tgt_descriptor.target_indices[:, 0], 1)), tf.int32) src_lens = tf.cast( tf.reduce_sum(1 - src_descriptor.canvas_paddings, 1), tf.int32) # `tgt_offset` is shape [num_targets] where each entry corresponds to the # offset needed for that target (due to the source length). tgt_offset = tf.matmul(target_batch, tf.expand_dims(src_lens, 1)) # We shift the tgt slot without touching the batch or vocab. tgt_descriptor.target_indices += tf.concat([ tf.zeros_like(tgt_offset), tgt_offset, tf.zeros_like(tgt_offset) ], 1) # The canvas is simply the sequence-level concat of the src and tgt. canvas, canvas_paddings = insertion.SequenceConcat( src_descriptor.canvas, src_descriptor.canvas_paddings, tgt_descriptor.canvas, tgt_descriptor.canvas_paddings) target_indices = tf.concat( [src_descriptor.target_indices, tgt_descriptor.target_indices], 0) target_weights = tf.concat( [src_descriptor.target_weights, tgt_descriptor.target_weights], 0) return py_utils.NestedMap(canvas=canvas, canvas_paddings=canvas_paddings, target_indices=target_indices, target_weights=target_weights)
def _TimeMask(self, inputs, seq_lengths, global_seed, noisify=False, gaussian_noise=False, dtype=tf.float32, domain_id_index=0): """Applies time masking with given degree to inputs. Args: inputs: Batch of input features of shape (batch_size, time_length, num_freq, channels). seq_lengths: The actual sequence lengths which mask been sampled of shape (batch_size,). global_seed: an integer seed tensor for stateless random ops. noisify: Whether to noisify the masked out regions. gaussian_noise: Whether to use gaussian noise when noisifying. dtype: Data type. domain_id_index: domain id index. Returns: Inputs with random time masking applied. """ p = self.params # Get time masking parameters. time_mask_max_frames = p.time_mask_max_frames[domain_id_index] time_masks_per_frame = p.time_masks_per_frame[domain_id_index] use_dynamic_time_mask_max_frames = \ p.use_dynamic_time_mask_max_frames[domain_id_index] multiplicity = p.time_mask_count[domain_id_index] max_ratio = p.time_mask_max_ratio[domain_id_index] # If maximum mask length is zero, do nothing. if ((time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames) or max_ratio <= 0.0): return inputs if multiplicity == 0: return inputs seq_lengths = tf.cast(seq_lengths, tf.int32) batch_size, time_length, _, _ = py_utils.GetShape(inputs) # When using dynamic time mask size, discard upper-bound on # maximum allowed frames for time mask. if use_dynamic_time_mask_max_frames: time_mask_max_frames = None # Create masks in time direction and apply. block_arrays = self._GetMask(batch_size, choose_range=seq_lengths, mask_size=time_length, global_seed=global_seed, max_length=time_mask_max_frames, masks_per_frame=time_masks_per_frame, multiplicity=multiplicity, dtype=dtype, max_ratio=max_ratio) # Non-empty random seed values are only used for testing or when using # stateless random ops. seed_6 and seed_7 are set separately to avoid # correlation of warp magnitude and origin position. if p.use_input_dependent_random_seed: seed_6 = global_seed + 6 seed_7 = global_seed + 7 else: seed_6 = p.random_seed seed_7 = p.random_seed outputs = self.EinsumBxycBxBxyc(inputs, block_arrays, name='einsum_formasking') if noisify: # Sample noise with standard deviation with factor * 0.1 + 0.0001 # TODO(ngyuzh): Make sure this won't affect EOS. if gaussian_noise: stddev = 1.0 else: random_uniform = _random_uniform_op( p.use_input_dependent_random_seed) factor = random_uniform(shape=(), minval=1.0, maxval=2.0, dtype=dtype, seed=seed_6) stddev = factor * 0.1 + 0.0001 random_normal = _random_normal_op( p.use_input_dependent_random_seed) noise = random_normal(shape=[ tf.shape(inputs)[0], tf.shape(inputs)[1], tf.shape(inputs)[2] ], stddev=stddev, seed=seed_7) if p.fprop_dtype is not None and p.fprop_dtype != p.dtype: noise = tf.cast(noise, p.fprop_dtype) outputs_mask = self.EinsumBxyBxBxy(noise, 1.0 - block_arrays, name='einsum_fornoisymasking') outputs = outputs + tf.expand_dims(outputs_mask, -1) return outputs
def FProp(self, theta, inputs, paddings, state0=None, segment_id=None): """Computes LSTM forward pass. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: A single tensor or a tuple of tensors with cardinality equal to rnn_cell.inputs_arity. For every input tensor, the first dimension is assumed to be time, second dimension batch, and third dimension depth. paddings: A tensor. First dim is time, second dim is batch, and third dim is expected to be 1. state0: If not None, the initial rnn state in a `.NestedMap`. Defaults to the cell's zero-state. segment_id: A tensor to support packed inputs. First dim is time, second dim is batch, and third dim is expected to be 1. Returns: A tensor of [time, batch, dims]. The final recurrent state. """ p = self.params rcell = self.cell assert isinstance(rcell, (rnn_cell.RNNCell)) if not isinstance(inputs, (list, tuple)): inputs = [inputs] # Slicing wm to wm_{i,h} outside the loop to get 20% speedup over regular # LSTM baseline. # Keeping slicing within the loop gives only < 3% speedup. cell_theta = theta.cell.copy() num_input_nodes = p.cell.num_input_nodes cell_theta['wm_i'] = cell_theta.wm[:num_input_nodes, :] cell_theta['wm_h'] = cell_theta.wm[num_input_nodes:, :] tf.logging.vlog(1, 'cell_theta: %r', cell_theta) if p.packed_input: assert segment_id is not None reset_mask = rnn_layers.GeneratePackedInputResetMask( segment_id, is_reverse=False) reset_mask = py_utils.HasShape(reset_mask, tf.shape(paddings)) else: reset_mask = tf.zeros_like(paddings) if p.reverse: inputs = [tf.reverse(x, [0]) for x in inputs] paddings = tf.reverse(paddings, [0]) reset_mask = tf.reverse(reset_mask, [0]) if not state0: batch_size = py_utils.GetShape(paddings)[1] state0 = rcell.zero_state(cell_theta, batch_size) # [T, B, H] proj_inputs = rcell.ProjectInputSequence( cell_theta, py_utils.NestedMap(act=inputs)) proj_inputs = py_utils.NestedMap(proj_inputs=proj_inputs, padding=paddings, reset_mask=reset_mask) acc_state, final_state = recurrent.Recurrent( theta=cell_theta, state0=state0, inputs=proj_inputs, cell_fn=rcell.FPropWithProjectedInput, cell_type=rcell.layer_type, accumulator_layer=self, allow_implicit_capture=p.allow_implicit_capture) act = rcell.GetOutput(acc_state) if p.reverse: act = tf.reverse(act, [0]) return act, final_state
def _StringsToIdsImpl(self, strs, max_length, append_eos, languages): """Takes a tensor of strings and returns id/padding tensors. This generates `token_ids`, `target_ids`, and `paddings` in the format that is expected for tokenizers. This performs padding to a fixed length and appends the end-of-sentence token as appropriate. Args: strs: a string Tensor. max_length: a python integer. The second dimension of the returned arrays. All sequences are padded or truncated to that length. append_eos: a python bool. See `BaseTokenizer` for explanation. languages: A vector of strings with the same length as `strs`. Returns: A tuple of 3 tensors: - token_ids: a tensor of sequences of WPM ids starting with SOS. Sequences always end with EOS unless the sequence exceeds the maximum length. Always padded with EOS. - target_ids: a tensor of sequences of WPM ids not starting with SOS but ending with EOS. Always padded with EOS. - paddings: a tensor of floats indicating, at each position, whether the corresponding position is padded. """ p = self.params if append_eos is None: append_eos = p.append_eos batch_size = py_utils.GetShape(strs)[0] token_ids_ta = tf.TensorArray(tf.int32, batch_size) target_ids_ta = tf.TensorArray(tf.int32, batch_size) paddings_ta = tf.TensorArray(tf.float32, batch_size) def _TokenizeOneSentence(i, strs, token_ids_ta, target_ids_ta, paddings_ta): """Tokenizes a single sentence.""" ids, _ = self._wpm_encoder.Encode(strs[i]) if append_eos: ids = tf.concat([ids, [self.eos_id]], axis=0) # This truncates after the eos is added, so some sentences might # not have </s> at the end. token_ids_ta = token_ids_ta.write( i, py_utils.PadOrTrimTo(tf.concat([[self.sos_id], ids], axis=0), [max_length], self.eos_id)) target_ids_ta = target_ids_ta.write( i, py_utils.PadOrTrimTo(ids, [max_length], self.eos_id)) paddings_ta = paddings_ta.write( i, py_utils.PadOrTrimTo(tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.)) return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta _, _, token_ids_ta, target_ids_ta, paddings_ta = tf.while_loop( lambda i, *_: i < batch_size, _TokenizeOneSentence, loop_vars=(tf.constant(0, tf.int32), strs, token_ids_ta, target_ids_ta, paddings_ta), parallel_iterations=30, back_prop=False) token_ids = token_ids_ta.stack() target_ids = target_ids_ta.stack() paddings = paddings_ta.stack() if not p.pad_to_max_length: maxlen = tf.cast( tf.round(tf.reduce_max(tf.reduce_sum(1.0 - paddings, axis=1))), tf.int32) token_ids = token_ids[:, :maxlen] target_ids = target_ids[:, :maxlen] paddings = paddings[:, :maxlen] return token_ids, target_ids, paddings
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def BuildDataSource(self, data_source_from_file_pattern_fn): """Read and return input batch from a p.file_pattern list. `p.file_patterns` is a list of file patterns, `p.weights` contains weights for each file pattern. If provided `p.bprop_variable_filters` includes a bprop_variable_filter for each file pattern. Args: data_source_from_file_pattern_fn: a function that takes file_pattern as an argument and returns an input batch. Returns: A NestedMap containing: data: a tuple of tf.Tensor or `.NestedMap` of tf.Tensor source_selected: a tensor of size [batch_size, number of data sources] selected_bprop: a tensor of size [number of data sources] bprop_variable_filters: containing a list of bprop_variable filters for each source Raises: ValueError: If unknown token type. """ p = self.params def _MakeDataSourceFromFilePatternFunc( data_source_from_file_pattern_fn, file_pattern): # It's important to invoke self._DataSourceFromFilePattern() inside the # lambda to make sure that the record is drawn from data source # only if it will be used. Weights are handled by MixByWeight, not the # data_source_from_file_pattern_fn. return lambda: data_source_from_file_pattern_fn(file_pattern) if len(p.weights) != len(p.file_patterns): raise ValueError( 'Expected p.file_patterns and p.weights to be the same length. ' 'Found %d file_patterns, and %d weights' % (len(p.file_patterns), len(p.weights))) if not all(isinstance(x, six.string_types) for x in p.file_patterns): raise ValueError( 'Expected all elements of p.file_patterns to be strings') # TODO(rosenberg) replace this with functools.partial inputs = [ _MakeDataSourceFromFilePatternFunc( data_source_from_file_pattern_fn, file_pattern) for file_pattern in p.file_patterns ] weights = p.weights if not p.bprop_variable_filters: bprop_variable_filters = [''] * len(inputs) else: bprop_variable_filters = p.bprop_variable_filters data_source, selected_bprop = py_utils.MixByWeight(inputs, weights, seed=p.random_seed) # TODO(neerajgaur): Remove _bprop_onehot and change code that uses it to # use source_selected from input_batch. batch_size = py_utils.GetShape(tf.nest.flatten(data_source)[0])[0] ret = py_utils.NestedMap() ret.data = data_source ret.bprop_variable_filters = bprop_variable_filters ret.selected_bprop = selected_bprop ret.source_selected = tf.tile(tf.expand_dims(selected_bprop, 0), [batch_size, 1]) return ret
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` object containing: ids - The inputs tensor of shape [batch, time]. paddings - The ids' paddings of shape [batch, time]. Returns: A '.NestedMap' object containing: encoded - The encoded features of shape [time, batch, dim] or [batch, time, dim], depending p.output_data_format. padding - The encoded features' padding of shape [time, batch] or [batch, time]. segment_id - The segmentation of packed inputs of shape [time, batch] or [batch, time] if it is supported by the model, or None otherwise. embedded_inputs - The embedded inputs tokens without positional encodings of shape [time, batch, dim] or [batch, time, dim]. """ p = self.params with tf.name_scope(p.name): # [batch, time] input_ids = input_batch.ids # [batch, time] paddings = input_batch.paddings # [batch, time] segment_ids = input_batch.segment_ids if p.packed_input else None batch = py_utils.GetShape(input_ids)[0] time = py_utils.GetShape(input_ids)[1] # Embedding layer. # [batch, time, dim] if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, input_ids) else: input_embs = self.softmax.EmbLookup(theta.softmax, input_ids) orig_input_embs = input_embs # [1, time, dim] if p.packed_input: positions = input_batch.segment_pos position_embs = tf.expand_dims( self.position_emb.FPropWithPosition( theta.position_emb, positions), 0) else: position_embs = tf.expand_dims( self.position_emb.FProp(theta.position_emb, time), 0) # [batch, time, dim] input_embs += tf.cast(position_embs, tf.bfloat16) if p.input_dropout_tpl.fprop_dtype: input_embs = tf.cast(input_embs, p.input_dropout_tpl.fprop_dtype) paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [batch, time, dim] transformer_input = input_embs # Explicitly set the input shape of Transformer layers, to avoid # unknown shape error occurred to tf.einsum on nonTPU devices. transformer_input = tf.reshape(transformer_input, [batch, time, p.model_dim]) # Compute self-attention segment mask once. if p.packed_input: segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids, dtype=transformer_input.dtype) else: segment_mask = tf.zeros([batch, 1, time, time]) encoded, padding = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, segment_mask) if p.final_layer_norm: encoded = self.final_ln.FProp(theta.final_ln, encoded) seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1), tf.int32) if p.output_data_format == 'TBC': encoded = tf.transpose(encoded, [1, 0, 2]) # [time, batch, dim] padding = tf.transpose(padding) # [time, batch] segment_ids = tf.transpose( segment_ids) if p.packed_input else None orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2]) return py_utils.NestedMap( encoded=encoded, padding=padding, seq_lengths=seq_lengths, # used by beam_search_helper. segment_id=segment_ids, embedded_inputs=orig_input_embs)
def _BeamSearchDecodeIds(self, theta, encoder_outputs, num_hyps_per_beam, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap computed by encoder. num_hyps_per_beam: Number of hyps per beam. init_beam_search_state: The InitBeamSearchState callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The PreBeamSearchStepCallback callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The PostBeamSearchStepCallback callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: hyps: A tensor of shape [time, b * k] with ids of the token selected. prev_hyps: A tensor of shape [time, b * k] with index to the previous hyps which was selected. done_hyps: A boolean tensor of shape [time, b * k] where value indicates if hyps was terminated. scores: A tensor of shape [time, b * k] with scores of the token selected. atten_probs: A tensor of shape [time, b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. eos_scores: A tensor of shape [time, b * k] with scores of the eos token selected. eos_atten_probs: A tensor of shape [time, b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. source_seq_lengths: A tensor of shape [time] containing the source seq_lengths. flat_final_other_states: A array of tensors that are part of other states. """ p = self.params source_paddings = encoder_outputs.padding initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam # We cache the NestedMap as member variable so that we can use it to # pack the final outputs. Tpu rewrite methods forces us to strictly pass # in Tensors, and output Tensors self._other_states = other_states step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 fprop_dtype = py_utils.FPropDtype(p) best_scores = (tf.zeros(shape=[num_beams], dtype=fprop_dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=fprop_dtype) histories = tf.zeros(shape=[num_hyps], dtype=tf.int32) in_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_prev_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_done_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_eos_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_eos_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) # States for beam search that are inputs into Beam search step. accum_bs_states = [best_scores, cumulative_scores, histories] # States that are not accumulators. non_accum_bs_states = [ in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, in_eos_scores, in_eos_atten_probs, ] core_bs_states = tuple(accum_bs_states + non_accum_bs_states) flat_other_states = other_states.Flatten() # If there is an optimized implementation for short sequence, LoopBodyShort # will run first for short_seq_limit steps (after which the # LoopBodyShort does not have performance benefit). Then LoopBodyLong (the # default implementation) is used to continue the rest of the steps. For # decoders which do not have the short sequence specific implementation, # only the LoopBodyLong (the default implementation) will run. if p.short_seq_limit > 0: def LoopContinueShort(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): """Use short_seq optimization when cur_step is smaller than limit.""" return tf.math.logical_and(cur_step < p.short_seq_limit, tf.math.logical_not(all_done)) def LoopBodyShort(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): """Loop body of short_seq optimization. Instead of doing computation for the entire padded sequence, while loop with early exit is used within each _BeamSearchStep to do computation for only the actual sequence (seq_length <= cur_step). use_short_seq_opt is used as the flag to pass this information down to the decoder implementation. Args: cur_step: A scalar int tensor, the current time step, 0-based. unused_all_done: A tf.bool, indicating whether the decoding finishes. step_ids: An int32 tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. other_states_list: A flattened NestedMap of other beam search states. Returns: The updated input tuple, with the same shape. """ (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=True) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) (cur_step, all_done, step_ids, core_bs_states, flat_other_states) = tf.while_loop( LoopContinueShort, LoopBodyShort, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=( tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), tuple( list(_GetShapes(accum_bs_states)) + list(_GetShapes(non_accum_bs_states, none_shapes=True))), _GetShapes(flat_other_states, none_shapes=True)), maximum_iterations=max_steps) def LoopContinueLong(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): """Continue default implementation until decoding finishes.""" return tf.math.logical_and(cur_step < max_steps, tf.math.logical_not(all_done)) def LoopBodyLong(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): """Loop body of default long_seq implementation.""" (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=False) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinueLong, LoopBodyLong, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=( tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), tuple( list(_GetShapes(accum_bs_states)) + list(_GetShapes(non_accum_bs_states, none_shapes=True))), _GetShapes(flat_other_states, none_shapes=False)), maximum_iterations=max_steps) if isinstance(source_paddings, py_utils.NestedMap): source_seq_lengths = tf.cast(tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), dtype=tf.int32) else: source_seq_lengths = tf.cast(tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), dtype=tf.int32) # Concatenate all outputs on axis=0. scores = final_bs_states[3].stack() hyps = final_bs_states[4].stack() prev_hyps = final_bs_states[5].stack() done_hyps = tf.cast(final_bs_states[6].stack(), tf.bool) atten_probs = final_bs_states[7].stack() eos_scores = final_bs_states[8].stack() eos_atten_probs = final_bs_states[9].stack() rets = (hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores, eos_atten_probs, source_seq_lengths) # TODO(rohananil): Only send a single R1 tensor to host instead of 3 after # b/111131551 is resolved. # Canonical shapes for tensors of various. ranks r_shapes = [ py_utils.GetShape(source_seq_lengths), py_utils.GetShape(hyps), py_utils.GetShape(atten_probs) ] # Reshape all tensors to [-1] to avoid cost of copy due to padding. rets_r1 = [tf.reshape(r, [-1]) for r in rets] return tuple(r_shapes) + tuple(rets_r1) + tuple( flat_final_other_states)