def _InputBatch(self): p = self.params @tf.function def ReadData(): x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2, [p.data_dtype, p.label_dtype]) # Always convert to float32. return tf.cast(x, tf.float32), tf.cast(y, tf.float32) # Loads data and label into memory and keep it around. data, label = ops.cached_call(f=ReadData.get_concrete_function(), T=[tf.float32, tf.float32]) b, shape = self.InfeedBatchSize(), list(p.data_shape) data = tf.reshape(data, [-1] + shape) label = tf.reshape(label, [-1]) label = py_utils.HasShape(label, [tf.shape(data)[0]]) sample_ids = ops.random_permutation_sequence( num=p.num_samples, batch=b, repeat=p.repeat, seed=p.random_seed if p.random_seed else 0) n = tf.shape(sample_ids)[0] raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape) ret = py_utils.NestedMap( raw=raw, data=self._Preprocess(raw), label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]), weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b])) if not py_utils.use_tpu(): ret['sample_ids'] = sample_ids return ret
def _ProcessSingleInput(self, source_id, src, tgt): """Performs strings-to-ids on the given input pair via p.tokenizer_dict.""" _, src_labels, src_paddings = self.StringsToIds( tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key) tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds( tf.reshape(tgt, [1]), is_source=False, key=self._tgt_tokenizer_key) # Mask positions to 0 where padding is 1 for consistency. We do this because # tokenizer implementation may use EOS token to pad. src_labels = py_utils.ApplyPadding(src_paddings, src_labels) tgt_ids = py_utils.ApplyPadding(tgt_paddings, tgt_ids) tgt_labels = py_utils.ApplyPadding(tgt_paddings, tgt_labels) features = py_utils.NestedMap() features.src = py_utils.NestedMap() features.src.ids = src_labels # ids_indicator is 1 if and only if the output from tokenizer has a # non-padded id. Unlike weights, it will not mutate and can be used for # determining actual sequence length, for example. features.src.ids_indicator = 1 - src_paddings features.tgt = py_utils.NestedMap() features.tgt.ids = tgt_ids features.tgt.labels = tgt_labels features.tgt.ids_indicator = 1 - tgt_paddings src_task_id, tgt_task_id = self._GetTaskIds(source_id) # task_ids are padded with zeros. features.src.task_ids = tf.cast( features.src.ids_indicator, dtype=tf.int32) * src_task_id features.tgt.task_ids = tf.cast( features.tgt.ids_indicator, dtype=tf.int32) * tgt_task_id if not py_utils.use_tpu(): features.src.strs = src features.tgt.strs = tgt return features.Transform(tf.squeeze)
def _GreedySearchStep(self, theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states, pre_beam_search_step_callback, post_beam_search_step_callback): """Extend greedy search hyps for one step. Args: theta: A `.NestedMap` object containing weights' values of the decoder layer and its children layers. encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to the callbacks. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. hyp_ids: An int tensor of shape [num_hyps, tgt_seq_len]. hyp_lens: Valid length of all the hyps. Tokens after eos ids are not counted. done_hyps: Whether or not a hyp has finished. other_states: A `.NestedMap` of other beam search states. This `.NestedMap` is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. See class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. See class header comments for more details. Returns: A tuple of following elements for the next greedy search step, (next step, new_step_ids, hyp_ids, hyp_lens, done_hyps, other_states) """ p = self.params # Increment hyp_lens by 1 if the hyp is not finished yet. hyp_lens = hyp_lens + (1 - tf.cast(done_hyps, tf.int32)) bs_results, new_other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, 1) # num_hyps_per_beam new_step_ids = tf.math.argmax(bs_results.log_probs, 1) new_step_ids = tf.cast(new_step_ids, tf.int32) new_step_ids = tf.reshape(new_step_ids, tf.shape(step_ids)) final_other_states = post_beam_search_step_callback( theta, encoder_outputs, new_step_ids, new_other_states) # Stash new_step_ids into the right slot. new_step_ids_1d = tf.reshape(new_step_ids, [-1]) hyp_ids = inplace_ops.alias_inplace_update(hyp_ids, cur_step, new_step_ids_1d) # Update done_hyps if the current step_ids is the end of sequence token. done_hyps = tf.math.logical_or( done_hyps, tf.equal(new_step_ids_1d, p.target_eos_id)) return (cur_step + 1, new_step_ids, hyp_ids, hyp_lens, done_hyps, final_other_states)
def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_categorical( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap()
def AddMultiCurveSubplot(fig, tensors, paddings, labels, xlabels=None, **kwargs): """Adds a multi curve subplot to Matplotlib figure. Plots one line for each entry in tensors and assigns a plot label legend. Args: fig: The Matplotlib figure. tensors: List of tensors of shape [batch, length] paddings: Paddings for 'tensors' with shape [batch, length] with 0. in valid positions and 1. in invalid. labels: A list of tensor names (strings) of the same length as 'tensors'. xlabels: A string tensor of shape [batch] with an xlabel per batch. **kwargs: With optional, title, xlabel, ylabel, fontsize. """ data = [] row_labels = [] for t, l in zip(tensors, labels): if t is not None: data.append(py_utils.ApplyPadding(paddings, t)) row_labels.append(l) shape = py_utils.GetShape(data[0], 2) data = tf.reshape(tf.concat(data, -1), [shape[0], len(data), shape[1]]) args = [data, py_utils.LengthsFromPaddings(paddings)] if xlabels is not None: args.append(xlabels) fig.AddSubplot( args, plot_func=_AddMultiCurveRowPlots, row_labels=row_labels, **kwargs)
def _GetWeight(self, theta): p = self.params if p.weight_norm: # Normalize along the last dim (standard conv). filter_w = tf.nn.l2_normalize(theta.w, [0, 1, 2]) * tf.reshape( (theta.g + 1.0), [1, 1, 1, p.filter_shape[-1]]) else: filter_w = theta.w return filter_w
def _GetWeight(self, theta): p = self.params if p.weight_norm: # Normalize along the last two dims. filter_w = tf.nn.l2_normalize(theta.w, [0, 1]) * tf.reshape( (theta.g + 1.0), [1, 1, p.filter_shape[2], p.filter_shape[3]]) else: filter_w = theta.w return filter_w
def _InputBatch(self): length = tf.reduce_prod(self.shape) counter = summary_utils.StatsCounter('CountingInputGenerator') new_value = tf.cast(counter.IncBy(length), dtype=tf.int32) - length new_value = tf.stop_gradient(new_value) values = new_value + tf.range(length) shaped_values = tf.reshape(tf.cast(values, dtype=tf.float32), self.shape) targets = tf.reduce_sum(shaped_values, axis=0) return py_utils.NestedMap(src_ids=shaped_values, tgt_ids=targets)
def restore(self, restored_tensors, restored_shapes): restored_tensor = restored_tensors[0] if restored_shapes is not None: restored_tensor = tf.reshape(restored_tensor, restored_shapes[0]) return tf.assign( self.op, tf.cast(restored_tensor, tf.bfloat16), validate_shape=restored_shapes is None and self.op.get_shape().is_fully_defined())
def MakeCausalPadding(seq_len, block_size, left_context, right_context): """Makes the causal padding tensor for a full sequence. Args: seq_len: int or scalar int tensor. Sequence length. block_size: int. Number of time frames in a block. left_context: int. Left context size. right_context: int. Right context size. Returns: A tensor of [num_blocks, block_size, context_size] taking values in {0, 1}, where context_size = block_size + (left_context - 1) + right_context. Element b, i, j is zero if in the b-th block, the i-th frame can access the j-th frame in the context. """ seq_len = py_utils.with_dependencies([ py_utils.assert_greater_equal( seq_len, 1, message='seq_len must be at least 1') ], seq_len) num_blocks = (seq_len + block_size - 1) // block_size context_size = block_size + (left_context - 1) + right_context # [num_blocks, block_size]: source positions in the original sequence. src_positions = tf.reshape(tf.range(num_blocks * block_size), [num_blocks, block_size]) # [num_blocks,]: source positions at the start of each block. block_start_positions = tf.range(0, num_blocks * block_size, block_size) # [context_size]: positions relative to the block start. relative_context_positions = tf.range(context_size) - (left_context - 1) # [num_blocks, context_size]: target positions in the original sequence. tgt_positions = (block_start_positions[:, tf.newaxis] + relative_context_positions[tf.newaxis, :]) # [num_blocks, block_size, context_size]: position differences between source- # target pairs. position_diff = src_positions[:, :, tf.newaxis] - tgt_positions[:, tf.newaxis, :] # [num_blocks, block_size, context_size]: if attention is allowed between # source-target pairs. valid_atten = tf.math.logical_and(-right_context <= position_diff, position_diff < left_context) # [num_blocks, block_size]: if the source position is valid, not padded. valid_src = src_positions < seq_len # [num_blocks, context_size]: if the target position is valid, not padded. valid_tgt = tf.math.logical_and(0 <= tgt_positions, tgt_positions < seq_len) valid_atten &= tf.math.logical_and(valid_src[:, :, tf.newaxis], valid_tgt[:, tf.newaxis, :]) padding = 1.0 - tf.cast(valid_atten, dtype=tf.float32) return padding
def UnstackFeatures(self, src_inputs, src_paddings): """Unstacks src_input and src_paddings based off stack height.""" sh = self.params.stack_height bs, old_series_length, _, channels = py_utils.GetShape(src_inputs) unstacked_series_length = old_series_length * sh src_inputs = tf.reshape(src_inputs, [bs, unstacked_series_length, -1, channels]) content = 1 - src_paddings lengths = tf.cast(sh * tf.reduce_sum(content, axis=1), tf.int32) mask = tf.sequence_mask(lengths, maxlen=unstacked_series_length) src_paddings = 1 - tf.cast(mask, tf.int32) return src_inputs, src_paddings
def _AugmentationNetwork(self, series_length, inputs, paddings, global_seed, domain_id_index=0): """Returns augmented features. Args: series_length: Total length of time series. inputs: Batch of input features of shape (batch_size, time_length, num_freq, channels). paddings: Batch of padding vectors of shape (batch_size, time_length). global_seed: an integer seed tensor for stateless random ops. domain_id_index: domain id index. Returns: Batch of output features of shape (batch_size, time_length, num_freq, channels) obtained by applying random augmentations to inputs. """ p = self.params dtype = p.dtype # Unstack the features. if p.unstack: inputs, paddings = self.UnstackFeatures(inputs, paddings) lengths = tf.reduce_sum(1 - paddings, 1) inputs = self._TimeWarp(inputs, lengths, global_seed=global_seed, dtype=dtype, domain_id_index=domain_id_index) inputs = self._TimeMask(inputs, lengths, global_seed=global_seed, noisify=p.use_noise, gaussian_noise=p.gaussian_noise, dtype=dtype, domain_id_index=domain_id_index) inputs = self._FrequencyMask(inputs, global_seed=global_seed, dtype=dtype, domain_id_index=domain_id_index) # Restack the features after applying specaugment. if p.unstack: inputs = tf.reshape( inputs, [tf.shape(inputs)[0], series_length, -1, tf.shape(inputs)[3]]) return inputs
def _ReshapeRetVal(name, t_shape): """Restore shape for tensors in microbatches.""" if t_shape is None: return None output_tensor = output_state[name] if p.batch_dim != 0: perm = list(range(1, p.batch_dim + 1)) + [0] perm += list(range(p.batch_dim + 1, t_shape.rank + 1)) output_tensor = tf.transpose(output_tensor, perm=perm) output_shape = t_shape.ToTensorShape().as_list() output_shape[p.batch_dim] *= p.num_micro_batches output_tensor = tf.reshape(output_tensor, output_shape) return output_tensor
def _GetExpertDist(self, theta, inputs, *args): """Get the task id from inputs tensors.""" # TODO(huangyp): support the more general case when batch size is not 1. # Input shape can be either [batch, length, dim] or [length, batch, dim] reshaped_inputs = tf.reshape(inputs, [-1, self.params.cond_dim]) if self.params.nonzeros_mean: per_example_emb = tf.reduce_sum(reshaped_inputs, 0) nonzeros = tf.cast(tf.math.count_nonzero(reshaped_inputs, 0), dtype=tf.float32) per_example_emb /= (nonzeros + 1e-10) else: per_example_emb = tf.reduce_mean(reshaped_inputs, 0) expert_dist = tf.nn.sigmoid( tf.einsum('i,ij->j', per_example_emb, theta.w)) return expert_dist
def SequenceLength(padding): """Computes the length of a sequence based on binary padding. Args: padding: A tensor of binary paddings shaped [batch, seqlen]. Returns: seq_lens, A tensor of shape [batch] containing the non-padded length of each element of plot_tensor along the batch dimension. """ seq_lens = tf.cast(tf.round(tf.reduce_sum(1 - padding, axis=1)), tf.int32) # Get rid of any extra dimensions. batch_size = tf.shape(padding)[0] seq_lens = tf.reshape(seq_lens, [batch_size], name='seq_lens') return seq_lens
def RelShift(x): """Performs relative shift on 4D tensor (first 2 axis are batching dims). Given input of shape [?, ?, W, W], this does "relative shifting" for the last two dims, s.t. output[b, n, i, j] = 0 if i > j else input[b, n, i, j-i] Args: x: A Tensor of shape [?, ?, W, W] Returns: A Tensor of the same shape as input with its content shifted (as described above). """ b, n, w, _ = py_utils.GetShape(x) x = py_utils.HasShape(x, [-1, -1, w, w]) x = tf.pad(x, ((0, 0), (0, 0), (0, 0), (0, 1))) x = tf.reshape(x, [b, n, w + 1, w]) x = x[:, :, :w, :] return x
def PrepareSequenceForPlot(tensor, padding, name): """Prepares a sequence feature for plotting. The sequence feature is transposed and channels are flattened. Args: tensor: A n-D Tensor of shape [batch, time, ...]. padding: A Tensor of shape [batch, time]. name: A string as the name of the reshaped Tensor, which will be used as the subcaption for plotting. Returns: A tuple of: reshaped_tensor: A 3-D Tensor of shape [batch, dim, time]. sequence_length: A 1-D Tensor of shape [batch]. """ # Flatten any dimensions beyond the third into the third. batch_size, max_len = py_utils.GetShape(tensor, 2) plot_tensor = tf.reshape(tensor, [batch_size, max_len, -1]) plot_tensor = tf.transpose(plot_tensor, [0, 2, 1], name=name) return (plot_tensor, SequenceLength(padding))
def _ProcessMASSInput(self, source_id, src): """Perform MASS input processing.""" # TODO(yuancao): By doing so we assume that right now for monolingual # eval/dev sets (xx->xx) are in double-column format (since it bypasses # the Mass op). Ideally we should add a dedicated eval/dev processing # procedure for unsupervised MT cases, so that single-column eval/devs sets # are also supported. This should not be handled by any specific ops like # Mass, but inside the TextPackedInput class. assert not self.do_eval, 'MASS input can only be used for training.' _, labels, paddings = self.StringsToIds( tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key) weights = 1 - paddings actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32) src_lang_ids, tgt_lang_ids = self._GetTaskIds(source_id) mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len) features = py_utils.NestedMap() features.src = py_utils.NestedMap() features.src.ids = mass_out.src.ids features.src.paddings = paddings features.src.weights = weights features.src.task_ids = tf.cast( features.src.weights, dtype=tf.int32) * src_lang_ids features.src.ids_indicator = weights features.tgt = py_utils.NestedMap() features.tgt.ids = mass_out.tgt.ids features.tgt.labels = mass_out.tgt.labels features.tgt.paddings = paddings features.tgt.weights = mass_out.tgt.weights features.tgt.task_ids = tf.ones_like( features.src.task_ids, dtype=tf.int32) * tgt_lang_ids features.tgt.ids_indicator = weights if not py_utils.use_tpu(): features.src.strs = src features.tgt.strs = src return features.Transform(tf.squeeze)
def ConvertToBlocks(x, block_size, padding_val=0.0): """Turns a sequence to non overlapping blocks. Args: x: a tensor of [batch, time, ...]. block_size: int. Number of time frames in a block. padding_val: float. value on the padded frames. Returns: A tensor of [batch, num_blocks, block_size, ...], with necessary paddings, where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...]. """ shape = py_utils.GetShape(x) b, t = shape[:2] if block_size < 1: raise ValueError( 'block_size must be at least 1, got {}'.format(block_size)) w = block_size # Pad t to be a multiply of w. num_blocks = (t + w - 1) // w pad_to_length = num_blocks * w padded = py_utils.PadSequenceDimension(x, pad_to_length, padding_val) reshaped = tf.reshape(padded, [b, num_blocks, w] + shape[2:]) return reshaped
def _BeamSearchDecodeIds(self, theta, encoder_outputs, num_hyps_per_beam, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap computed by encoder. num_hyps_per_beam: Number of hyps per beam. init_beam_search_state: The InitBeamSearchState callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The PreBeamSearchStepCallback callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The PostBeamSearchStepCallback callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: hyps: A tensor of shape [time, b * k] with ids of the token selected. prev_hyps: A tensor of shape [time, b * k] with index to the previous hyps which was selected. done_hyps: A boolean tensor of shape [time, b * k] where value indicates if hyps was terminated. scores: A tensor of shape [time, b * k] with scores of the token selected. atten_probs: A tensor of shape [time, b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. eos_scores: A tensor of shape [time, b * k] with scores of the eos token selected. eos_atten_probs: A tensor of shape [time, b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. source_seq_lengths: A tensor of shape [time] containing the source seq_lengths. flat_final_other_states: A array of tensors that are part of other states. """ p = self.params source_paddings = encoder_outputs.padding initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam # We cache the NestedMap as member variable so that we can use it to # pack the final outputs. Tpu rewrite methods forces us to strictly pass # in Tensors, and output Tensors self._other_states = other_states step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 fprop_dtype = py_utils.FPropDtype(p) best_scores = (tf.zeros(shape=[num_beams], dtype=fprop_dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=fprop_dtype) histories = tf.zeros(shape=[num_hyps], dtype=tf.int32) in_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_prev_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_done_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_eos_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_eos_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) # States for beam search that are inputs into Beam search step. accum_bs_states = [best_scores, cumulative_scores, histories] # States that are not accumulators. non_accum_bs_states = [ in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, in_eos_scores, in_eos_atten_probs, ] core_bs_states = tuple(accum_bs_states + non_accum_bs_states) flat_other_states = other_states.Flatten() # If there is an optimized implementation for short sequence, LoopBodyShort # will run first for short_seq_limit steps (after which the # LoopBodyShort does not have performance benefit). Then LoopBodyLong (the # default implementation) is used to continue the rest of the steps. For # decoders which do not have the short sequence specific implementation, # only the LoopBodyLong (the default implementation) will run. if p.short_seq_limit > 0: def LoopContinueShort(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): """Use short_seq optimization when cur_step is smaller than limit.""" return tf.math.logical_and(cur_step < p.short_seq_limit, tf.math.logical_not(all_done)) def LoopBodyShort(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): """Loop body of short_seq optimization. Instead of doing computation for the entire padded sequence, while loop with early exit is used within each _BeamSearchStep to do computation for only the actual sequence (seq_length <= cur_step). use_short_seq_opt is used as the flag to pass this information down to the decoder implementation. Args: cur_step: A scalar int tensor, the current time step, 0-based. unused_all_done: A tf.bool, indicating whether the decoding finishes. step_ids: An int32 tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. other_states_list: A flattened NestedMap of other beam search states. Returns: The updated input tuple, with the same shape. """ (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=True) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) (cur_step, all_done, step_ids, core_bs_states, flat_other_states) = tf.while_loop( LoopContinueShort, LoopBodyShort, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=( tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), tuple( list(_GetShapes(accum_bs_states)) + list(_GetShapes(non_accum_bs_states, none_shapes=True))), _GetShapes(flat_other_states, none_shapes=True)), maximum_iterations=max_steps) def LoopContinueLong(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): """Continue default implementation until decoding finishes.""" return tf.math.logical_and(cur_step < max_steps, tf.math.logical_not(all_done)) def LoopBodyLong(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): """Loop body of default long_seq implementation.""" (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=False) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinueLong, LoopBodyLong, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=( tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), tuple( list(_GetShapes(accum_bs_states)) + list(_GetShapes(non_accum_bs_states, none_shapes=True))), _GetShapes(flat_other_states, none_shapes=False)), maximum_iterations=max_steps) if isinstance(source_paddings, py_utils.NestedMap): source_seq_lengths = tf.cast(tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), dtype=tf.int32) else: source_seq_lengths = tf.cast(tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), dtype=tf.int32) # Concatenate all outputs on axis=0. scores = final_bs_states[3].stack() hyps = final_bs_states[4].stack() prev_hyps = final_bs_states[5].stack() done_hyps = tf.cast(final_bs_states[6].stack(), tf.bool) atten_probs = final_bs_states[7].stack() eos_scores = final_bs_states[8].stack() eos_atten_probs = final_bs_states[9].stack() rets = (hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores, eos_atten_probs, source_seq_lengths) # TODO(rohananil): Only send a single R1 tensor to host instead of 3 after # b/111131551 is resolved. # Canonical shapes for tensors of various. ranks r_shapes = [ py_utils.GetShape(source_seq_lengths), py_utils.GetShape(hyps), py_utils.GetShape(atten_probs) ] # Reshape all tensors to [-1] to avoid cost of copy due to padding. rets_r1 = [tf.reshape(r, [-1]) for r in rets] return tuple(r_shapes) + tuple(rets_r1) + tuple( flat_final_other_states)
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs): """Merges beam search hyps from multiple decoders. Args: max_hyps_per_beam: the number of top hyps in the merged results. Must be less than or equal to total number of input hyps. beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share the same source_batch and max sequence length. Returns: A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per beam. """ source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0] value_dict = {} for output in beam_search_outputs: hyps_per_beam = py_utils.with_dependencies([ py_utils.assert_equal(source_batch, tf.shape(output.topk_hyps)[0]), ], tf.shape( output.topk_hyps)[1]) for k, v in six.iteritems(output._asdict()): if v is None: continue if k == 'done_hyps': v = tf.transpose(v) if k not in value_dict: value_dict[k] = [] value_dict[k].append( tf.reshape(v, [source_batch, hyps_per_beam, -1])) # Concatenate the tensors along the 'num_hyps_per_beam' dimension. concatenated = {} for k, values in six.iteritems(value_dict): if len(values) != len(beam_search_outputs): raise ValueError('Incomplete values for %s: %s' % (k, beam_search_outputs)) concatenated[k] = tf.concat(values, axis=1) scores = concatenated['topk_scores'] scores = tf.where(tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6), scores) scores = tf.squeeze(scores, -1) # Select top max_hyps_per_beam indices per beam. _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam) batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam]) # [source_batch, max_hyps_per_beam, 2] gather_indices = tf.stack([batch_ids, top_indices], axis=-1) # Gather the merged top hyps according to 'gather_indices'. top = beam_search_outputs[0]._asdict() total_hyps = source_batch * max_hyps_per_beam for k, v in six.iteritems(concatenated): v = tf.gather_nd(v, gather_indices) if k == 'done_hyps': v = tf.transpose(tf.reshape(v, [total_hyps, -1])) elif k == 'topk_hyps': v = tf.reshape(v, [source_batch, max_hyps_per_beam]) elif k == 'topk_ids': v = tf.reshape(v, [total_hyps, -1]) elif k in ('topk_lens', 'topk_scores', 'topk_decoded'): v = tf.reshape(v, [total_hyps]) else: raise ValueError('Unexpected field: %s' % k) top[k] = v return BeamSearchDecodeOutput(**top)
def BeamSearchDecode(self, theta, encoder_outputs, num_hyps_per_beam_override=0, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. Mostly opaque to BeamSearchHelper, except that it should contain either a 'seq_lengths' field of shape [source_batch_size] or a 'paddings' field of shape [source_max_lengths, source_batch_size]. num_hyps_per_beam_override: If set to a value <= 0, this parameter is ignored. If set to a value > 0, then this value will be used to override `p.num_hyps_per_beam`. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A `BeamSearchDecodeOutput`. """ p = self.params num_hyps_per_beam = p.num_hyps_per_beam if num_hyps_per_beam_override > 0: num_hyps_per_beam = num_hyps_per_beam_override if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype) in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype) in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string) bs_atten_probs = tf.zeros( [max_steps, num_hyps, tf.shape(initial_results.atten_probs)[1]], dtype=p.dtype) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, bs_atten_probs) def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): return tf.math.logical_and(cur_step < max_steps, tf.math.logical_not(all_done)) def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), _GetShapes(core_bs_states), _GetShapes(flat_other_states, none_shapes=True))) # [target_seq_len, num_beams * num_hyps_per_beam]. final_done_hyps = final_bs_states[5] final_other_states = other_states.Pack(flat_final_other_states) # Assume that `paddings` has shape [source_max_lengths, source_batch_size] # by default, and compute `encoded_seq_lengths` accordingly. This can be # overridden by directly passing `seq_lengths` in the `encoder_outputs` # NestedMap. encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None) if encoded_seq_lengths is None: source_paddings = encoder_outputs.padding if isinstance(source_paddings, py_utils.NestedMap): encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum( 1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), tf.int32) else: encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), tf.int32) # [num_beams, num_hyps_per_beam]. topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, encoded_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) # [num_beams * num_hyps_per_beam, ...]. max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps topk_ids, topk_lens, topk_scores = ops.unpack_hyp( tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores, None, final_other_states)
def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states, num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback): """Extend beam search hyps for one step. | num_beams = Number of source sequences to be decoded. | num_hyps_per_beam = Number of hyps to keep per source sequence. | num_hyps = num_beams * num_hyps_per_beam | src_seq_len = Number of time steps in the source sequence. | src_batch = Number of examples in the source sequence. | tgt_seq_len = Maximum allowed time steps in the target sequence. | tgt_batch = num_hyps_per_beam * src_batch Args: theta: A `.NestedMap` object containing weights' values of the decoder layer and its children layers. encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to the callbacks. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. This list is maintained by this helper class. other_states: A `.NestedMap` of other beam search states. This `.NestedMap` is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. num_hyps_per_beam: Num of hyps to keep per beam. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. See class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. See class header comments for more details. Returns: A tuple of following elements for the next beam search step, (next step, all_done, step_ids, core_bs_states, other_states) """ p = self.params bs_results, other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam) (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs) = core_bs_states (out_best_scores, out_cumulative_scores, out_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs, all_done) = ops.beam_search_step( tf.cast(bs_results.log_probs, dtype=p.dtype), tf.cast(bs_results.atten_probs, dtype=p.dtype), best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, bs_results.is_last_chunk if self._model_uses_eoc_id else [], cur_step, eoc_id=p.target_eoc_id, eos_id=p.target_eos_id, beam_size=p.beam_size, num_hyps_per_beam=num_hyps_per_beam, valid_eos_max_logit_delta=p.valid_eos_max_logit_delta, merge_paths=p.merge_paths, allow_empty_terminated_hyp=p.allow_empty_terminated_hyp, ensure_full_beam=p.ensure_full_beam, force_eos_in_last_step=p.force_eos_in_last_step, local_eos_threshold=p.local_eos_threshold) new_step_ids = tf.reshape(out_hyps[cur_step, :], tf.shape(step_ids)) new_step_ids.set_shape(step_ids.get_shape()) old_hyp_ids = tf.reshape( tf.slice(out_prev_hyps, begin=[cur_step, 0], size=[1, -1]), [-1]) if p.batch_major_compute: # Transformed the indices into the key/value cache for fast decoding # (prefix_states in other_states) due to the num_hyps dimension of # cache is computed as num_beams by num_hyps_per_beam, which is different # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams). # Both transpose and recomputation are required to correct the indices. num_beams = tf.shape(best_scores)[0] old_hyp_ids_in_cache_order = tf.reshape( tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])), [-1]) old_hyp_ids_in_cache_order = ( (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam + old_hyp_ids_in_cache_order // num_beams) new_bs_states = (out_best_scores, out_cumulative_scores, out_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs) def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and x_in.shape.ndims > 0): if x_in.shape.ndims > 2 and not p.batch_major_state: # Use corrected indices only here for batch major compute as key/value # caches are the states being affected. correct_old_hyp_ids = (old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in new_other_states = other_states.Transform(ReOrderHyps) final_other_states = post_beam_search_step_callback( theta, encoder_outputs, new_step_ids, new_other_states) return (cur_step + 1, all_done, new_step_ids, new_bs_states, final_other_states)
def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile( tensor, [num_hyps_per_beam, 1]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch])
def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states, num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=False): """Extend beam search hyps for one step. num_beams = Number of source sequences to be decoded. num_hyps_per_beam = Number of hyps to keep per source sequence. num_hyps = num_beams * num_hyps_per_beam src_seq_len = Number of time steps in the source sequence. tgt_seq_len = Maximum allowed time steps in the target sequence. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap computed by encoder. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. This list is maintained by this helper class. other_states: A NestedMap of other beam search states. This NestedMap is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. num_hyps_per_beam: Num of hyps to keep per beam. pre_beam_search_step_callback: The PreBeamSearchStepCallback callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The PostBeamSearchStepCallback callback. Please refer to the class header comments for more details. use_short_seq_opt: A bool, whether using short sequence optimization. Returns: A tuple of following elements for the next beam search step: (next step, all_done, step_ids, core_bs_states, other_states) """ p = self.params if use_short_seq_opt: bs_results, other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam, use_short_seq_opt) else: bs_results, other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam) (best_scores, cumulative_scores, histories, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, in_eos_scores, in_eos_atten_probs) = core_bs_states (out_best_scores, out_cumulative_scores, out_scores, out_eos_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs, out_eos_atten_probs, all_done, out_histories) = beam_search_tpu_ops.beam_search_step( bs_results.log_probs, bs_results.atten_probs, best_scores, cumulative_scores, histories, cur_step, eos_id=p.target_eos_id, beam_size=p.beam_size, num_beams=tf.shape(best_scores)[0], num_hyps_per_beam=num_hyps_per_beam, valid_eos_max_logit_delta=p.valid_eos_max_logit_delta, merge_paths=p.merge_paths, eoc_id=p.target_eoc_id if p.merge_paths else -1, is_last_chunk=bs_results.get('is_last_chunk')) # Write out values into TensorArray's corresponding to each output. arr_scores = in_scores.write(cur_step, out_scores) arr_eos_scores = in_eos_scores.write(cur_step, out_eos_scores) arr_hyps = in_hyps.write(cur_step, out_hyps) arr_prev_hyps = in_prev_hyps.write(cur_step, out_prev_hyps) # TODO(rohananil): Change the implementation of TensorArray write for # tf.bool from false += current_value to logical_and(true, current_value) as # addition operator for bool is not defined. arr_done_hyps = in_done_hyps.write(cur_step, tf.cast(out_done_hyps, tf.int32)) arr_atten_probs = in_atten_probs.write(cur_step, out_atten_probs) arr_eos_atten_probs = in_eos_atten_probs.write(cur_step, out_eos_atten_probs) # New beam search states. new_bs_states = (out_best_scores, out_cumulative_scores, out_histories, arr_scores, arr_hyps, arr_prev_hyps, arr_done_hyps, arr_atten_probs, arr_eos_scores, arr_eos_atten_probs) old_hyp_ids = tf.reshape(out_prev_hyps, [-1]) if p.batch_major_compute: # Transformed the indices into the key/value cache for fast decoding # (prefix_states in other_states) due to the num_hyps dimension of # cache is computed as num_beams by num_hyps_per_beam, which is different # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams). # Both transpose and recomputation are required to correct the indices. num_beams = tf.shape(best_scores)[0] old_hyp_ids_in_cache_order = tf.reshape( tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])), [-1]) old_hyp_ids_in_cache_order = ( (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam + old_hyp_ids_in_cache_order // num_beams) def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if isinstance(x_in, tf.Tensor) and x_in.shape.ndims > 0: # For rank > 1 tensors we make use of an efficient matmul based gather # on tpu that takes in account the range of the values. For R1, we # rely on the tf.gather and xla to optimize it efficiently for R1 # layout. if x_in.shape.ndims > 1: if p.batch_major_state: num_hyps = tf.shape(old_hyp_ids)[0] x_out = beam_search_tpu_ops.fast_gather( x_in, old_hyp_ids, num_hyps, max_value=None, batch_major_state=p.batch_major_state) else: # Use corrected indices only here for batch major compute as # key/value caches are the states being affected. correct_old_hyp_ids = (old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) def _GatherStep(x_in, t): """Gather for one time step. Args: x_in: in the shape of [T, B, ...] we first get slice(t) from the tensors, then gather old_hyp_ids from the slice and write the interpolated slice inplace to update the original x_in. t: current time step Returns: Updated x_in and time step """ x = tf.gather(tf.gather(x_in, t), correct_old_hyp_ids) return inplace_ops.alias_inplace_update( x_in, t, x), t + 1 x_out, _ = tf.while_loop( lambda _, t: t <= cur_step, _GatherStep, (x_in, tf.zeros([], tf.int32))) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in new_other_states = other_states.Transform(ReOrderHyps) new_step_ids = tf.reshape(out_hyps, [-1, 1]) final_other_states = post_beam_search_step_callback( theta, encoder_outputs, new_step_ids, new_other_states) return (cur_step + 1, all_done, new_step_ids, new_bs_states, final_other_states)
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where(tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def _ReshapeBackToHigherRank(inps, r_shape): for i in range(len(inps)): inps[i] = tf.reshape(inps[i], r_shape) return inps
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` object containing: ids - The inputs tensor of shape [batch, time]. paddings - The ids' paddings of shape [batch, time]. Returns: A '.NestedMap' object containing: encoded - The encoded features of shape [time, batch, dim] or [batch, time, dim], depending p.output_data_format. padding - The encoded features' padding of shape [time, batch] or [batch, time]. segment_id - The segmentation of packed inputs of shape [time, batch] or [batch, time] if it is supported by the model, or None otherwise. embedded_inputs - The embedded inputs tokens without positional encodings of shape [time, batch, dim] or [batch, time, dim]. """ p = self.params with tf.name_scope(p.name): # [batch, time] input_ids = input_batch.ids # [batch, time] paddings = input_batch.paddings # [batch, time] segment_ids = input_batch.segment_ids if p.packed_input else None batch = py_utils.GetShape(input_ids)[0] time = py_utils.GetShape(input_ids)[1] # Embedding layer. # [batch, time, dim] if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, input_ids) else: input_embs = self.softmax.EmbLookup(theta.softmax, input_ids) orig_input_embs = input_embs # [1, time, dim] if p.packed_input: positions = input_batch.segment_pos position_embs = tf.expand_dims( self.position_emb.FPropWithPosition( theta.position_emb, positions), 0) else: position_embs = tf.expand_dims( self.position_emb.FProp(theta.position_emb, time), 0) # [batch, time, dim] input_embs += tf.cast(position_embs, tf.bfloat16) if p.input_dropout_tpl.fprop_dtype: input_embs = tf.cast(input_embs, p.input_dropout_tpl.fprop_dtype) paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [batch, time, dim] transformer_input = input_embs # Explicitly set the input shape of Transformer layers, to avoid # unknown shape error occurred to tf.einsum on nonTPU devices. transformer_input = tf.reshape(transformer_input, [batch, time, p.model_dim]) # Compute self-attention segment mask once. if p.packed_input: segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids, dtype=transformer_input.dtype) else: segment_mask = tf.zeros([batch, 1, time, time]) encoded, padding = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, segment_mask) if p.final_layer_norm: encoded = self.final_ln.FProp(theta.final_ln, encoded) seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1), tf.int32) if p.output_data_format == 'TBC': encoded = tf.transpose(encoded, [1, 0, 2]) # [time, batch, dim] padding = tf.transpose(padding) # [time, batch] segment_ids = tf.transpose( segment_ids) if p.packed_input else None orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2]) return py_utils.NestedMap( encoded=encoded, padding=padding, seq_lengths=seq_lengths, # used by beam_search_helper. segment_id=segment_ids, embedded_inputs=orig_input_embs)
def BeamSearchDecodePostProcess(self, num_hyps_per_beam, max_steps, r1_shape, r2_shape, r3_shape, hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores, eos_atten_probs, source_seq_lengths, *flat_final_other_states): """Beam search post processing functions on CPUs. Args: num_hyps_per_beam: Number of hyps per beam. max_steps: Maximum number of beam search steps. r1_shape: A tensor of shape [1] with value [time]. r2_shape: A tensor of shape [2] with values [time, b * k]. r3_shape: A tensor of shape [3] with values [time, b * k, seq_len]. hyps: A tensor of shape [1] with ids of the token selected. prev_hyps: A tensor of shape [time * b * k] with index to the previous hyps which was selected. done_hyps: A boolean tensor of shape [time * b * k] where value indicates if hyps was terminated. scores: A tensor of shape [time * b * k] with scores of the token selected. atten_probs: A tensor of shape [time * b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. eos_scores: A tensor of shape [time * b * k] with scores of the eos token selected. eos_atten_probs: A tensor of shape [time * b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. source_seq_lengths: A tensor of shape [time] containing the source seq_lengths. *flat_final_other_states: A array of tensors that are part of other states. Returns: final_done_hyps: A tensor of shape [time, b * k] containing `Hypothesis` pbs containing terminated hyps. topk_hyps, topk_ids, topk_lens, topk_scores: Top K terminated Hyps. flat_final_other_states: A array of tensors that are part of other states. """ p = self.params def _ReshapeBackToHigherRank(inps, r_shape): for i in range(len(inps)): inps[i] = tf.reshape(inps[i], r_shape) return inps # Reshape all tensors back to original shapes of rank 1, 2 and 3. r1_inps = [source_seq_lengths] r1_inps = _ReshapeBackToHigherRank(r1_inps, r1_shape) r2_inps = [hyps, prev_hyps, done_hyps, scores, eos_scores] r2_inps = _ReshapeBackToHigherRank(r2_inps, r2_shape) r3_inps = [atten_probs, eos_atten_probs] r3_inps = _ReshapeBackToHigherRank(r3_inps, r3_shape) (source_seq_lengths, hyps, prev_hyps, done_hyps, scores, eos_scores, atten_probs, eos_atten_probs) = (r1_inps + r2_inps + r3_inps) final_done_hyps = ops.hyps_from_beam_search_outs( hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores, eos_atten_probs, eos_id=p.target_eos_id, num_hyps_per_beam=num_hyps_per_beam) topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, source_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) topk_ids, topk_lens, topk_scores = ops.unpack_hyp( topk_hyps, max_seq_length=max_steps) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return (final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores) + tuple(flat_final_other_states)