def FProp(self, theta, source_vecs, source_paddings, target_vecs, target_paddings, source_segment_id, target_segment_id, transparent_acc, transparent_acc_helper, source_task_id=None, target_task_id=None): del source_task_id del target_task_id p = self.params if p.inputs_from_decoder: transformer_output = target_vecs else: transformer_output = source_vecs dim1, dim2 = tf.shape(transformer_output)[0], tf.shape( transformer_output)[1] softmax_input = tf.reshape(transformer_output, [-1, p.input_dim]) output_shape = [dim1, dim2, p.num_classes] return tf.reshape( super(GPipeTransformerSoftmaxLayer, self).Logits(theta, [softmax_input]), output_shape)
def _ProcessSingleInput(self, source_id, src, tgt): """Performs strings-to-ids on the given input pair via p.tokenizer_dict.""" _, src_labels, src_paddings = self.StringsToIds( tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key) tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds( tf.reshape(tgt, [1]), is_source=False, key=self._tgt_tokenizer_key) # Mask positions to 0 where padding is 1 for consistency. We do this because # tokenizer implementation may use EOS token to pad. src_labels = py_utils.ApplyPadding(src_paddings, src_labels) tgt_ids = py_utils.ApplyPadding(tgt_paddings, tgt_ids) tgt_labels = py_utils.ApplyPadding(tgt_paddings, tgt_labels) features = py_utils.NestedMap() features.src = py_utils.NestedMap() features.src.ids = src_labels # ids_indicator is 1 if and only if the output from tokenizer has a # non-padded id. Unlike weights, it will not mutate and can be used for # determining actual sequence length, for example. features.src.ids_indicator = 1 - src_paddings features.tgt = py_utils.NestedMap() features.tgt.ids = tgt_ids features.tgt.labels = tgt_labels features.tgt.ids_indicator = 1 - tgt_paddings src_task_id, tgt_task_id = self._GetTaskIds(source_id) # task_ids are padded with zeros. features.src.task_ids = tf.cast(features.src.ids_indicator, dtype=tf.int32) * src_task_id features.tgt.task_ids = tf.cast(features.tgt.ids_indicator, dtype=tf.int32) * tgt_task_id if not py_utils.use_tpu(): features.src.strs = src features.tgt.strs = tgt return features.Transform(tf.squeeze)
def _InputBatch(self): p = self.params @tf.function def ReadData(): x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2, [p.data_dtype, p.label_dtype]) # Always convert to float32. return tf.cast(x, tf.float32), tf.cast(y, tf.float32) # Loads data and label into memory and keep it around. data, label = ops.cached_call(f=ReadData.get_concrete_function(), T=[tf.float32, tf.float32]) b, shape = self.InfeedBatchSize(), list(p.data_shape) data = tf.reshape(data, [-1] + shape) label = tf.reshape(label, [-1]) label = py_utils.HasShape(label, [tf.shape(data)[0]]) sample_ids = ops.random_permutation_sequence( num=p.num_samples, batch=b, repeat=p.repeat, seed=p.random_seed if p.random_seed else 0) n = tf.shape(sample_ids)[0] raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape) ret = py_utils.NestedMap( raw=raw, data=self._Preprocess(raw), label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]), weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b])) if not py_utils.use_tpu(): ret['sample_ids'] = sample_ids return ret
def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile(tensor, [num_hyps_per_beam, 1 ]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch])
def _GreedySearchStep(self, theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states, pre_beam_search_step_callback, post_beam_search_step_callback): """Extend greedy search hyps for one step. Args: theta: A `.NestedMap` object containing weights' values of the decoder layer and its children layers. encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to the callbacks. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. hyp_ids: An int tensor of shape [num_hyps, tgt_seq_len]. hyp_lens: Valid length of all the hyps. Tokens after eos ids are not counted. done_hyps: Whether or not a hyp has finished. other_states: A `.NestedMap` of other beam search states. This `.NestedMap` is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. See class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. See class header comments for more details. Returns: A tuple of following elements for the next greedy search step, (next step, new_step_ids, hyp_ids, hyp_lens, done_hyps, other_states) """ p = self.params # Increment hyp_lens by 1 if the hyp is not finished yet. hyp_lens = hyp_lens + (1 - tf.cast(done_hyps, tf.int32)) bs_results, new_other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, 1) # num_hyps_per_beam new_step_ids = tf.math.argmax(bs_results.log_probs, 1) new_step_ids = tf.cast(new_step_ids, tf.int32) new_step_ids = tf.reshape(new_step_ids, tf.shape(step_ids)) final_other_states = post_beam_search_step_callback( theta, encoder_outputs, new_step_ids, new_other_states) # Stash new_step_ids into the right slot. new_step_ids_1d = tf.reshape(new_step_ids, [-1]) hyp_ids = inplace_ops.alias_inplace_update(hyp_ids, cur_step, new_step_ids_1d) # Update done_hyps if the current step_ids is the end of sequence token. done_hyps = tf.math.logical_or( done_hyps, tf.equal(new_step_ids_1d, p.target_eos_id)) return (cur_step + 1, new_step_ids, hyp_ids, hyp_lens, done_hyps, final_other_states)
def AddMultiCurveSubplot(fig, tensors, paddings, labels, xlabels=None, **kwargs): """Adds a multi curve subplot to Matplotlib figure. Plots one line for each entry in tensors and assigns a plot label legend. Args: fig: The Matplotlib figure. tensors: List of tensors of shape [batch, length] paddings: Paddings for 'tensors' with shape [batch, length] with 0. in valid positions and 1. in invalid. labels: A list of tensor names (strings) of the same length as 'tensors'. xlabels: A string tensor of shape [batch] with an xlabel per batch. **kwargs: With optional, title, xlabel, ylabel, fontsize. """ data = [] row_labels = [] for t, l in zip(tensors, labels): if t is not None: data.append(py_utils.ApplyPadding(paddings, t)) row_labels.append(l) shape = py_utils.GetShape(data[0], 2) data = tf.reshape(tf.concat(data, -1), [shape[0], len(data), shape[1]]) args = [data, py_utils.LengthsFromPaddings(paddings)] if xlabels is not None: args.append(xlabels) fig.AddSubplot(args, plot_func=_AddMultiCurveRowPlots, row_labels=row_labels, **kwargs)
def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_categorical( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap()
def FProp(self, theta): """Combines the list of input tensors into a single tensor. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. Returns: A tensor of weights with dropout applied with shape [num_sources]. """ p = self.params # The constant factor is just meant to support the non-normalized scenario. # If softmax is applied, this factor will cancel out. w = theta.sum_weight * p.global_weight_scale + (1 / p.num_sources) w = tf.reshape(w, [p.num_sources]) w = self.weighted_merger_dropout.FProp(theta.weighted_merger_dropout, w) if p.weighted_merger_softmax: residual_weights = p.minimal_prob * p.num_sources assert residual_weights >= 0.0 assert residual_weights < 1.0 w = tf.nn.softmax( w, axis=0) * (1.0 - residual_weights) + p.minimal_prob return w
def restore(self, restored_tensors, restored_shapes): restored_tensor = restored_tensors[0] if restored_shapes is not None: restored_tensor = tf.reshape(restored_tensor, restored_shapes[0]) return tf.assign(self.op, tf.cast(restored_tensor, tf.bfloat16), validate_shape=restored_shapes is None and self.op.get_shape().is_fully_defined())
def _GetWeight(self, theta): p = self.params if p.weight_norm: # Normalize along the last dim (standard conv). filter_w = tf.nn.l2_normalize(theta.w, [0, 1, 2]) * tf.reshape( (theta.g + 1.0), [1, 1, 1, p.filter_shape[-1]]) else: filter_w = theta.w return filter_w
def _GetWeight(self, theta): p = self.params if p.weight_norm: # Normalize along the last two dims. filter_w = tf.nn.l2_normalize(theta.w, [0, 1]) * tf.reshape( (theta.g + 1.0), [1, 1, p.filter_shape[2], p.filter_shape[3]]) else: filter_w = theta.w return filter_w
def _InputBatch(self): length = tf.reduce_prod(self.shape) counter = summary_utils.StatsCounter('CountingInputGenerator') new_value = tf.cast(counter.IncBy(length), dtype=tf.int32) - length new_value = tf.stop_gradient(new_value) values = new_value + tf.range(length) shaped_values = tf.reshape(tf.cast(values, dtype=tf.float32), self.shape) targets = tf.reduce_sum(shaped_values, axis=0) return py_utils.NestedMap(src_ids=shaped_values, tgt_ids=targets)
def MakeCausalPadding(seq_len, block_size, left_context, right_context): """Makes the causal padding tensor for a full sequence. Args: seq_len: int or scalar int tensor. Sequence length. block_size: int. Number of time frames in a block. left_context: int. Left context size. right_context: int. Right context size. Returns: A tensor of [num_blocks, block_size, context_size] taking values in {0, 1}, where context_size = block_size + (left_context - 1) + right_context. Element b, i, j is zero if in the b-th block, the i-th frame can access the j-th frame in the context. """ seq_len = py_utils.with_dependencies([ py_utils.assert_greater_equal( seq_len, 1, message='seq_len must be at least 1') ], seq_len) num_blocks = (seq_len + block_size - 1) // block_size context_size = block_size + (left_context - 1) + right_context # [num_blocks, block_size]: source positions in the original sequence. src_positions = tf.reshape( tf.range(num_blocks * block_size), [num_blocks, block_size]) # [num_blocks,]: source positions at the start of each block. block_start_positions = tf.range(0, num_blocks * block_size, block_size) # [context_size]: positions relative to the block start. relative_context_positions = tf.range(context_size) - (left_context - 1) # [num_blocks, context_size]: target positions in the original sequence. tgt_positions = ( block_start_positions[:, tf.newaxis] + relative_context_positions[tf.newaxis, :]) # [num_blocks, block_size, context_size]: position differences between source- # target pairs. position_diff = src_positions[:, :, tf.newaxis] - tgt_positions[:, tf.newaxis, :] # [num_blocks, block_size, context_size]: if attention is allowed between # source-target pairs. valid_atten = tf.math.logical_and(-right_context <= position_diff, position_diff < left_context) # [num_blocks, block_size]: if the source position is valid, not padded. valid_src = src_positions < seq_len # [num_blocks, context_size]: if the target position is valid, not padded. valid_tgt = tf.math.logical_and(0 <= tgt_positions, tgt_positions < seq_len) valid_atten &= tf.math.logical_and(valid_src[:, :, tf.newaxis], valid_tgt[:, tf.newaxis, :]) padding = 1.0 - tf.cast(valid_atten, dtype=tf.float32) return padding
def UnstackFeatures(self, src_inputs, src_paddings): """Unstacks src_input and src_paddings based off stack height.""" sh = self.params.stack_height bs, old_series_length, _, channels = py_utils.GetShape(src_inputs) unstacked_series_length = old_series_length * sh src_inputs = tf.reshape(src_inputs, [bs, unstacked_series_length, -1, channels]) content = 1 - src_paddings lengths = tf.cast(sh * tf.reduce_sum(content, axis=1), tf.int32) mask = tf.sequence_mask(lengths, maxlen=unstacked_series_length) src_paddings = 1 - tf.cast(mask, tf.int32) return src_inputs, src_paddings
def _AugmentationNetwork(self, series_length, inputs, paddings, global_seed, domain_id_index=0): """Returns augmented features. Args: series_length: Total length of time series. inputs: Batch of input features of shape (batch_size, time_length, num_freq, channels). paddings: Batch of padding vectors of shape (batch_size, time_length). global_seed: an integer seed tensor for stateless random ops. domain_id_index: domain id index. Returns: Batch of output features of shape (batch_size, time_length, num_freq, channels) obtained by applying random augmentations to inputs. """ p = self.params dtype = p.dtype # Unstack the features. if p.unstack: inputs, paddings = self.UnstackFeatures(inputs, paddings) lengths = tf.reduce_sum(1 - paddings, 1) inputs = self._TimeWarp(inputs, lengths, global_seed=global_seed, dtype=dtype, domain_id_index=domain_id_index) inputs = self._TimeMask(inputs, lengths, global_seed=global_seed, noisify=p.use_noise, gaussian_noise=p.gaussian_noise, dtype=dtype, domain_id_index=domain_id_index) inputs = self._FrequencyMask(inputs, global_seed=global_seed, dtype=dtype, domain_id_index=domain_id_index) # Restack the features after applying specaugment. if p.unstack: inputs = tf.reshape( inputs, [tf.shape(inputs)[0], series_length, -1, tf.shape(inputs)[3]]) return inputs
def _ReshapeRetVal(name, t_shape): """Restore shape for tensors in microbatches.""" if t_shape is None: return None output_tensor = output_state[name] if p.batch_dim != 0: perm = list(range(1, p.batch_dim + 1)) + [0] perm += list(range(p.batch_dim + 1, t_shape.rank + 1)) output_tensor = tf.transpose(output_tensor, perm=perm) output_shape = t_shape.ToTensorShape().as_list() output_shape[p.batch_dim] *= p.num_micro_batches output_tensor = tf.reshape(output_tensor, output_shape) return output_tensor
def _GetExpertDist(self, theta, inputs, *args): """Get the task id from inputs tensors.""" # TODO(huangyp): support the more general case when batch size is not 1. # Input shape can be either [batch, length, dim] or [length, batch, dim] reshaped_inputs = tf.reshape(inputs, [-1, self.params.cond_dim]) if self.params.nonzeros_mean: per_example_emb = tf.reduce_sum(reshaped_inputs, 0) nonzeros = tf.cast( tf.math.count_nonzero(reshaped_inputs, 0), dtype=tf.float32) per_example_emb /= (nonzeros + 1e-10) else: per_example_emb = tf.reduce_mean(reshaped_inputs, 0) expert_dist = tf.nn.sigmoid(tf.einsum('i,ij->j', per_example_emb, theta.w)) return expert_dist
def SequenceLength(padding): """Computes the length of a sequence based on binary padding. Args: padding: A tensor of binary paddings shaped [batch, seqlen]. Returns: seq_lens, A tensor of shape [batch] containing the non-padded length of each element of plot_tensor along the batch dimension. """ seq_lens = tf.cast(tf.round(tf.reduce_sum(1 - padding, axis=1)), tf.int32) # Get rid of any extra dimensions. batch_size = tf.shape(padding)[0] seq_lens = tf.reshape(seq_lens, [batch_size], name='seq_lens') return seq_lens
def RelShift(x): """Performs relative shift on 4D tensor (first 2 axis are batching dims). Given input of shape [?, ?, W, W], this does "relative shifting" for the last two dims, s.t. output[b, n, i, j] = 0 if i > j else input[b, n, i, j-i] Args: x: A Tensor of shape [?, ?, W, W] Returns: A Tensor of the same shape as input with its content shifted (as described above). """ b, n, w, _ = py_utils.GetShape(x) x = py_utils.HasShape(x, [-1, -1, w, w]) x = tf.pad(x, ((0, 0), (0, 0), (0, 0), (0, 1))) x = tf.reshape(x, [b, n, w + 1, w]) x = x[:, :, :w, :] return x
def PrepareSequenceForPlot(tensor, padding, name): """Prepares a sequence feature for plotting. The sequence feature is transposed and channels are flattened. Args: tensor: A n-D Tensor of shape [batch, time, ...]. padding: A Tensor of shape [batch, time]. name: A string as the name of the reshaped Tensor, which will be used as the subcaption for plotting. Returns: A tuple of: reshaped_tensor: A 3-D Tensor of shape [batch, dim, time]. sequence_length: A 1-D Tensor of shape [batch]. """ # Flatten any dimensions beyond the third into the third. batch_size, max_len = py_utils.GetShape(tensor, 2) plot_tensor = tf.reshape(tensor, [batch_size, max_len, -1]) plot_tensor = tf.transpose(plot_tensor, [0, 2, 1], name=name) return (plot_tensor, SequenceLength(padding))
def _ProcessMASSInput(self, source_id, src): """Perform MASS input processing.""" # TODO(yuancao): By doing so we assume that right now for monolingual # eval/dev sets (xx->xx) are in double-column format (since it bypasses # the Mass op). Ideally we should add a dedicated eval/dev processing # procedure for unsupervised MT cases, so that single-column eval/devs sets # are also supported. This should not be handled by any specific ops like # Mass, but inside the TextPackedInput class. assert not self.do_eval, 'MASS input can only be used for training.' _, labels, paddings = self.StringsToIds(tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key) weights = 1 - paddings actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32) src_lang_ids, tgt_lang_ids = self._GetTaskIds(source_id) mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len) features = py_utils.NestedMap() features.src = py_utils.NestedMap() features.src.ids = mass_out.src.ids features.src.paddings = paddings features.src.weights = weights features.src.task_ids = tf.cast(features.src.weights, dtype=tf.int32) * src_lang_ids features.src.ids_indicator = weights features.tgt = py_utils.NestedMap() features.tgt.ids = mass_out.tgt.ids features.tgt.labels = mass_out.tgt.labels features.tgt.paddings = paddings features.tgt.weights = mass_out.tgt.weights features.tgt.task_ids = tf.ones_like(features.src.task_ids, dtype=tf.int32) * tgt_lang_ids features.tgt.ids_indicator = weights if not py_utils.use_tpu(): features.src.strs = src features.tgt.strs = src return features.Transform(tf.squeeze)
def ConvertToBlocks(x, block_size, padding_val=0.0): """Turns a sequence to non overlapping blocks. Args: x: a tensor of [batch, time, ...]. block_size: int. Number of time frames in a block. padding_val: float. value on the padded frames. Returns: A tensor of [batch, num_blocks, block_size, ...], with necessary paddings, where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...]. """ shape = py_utils.GetShape(x) b, t = shape[:2] if block_size < 1: raise ValueError('block_size must be at least 1, got {}'.format(block_size)) w = block_size # Pad t to be a multiply of w. num_blocks = (t + w - 1) // w pad_to_length = num_blocks * w padded = py_utils.PadSequenceDimension(x, pad_to_length, padding_val) reshaped = tf.reshape(padded, [b, num_blocks, w] + shape[2:]) return reshaped
def history_sort(values): return tf.reshape( fast_gather(tf.reshape(values, [-1, 1]), sorted_history_indices_flat, num_beams * k * candidates_per_hyp), [num_beams, k * candidates_per_hyp])
def FProp(self, theta, source_input, source_paddings, target_input=None, target_paddings=None, source_segment_id=None, target_segment_id=None, labels=None, label_weights=None, source_pos_id=None, target_pos_id=None, source_task_id=None, target_task_id=None): """Transforms source sequence of Tensors with Transformers layers. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. source_input: A sequence of ints indicating source input ids of [time, batch] shape or [batch, time] if batch_dim is 0. source_paddings: A sequence of 0s and 1s indicating input paddings of [time, batch] shape or [batch, time] if batch_dim is 0. target_input: A sequence of ints indicating target input ids of [time, batch] shape or [batch, time] if batch_dim is 0. target_paddings: [target_time, target_batch] or [target_batch, target_time] if batch_dim is 0. source_segment_id: A sequence of ints indicating source segment ids of [time, batch] shape or [batch, time] if batch_dim is 0. target_segment_id: A sequence of ints indicating target segment ids of [time, batch] shape or [batch, time] if batch_dim is 0. labels: A sequence of ints indicating label ids of [time, batch] shape, or [batch, time] if batch_dim is 0. label_weights: A sequence of floats indicates label weights of [time, batch] shape, or [batch, time] if batch_dim is 0. source_pos_id: A sequence of ints indicating source position ids of [time, batch] shape, or [batch, time] if batch_dim is 0. target_pos_id: A sequence of ints indicating target position ids of [time, batch] shape, or [batch, time] if batch_dim is 0. source_task_id: A sequence of ints indicating source task ids of [time, batch] shape, or [batch, time] if batch_dim is 0. target_task_id: A sequence of ints indicating target task ids of [time, batch] shape, or [batch, time] if batch_dim is 0. Returns: transformer_output with shape [time, batch, dim] or [batch, time, dim] if batch_dim is 0. """ p = self.params if p.num_decoder_layers > 0: assert target_input is not None assert target_paddings is not None if p.packed_input: assert source_segment_id is not None, ( 'Need to specify src_segment_id if packed input is supported.') assert source_pos_id is not None, ( 'Need to specify src_pos_id for packed input and embeddings.') logits = super(GPipeTransformerStack, self).FProp(theta, source_input, source_paddings, target_input, target_paddings, source_segment_id, target_segment_id, source_pos_id, target_pos_id, source_task_id, target_task_id) if not p.softmax_tpl: return logits label_weights = tf.reshape(label_weights, [-1]) target_probs = None if p.label_smoothing: if p.batch_dim: # Time-major target_probs = tf.transpose( self.smoother.FProp(theta.smoother, tf.transpose(target_paddings), tf.transpose(labels), target_ids=None), [1, 0, 2]) else: target_probs = self.smoother.FProp(theta.smoother, target_paddings, labels, target_ids=None) target_probs = tf.reshape(target_probs, [-1, p.softmax_tpl.num_classes]) reshaped_logits = tf.reshape(logits, [-1, p.softmax_tpl.num_classes]) tgt_labels = tf.reshape(labels, [-1]) num_splits = len(p.splits) softmax = self.children['cell_{}'.format(num_splits - 1)].softmax softmax_theta = theta['cell_{}'.format(num_splits - 1)].softmax per_example_xent, _ = softmax.XentLossFromLogits( softmax_theta, reshaped_logits, class_weights=tf.reshape(label_weights, [-1]), class_ids=tgt_labels, class_probabilities=target_probs) xent_shape = tf.shape(logits)[:2] per_example_xent = tf.reshape(per_example_xent, xent_shape) return per_example_xent, logits
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def to_flat_indices(column_indices_per_row): column_indices_per_row.shape.assert_has_rank(2) flat_indices = (column_indices_per_row + num_hyps_per_beam * candidates_per_hyp * tf.reshape(tf.range(num_beams), [num_beams, 1])) return tf.reshape(flat_indices, [-1])
def history_unsort(values): return tf.reshape( fast_gather(tf.reshape(values, [-1, 1]), inverse_indices_flat, num_beams * k * candidates_per_hyp), orig_scores_shape)
def beam_search_step(in_scores, in_atten_probs, in_best_scores, in_cumulative_scores, in_histories, cur_step, eos_id, num_beams, beam_size, num_hyps_per_beam, valid_eos_max_logit_delta=5.0, local_eos_threshold=-100.0, merge_paths=False, is_last_chunk=None, eoc_id=0): """A single step of beam search. Let "b" be the number of beams, "k" be the number hyps in each beam. This function supports values with dtypes tf.float32 or tf.bfloat16. The following data structures are allocated before the first decoding step and are passed along from cur step to the next step: Args: in_scores: A tensor of shape [b * k, vocab_size], where [i, ...] is the token score of the j-th hyps of the n-th beam. j = (i / k), and n = i % k in_atten_probs: A tensor of shape [b*k, s_len], where in_atten_probs[i, ...] is the attention probabilities over the source words of the j-th hyps of n-th beam (where j, and n are derived as above). in_best_scores: A vector of size [b], best scores of terminated hyps so far in each of the beams. in_cumulative_scores: A vector of size [b * k]. The cumulative score of each active hyp before the current step. in_histories: An int32 vector of size [b * k] containing hashes of the histories of each active hyp. If 'merge_paths' is enabled, the histories are used to identify hypotheses that are identical modulo epsilons (e.g. "a <eps> b" and "a b <eps>") and merge them. See 'update_histories' docstring for details. cur_step: Current step id. eos_id: Token id of the special end of sequence token. num_beams: Number of beams. beam_size: Search terminates if the delta between the scores of the active hyps. num_hyps_per_beam: Number of hyps in a beam. valid_eos_max_logit_delta: We allow </s> to terminate a hyp only if its logit is no more than 'valid_eos_max_logit_delta' away from the logit of the best candidate. local_eos_threshold: We allow </s> to terminate a hyp if the local score for </s> is greater than local_eos_threshold. merge_paths: If true, hyps which are identical when epsilons are removed will be combined into a single hyp. The probability for that combined hyp will be the sum of the probabilities of the component hyps. This can only be applied for epsilon-emitting models (RNN-T and NT). is_last_chunk: A tensor of shape [b * k, 1]. Used by neural transducer, determines whether the current hypothesis reaches the last chunk and should treat the next end-of-chunk symbol as end-of-sentence. eoc_id: int, the id of the end of chunk (a.k.a epsilon) token used by neural transducer models. Only relevant if 'merge_paths' is True or 'is_last_chunk' is provided. Returns: out_best_scores: A tensor of shape [b] of updated best scores for each of the beams. out_cumulative_scores: A tensor of shape [b * k]. The cumulative score of the new hyps after the current decoding step. out_scores: A tensor of shape [b * k] with scores of the token selected. out_eos_scores: A tensor of shape [b * k] with token scores for the EOS, in case the hyp was terminated, otherwise 0.0. out_hyps: A tensor of shape [b * k] with ids of the token selected. out_prev_hyps: A tensor of shape [b * k] with index to the previous hyps which was selected. out_done_hyps: A boolean tensor of shape [b * k] where value indicates if hyps was terminated. out_atten_probs: A tensor of shape [b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. out_eos_atten_probs: A tensor of shape [b * k, seq_len] which contains the attention probabilities over the source against word in the current hyp which was terminated. out_all_done: A scalar, whether decoding should terminate for all beams. out_histories: A tensor of shape [b * k] containing new history hashes for the active hypotheses. See 'update_histories' docstring for details. Raises: ValueError: if inputs are invalid. """ num_hyps_per_beam = int(num_hyps_per_beam) if num_hyps_per_beam <= 0: raise ValueError("num_hyps_per_beam = {} and must be > 0.".format( num_hyps_per_beam)) in_scores = tf.convert_to_tensor(in_scores) in_scores.shape.assert_has_rank(2) num_classes = in_scores.get_shape()[1] in_atten_probs = tf.convert_to_tensor(in_atten_probs) in_atten_probs.shape.assert_has_rank(2) in_best_scores = tf.convert_to_tensor(in_best_scores) in_best_scores.shape.assert_has_rank(1) in_cumulative_scores = tf.convert_to_tensor(in_cumulative_scores) in_cumulative_scores.shape.assert_has_rank(1) in_histories = tf.convert_to_tensor(in_histories) in_histories.shape.assert_has_rank(1) with tf.name_scope("beam_search_step"): # For k = num_hyps_per_beam # First step of beam search is to find the top tokens based on its score. # Normally we select k+1, where the extra +1 is to make sure we have k # non-eos tokens to select if EOS token is in the top-k. If path merging is # on, we actually need to select k+2; this ensures there are k+1 tokens left # after the merge, at least k of which are not EOS. # TODO(b/118644069): Avoid casts when there is a XLA op available that takes # in bfloat16. num_candidates_per_input_hyp = (num_hyps_per_beam + 2 if merge_paths else num_hyps_per_beam + 1) # [b * k, num_candidates_per_input_hyp] local_score_values, local_indices = xla_ops.top_k_with_unique( tf.cast(in_scores, tf.float32), k=num_candidates_per_input_hyp) local_score_values = tf.cast(local_score_values, in_scores.dtype) # Compute the global score which is sum of the local score, and the # cumulative scores for each of the hyps. # [b * k, num_candidates_per_input_hyp] global_score_values = local_score_values + tf.expand_dims( in_cumulative_scores, 1) values_dtype = local_score_values.dtype is_first_step = tf.cast(tf.equal(cur_step, 0), values_dtype) # Preprocessing to reorder the tensor from `mod` sharding to `div` so that # we can use matrix/vector operations to complete the beam search. # [b * k, num_candidates_per_input_hyp] global_score_values = reorder_tensor("mod_to_div", global_score_values, num_beams, num_hyps_per_beam) local_score_values = reorder_tensor("mod_to_div", local_score_values, num_beams, num_hyps_per_beam) local_indices = reorder_tensor("mod_to_div", local_indices, num_beams, num_hyps_per_beam, max_value=num_classes - 1) # [b * k, 1] histories = reorder_tensor("mod_to_div", tf.expand_dims(in_histories, 1), num_beams, num_hyps_per_beam) if is_last_chunk is None: is_last_chunk = tf.zeros([num_beams * num_hyps_per_beam, 1], tf.bool) else: is_last_chunk = tf.cast( reorder_tensor( "mod_to_div", tf.reshape(is_last_chunk, [num_beams * num_hyps_per_beam, 1]), num_beams, num_hyps_per_beam), tf.bool) # For the first step mask everything but the first row. # [num_hyps_per_beam] per_example_mask = tf.concat([ tf.constant([1.0], dtype=values_dtype), tf.zeros([num_hyps_per_beam - 1], dtype=values_dtype) ], 0) # [num_hyps_per_beam, num_beams] => [b*k, 1] mask = tf.reshape( tf.tile(per_example_mask, tf.expand_dims(num_beams, 0)), [-1, 1]) * is_first_step + (1.0 - is_first_step) local_score_values *= mask global_score_values *= mask # We add a large negative value for the unmasked values. per_example_additive_mask = tf.concat([ tf.constant([0.0], dtype=values_dtype), tf.constant(BEST_SCORES_INIT, shape=[num_hyps_per_beam - 1], dtype=values_dtype) ], 0) additive_mask = tf.reshape( tf.tile(per_example_additive_mask, tf.expand_dims(num_beams, 0)), [-1, 1]) * is_first_step local_score_values += additive_mask global_score_values += additive_mask if merge_paths: with tf.name_scope("merge_paths"): # Compute new history hashes for each hypothesis + new token. # [b * k, num_candidates_per_input_hyp] histories = update_histories(histories, local_indices, mask, epsilon_id=eoc_id) global_score_values, histories = merge_hyps( global_score_values, histories, mask, num_beams, num_hyps_per_beam) # As we keep num_candidates_per_input_hyp, we have a total of # num_candidates_per_input_hyp * k hyps active per example. num_candidate_hyps = num_candidates_per_input_hyp * num_hyps_per_beam batch_shape = [-1, num_candidate_hyps] # Reshape score values so that each row corresponds to a particular example. # [num_beams, num_candidate_hyps] global_score_values_batch = tf.reshape(global_score_values, batch_shape) # First for each beam: Find the top 2 * num_hyps_per_beam candidates. # The factor of 2 is to be able to process non EOS token ids in the case # where top scoring token for each hyps is EOS token. # [k * b, 2 * num_hyps_per_beam] _, candidates_indices_in_top_k = xla_ops.top_k_with_unique( tf.cast(global_score_values_batch, tf.float32), k=2 * num_hyps_per_beam) # Find the previous hyps of the candidate. We divide here by (k+1) to # identify which hyps this token came from. hyps_id = candidates_indices_in_top_k // num_candidates_per_input_hyp # Add in offset so that we can get the candidate index in the [b * k] space. offset = tf.expand_dims(tf.range(num_beams) * num_candidate_hyps, 1) flat_candidates_indices_in_top_k = tf.reshape( candidates_indices_in_top_k + offset, [-1]) flat_local_indices = tf.reshape(local_indices, [1, -1]) flat_token_scores = tf.reshape(local_score_values, [-1, 1]) flat_global_scores = tf.reshape(global_score_values, [-1, 1]) # Gather the token scores for each of 2*k candidates. We use tf.one_hot() # followed by a tf.matmul() to speedup gather on TPUs. total_num_candidates = num_beams * num_candidate_hyps token_scores_for_beam = tf.reshape( fast_gather(flat_token_scores, flat_candidates_indices_in_top_k, total_num_candidates), [num_beams, 2 * num_hyps_per_beam]) token_scores_for_beam_shape = tf.shape(token_scores_for_beam) global_scores_for_beam = tf.reshape( fast_gather(flat_global_scores, flat_candidates_indices_in_top_k, total_num_candidates), token_scores_for_beam_shape) # Local indices value's are between [0, vocab_size-1], hence we use the # slower version of gather. token_ids_for_beam = tf.reshape( fast_gather(flat_local_indices, flat_candidates_indices_in_top_k, total_num_candidates, max_value=num_classes - 1, axis=1), token_scores_for_beam_shape) # We have access to 2*num_hyps_per_beam hyps per beam. # We shrink back to num_hyps_per_beam that does not include EOS, and move # EOS that occurs in top-num_hyps_per_beam to the EOS done matrix. # To determine the threshold at which eos is allowed to terminate a hyp, # we need to know the maximum global score for that hyp with any additional # token. If path merging is *not* enabled, the global_score_values are # by construction in sorted order, so we can just look at its 0th column. If # path merging is enabled, the global scores of deleted (merged) hyps break # the sorted order, which means we have to do a full reduce_max. if merge_paths: max_global_score_per_input_hyp = tf.reduce_max(global_score_values, axis=1, keepdims=True) else: max_global_score_per_input_hyp = global_score_values[:, 0:1] # [num_beams * num_hyps_per_beam, 1] global_eos_threshold = (max_global_score_per_input_hyp - valid_eos_max_logit_delta) local_eos_threshold_tensor = local_eos_threshold * tf.ones_like( global_eos_threshold) # Find EOS in top num_hyps_per_beam token ids. We also treat EOC as EOS if # the model has indicated this is the last chunk. local_index_is_eos = tf.equal(local_indices, eos_id) local_index_is_last_chunk_eoc = tf.math.logical_and( tf.equal(local_indices, eoc_id), is_last_chunk) eos_mask = tf.math.logical_and( tf.math.logical_and( tf.math.logical_and( tf.greater( local_score_values, tf.tile(local_eos_threshold_tensor, [1, num_candidates_per_input_hyp])), tf.greater( global_score_values, tf.tile(global_eos_threshold, [1, num_candidates_per_input_hyp]))), tf.math.logical_or(local_index_is_eos, local_index_is_last_chunk_eoc)), tf.cast(mask, tf.bool)) end_hyps_bool_mask = tf.reshape(tf.reduce_any(eos_mask, 1), [-1, 1]) end_hyps_bool_mask = reorder_tensor("div_to_mod", end_hyps_bool_mask, num_beams, num_hyps_per_beam) eos_atten_probs = in_atten_probs * tf.cast(end_hyps_bool_mask, in_atten_probs.dtype) eos_atten_probs = tf.reshape(eos_atten_probs, [num_beams * num_hyps_per_beam, -1]) # A boolean tensor of shape [b * k] where value indicates if hyps was # terminated. out_done_hyps = tf.reshape(end_hyps_bool_mask, [-1]) # Scores for EOS token. eos_float_mask = tf.cast(eos_mask, values_dtype) eos_local_scores = eos_float_mask * local_score_values eos_additive_float_mask = (1.0 - eos_float_mask) * BEST_SCORES_INIT eos_local_scores += eos_additive_float_mask out_eos_scores = tf.reshape(tf.reduce_max(eos_local_scores, 1), [-1, 1]) out_eos_scores = tf.reshape( reorder_tensor("div_to_mod", out_eos_scores, num_beams, num_hyps_per_beam), [-1]) # A tensor of shape [b] of updated best scores for each of the beams. eos_global_scores = eos_float_mask * global_score_values eos_global_scores += eos_additive_float_mask best_scores = tf.reduce_max( tf.reshape(eos_global_scores, [num_beams, -1]), 1) # Following operations are to finds the top num_hyps_per_beam that are # active. # Active ones are the ones that do not correspond to EOS termination. # We keep num_hyps_per_beam * 2 in case every hyps is terminated by EOS id. # Top K with eos removed. non_eos_mask = tf.not_equal(token_ids_for_beam, eos_id) num_candidate_hyps = num_hyps_per_beam * 2 * num_beams index = tf.where( non_eos_mask, tf.reshape(tf.range(num_candidate_hyps, dtype=tf.int32), token_scores_for_beam_shape), num_candidate_hyps * tf.ones(dtype=tf.int32, shape=token_scores_for_beam_shape)) # Unrolled TopK. sorted_indices = [] # Finds the first num_hyps_per_beam unmasked indexes and stores them in # concated_index (shape: [num_beams, num_candidate_hyps]) # This is done by iteratively record the min index in each row, and reset # it to the max, so that next iteration reduce_min returns the 2nd minimum # index. for _ in range(num_hyps_per_beam): min_index = tf.reshape(tf.reduce_min(index, [1]), [num_beams, 1]) sorted_indices.append(min_index) # Replace position with num_candidate_hyps value. index = tf.where( tf.equal(index, min_index), num_candidate_hyps * tf.ones(dtype=tf.int32, shape=token_scores_for_beam_shape), index) # Post processing ops to output expected tensors. concated_sorted_indices = tf.concat(sorted_indices, 1) flat_sorted_indices = tf.reshape(concated_sorted_indices, [-1]) # A tensor of shape [b * k] with scores of the token selected. out_scores = tf.reshape( fast_gather(tf.reshape(token_scores_for_beam, [-1, 1]), flat_sorted_indices, num_candidate_hyps), [-1, 1]) out_scores = tf.reshape( reorder_tensor("div_to_mod", out_scores, num_beams, num_hyps_per_beam), [-1]) # Gather the updated histories of selected hypotheses if path merging is # enabled. Otherwise, the histories are unused, so just output in_histories. if merge_paths: flat_histories = tf.reshape(histories, [-1, 1]) # [num_beams, 2 * num_hyps_per_beam] histories_for_beam = tf.reshape( fast_gather(flat_histories, flat_candidates_indices_in_top_k, total_num_candidates), token_scores_for_beam_shape) out_histories = tf.reshape( fast_gather(tf.reshape(histories_for_beam, [-1, 1]), flat_sorted_indices, num_candidate_hyps), [-1, 1]) out_histories = tf.reshape( reorder_tensor("div_to_mod", out_histories, num_beams, num_hyps_per_beam), [-1]) else: out_histories = in_histories prev_hyps_ids = tf.reshape( tf.reshape( fast_gather(tf.reshape(hyps_id, [1, -1]), flat_sorted_indices, num_candidate_hyps, max_value=num_hyps_per_beam, axis=1), [num_beams, -1]) * num_beams + tf.expand_dims(tf.range(num_beams), 1), [-1, 1]) prev_hyps_ids = reorder_tensor("div_to_mod", prev_hyps_ids, num_beams, num_hyps_per_beam, max_value=num_hyps_per_beam) # A tensor of shape [b * k] with index to the previous hyps which was # selected. out_prev_hyps = tf.reshape(prev_hyps_ids, [-1]) # A tensor of shape [b * k, seq_len] which contain the attention # probabilities over the source words against word in the previous hyps. out_atten_probs = tf.reshape( fast_gather(in_atten_probs, out_prev_hyps, num_beams * num_hyps_per_beam), [num_beams * num_hyps_per_beam, -1]) sorted_top_k_ids = fast_gather(tf.reshape(token_ids_for_beam, [1, -1]), flat_sorted_indices, num_candidate_hyps, max_value=num_classes - 1, axis=1) sorted_top_k_ids = reorder_tensor("div_to_mod", sorted_top_k_ids, num_beams, num_hyps_per_beam, max_value=num_classes - 1, axis=1) # A tensor of shape [b * k] with ids of the token selected. out_hyps = tf.reshape(sorted_top_k_ids, [-1]) # A tensor of shape [b * k]. The cumulative score of the selected hyps after # the current decoding step. out_cumulative_scores = tf.reshape( fast_gather(tf.reshape(global_scores_for_beam, [-1, 1]), flat_sorted_indices, num_candidate_hyps), [-1, 1]) out_cumulative_scores = tf.reshape( reorder_tensor("div_to_mod", out_cumulative_scores, num_beams, num_hyps_per_beam), [-1]) out_best_scores = tf.maximum(best_scores, in_best_scores) # A scalar, whether decoding should terminate for all beams. out_all_done = tf.reshape( tf.math.logical_not( tf.reduce_any( tf.greater( out_cumulative_scores, tf.reshape( tf.tile( tf.reshape(out_best_scores - beam_size, [-1, 1]), [1, num_hyps_per_beam]), [-1])))), []) return (out_best_scores, out_cumulative_scores, out_scores, out_eos_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs, eos_atten_probs, out_all_done, out_histories)
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` object containing: ids - The inputs tensor of shape [batch, time]. paddings - The ids' paddings of shape [batch, time]. Returns: A '.NestedMap' object containing: encoded - The encoded features of shape [time, batch, dim] or [batch, time, dim], depending p.output_data_format. padding - The encoded features' padding of shape [time, batch] or [batch, time]. segment_id - The segmentation of packed inputs of shape [time, batch] or [batch, time] if it is supported by the model, or None otherwise. embedded_inputs - The embedded inputs tokens without positional encodings of shape [time, batch, dim] or [batch, time, dim]. """ p = self.params with tf.name_scope(p.name): # [batch, time] input_ids = input_batch.ids # [batch, time] paddings = input_batch.paddings # [batch, time] segment_ids = input_batch.segment_ids if p.packed_input else None batch = py_utils.GetShape(input_ids)[0] time = py_utils.GetShape(input_ids)[1] # Embedding layer. # [batch, time, dim] if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, input_ids) else: input_embs = self.softmax.EmbLookup(theta.softmax, input_ids) orig_input_embs = input_embs # [1, time, dim] if p.packed_input: positions = input_batch.segment_pos position_embs = tf.expand_dims( self.position_emb.FPropWithPosition( theta.position_emb, positions), 0) else: position_embs = tf.expand_dims( self.position_emb.FProp(theta.position_emb, time), 0) # [batch, time, dim] input_embs += tf.cast(position_embs, tf.bfloat16) if p.input_dropout_tpl.fprop_dtype: input_embs = tf.cast(input_embs, p.input_dropout_tpl.fprop_dtype) paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [batch, time, dim] transformer_input = input_embs # Explicitly set the input shape of Transformer layers, to avoid # unknown shape error occurred to tf.einsum on nonTPU devices. transformer_input = tf.reshape(transformer_input, [batch, time, p.model_dim]) # Compute self-attention segment mask once. if p.packed_input: segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids, dtype=transformer_input.dtype) else: segment_mask = tf.zeros([batch, 1, time, time]) encoded, padding = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, segment_mask) if p.final_layer_norm: encoded = self.final_ln.FProp(theta.final_ln, encoded) seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1), tf.int32) if p.output_data_format == 'TBC': encoded = tf.transpose(encoded, [1, 0, 2]) # [time, batch, dim] padding = tf.transpose(padding) # [time, batch] segment_ids = tf.transpose( segment_ids) if p.packed_input else None orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2]) return py_utils.NestedMap( encoded=encoded, padding=padding, seq_lengths=seq_lengths, # used by beam_search_helper. segment_id=segment_ids, embedded_inputs=orig_input_embs)
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where(tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)