def FProp(self, theta, input_batch): """Encodes source as represented by `inputs` and `paddings`. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. Returns: A NestedMap containing: - encoded: The encoded features, a tensor of shape [time, batch, depth] - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. """ p = self.params src_segment_id = None with tf.name_scope(p.name): # Now the rnn layers. inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) self._emb_out = xs ps = paddings # When cc_schedule is specified, make sure lstm_tpl is QuantizedLSTMCell # with the same cc_schedule so that the RNN layer output is within # clipping range. xs = self.rnn[0].FProp(theta.rnn[0], xs, ps) xs = self.dropout.FProp(theta.dropout, xs) for i in range(1, p.num_lstm_layers): layer = self.rnn[i] ys, _ = layer.FProp(theta.rnn[i], xs, ps) ys = self.dropout.FProp(theta.dropout, ys) if hasattr(layer.params, 'cell'): layer_params = layer.params.cell else: layer_params = layer.params if layer_params.num_input_nodes == layer_params.num_output_nodes: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: # When cc_schedule is specified, make sure lstm_tpl is # QuantizedLSTMCell with the same cc_schedule so that the RNN layer # output is within clipping range. xs = ys return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def FProp(self, theta, input_batch): p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) if p.packed_input: src_segment_id = tf.expand_dims( tf.transpose(input_batch.segment_ids), 2) else: src_segment_id = None xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys = layer.FProp(theta.rnn[i], xs, ps, segment_id=src_segment_id) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) if p.lstm_cell_size * 2 != p.encoder_out_dim: # Project to the right depth. xs = self.final_proj.FProp(theta.final_proj, xs, ps) summary_utils.histogram('final_proj_out', xs) if src_segment_id is not None: src_segment_id = tf.squeeze(src_segment_id, [2]) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def FProp(self, theta, input_batch, state0=None): p = self.params src_segment_id = None with tf.name_scope(p.name): # Reshape to [t, b] inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) # Setup streaming states. if not state0: state0 = self.zero_state(theta, tf.shape(inputs)[1]) state1 = py_utils.NestedMap(rnn=[None] * p.num_lstm_layers) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys, state1.rnn[i] = layer.FProp(theta.rnn[i], xs, ps, state0=state0.rnn[i]) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id, state=state1)
def _matmul_gather(self, values, axis=0, batch_major_state=True): """Returns values gathered. Args: values: Values to gather from. axis: Axis to gather on. Defaults to 0 (rows). batch_major_state: Whether the values to gather from use batch major or not. Defaults to True. For Transformer model, batch_major_state is set to False (time is the major dim). Returns: Gathered values. Raises: NotImplemented error if axis is not 0 nor 1. """ dtype = values.dtype if dtype != tf.float32 and dtype != tf.bfloat16: values = tf.cast(values, tf.float32) if axis == 0: if values.shape.rank is not None and values.shape.rank > 2: if not batch_major_state: values = tf.transpose(values, [1, 0, 2]) results = tf.cast( tf.gather(values, tf.cast(self._ids, tf.int32)), dtype) # pylint:disable=g-long-ternary return (tf.transpose(results, [1, 0, 2]) if not batch_major_state else results) # pylint:enable=g-long-ternary else: one_hot_ids = tf.one_hot(self._ids, self._ids_size, dtype=values.dtype) return tf.cast(tf.matmul(one_hot_ids, values), dtype) elif axis == 1: one_hot_ids = tf.one_hot(self._ids, self._ids_size, dtype=values.dtype, axis=0) return tf.cast(tf.matmul(values, one_hot_ids), dtype) else: raise NotImplementedError("Only row/col-wise gather implemented.")
def _ReshapeRetVal(name, t_shape): """Restore shape for tensors in microbatches.""" if t_shape is None: return None output_tensor = output_state[name] if p.batch_dim != 0: perm = list(range(1, p.batch_dim + 1)) + [0] perm += list(range(p.batch_dim + 1, t_shape.rank + 1)) output_tensor = tf.transpose(output_tensor, perm=perm) output_shape = t_shape.ToTensorShape().as_list() output_shape[p.batch_dim] *= p.num_micro_batches output_tensor = tf.reshape(output_tensor, output_shape) return output_tensor
def PrepareSequenceForPlot(tensor, padding, name): """Prepares a sequence feature for plotting. The sequence feature is transposed and channels are flattened. Args: tensor: A n-D Tensor of shape [batch, time, ...]. padding: A Tensor of shape [batch, time]. name: A string as the name of the reshaped Tensor, which will be used as the subcaption for plotting. Returns: A tuple of: reshaped_tensor: A 3-D Tensor of shape [batch, dim, time]. sequence_length: A 1-D Tensor of shape [batch]. """ # Flatten any dimensions beyond the third into the third. batch_size, max_len = py_utils.GetShape(tensor, 2) plot_tensor = tf.reshape(tensor, [batch_size, max_len, -1]) plot_tensor = tf.transpose(plot_tensor, [0, 2, 1], name=name) return (plot_tensor, SequenceLength(padding))
def AddAttentionSummary(name, attention_tensors, src_paddings, tgt_paddings, transcripts=None, max_outputs=3): """Adds an image summary showing the attention probability matrix and state. Tensors are in sequence tensor format with the batch dimension in axis 1. Args: name: Summary name. attention_tensors: A list of 3D tensors shaped [target_len, batch_size, source_len] where attention[i, j, k] is the probability for the i-th output attending to the k-th input for element j in the batch. src_paddings: A tensor of binary paddings shaped [source_len, batch] for the source sequence. Or a list of tensors of the same length as attention_tensors with a separate paddings for each entry in attention_tensors. tgt_paddings: A tensor of binary paddings shaped [target_len, batch] for the target sequence. Or a list of tensors of the same length as attention_tensors with a separate paddings for each entry in attention_tensors. transcripts: Optional, transcripts shaped [batch, source_len] for the source sequence. max_outputs: Integer maximum number of elements of the batch to plot. """ def Transpose(paddings): paddings = paddings if isinstance(paddings, list) else [paddings] return [tf.transpose(p) for p in paddings] AddAttentionSummaryBatchMajor( name, [tf.transpose(a, [1, 0, 2]) for a in attention_tensors], Transpose(src_paddings), Transpose(tgt_paddings), transcripts, max_outputs)
def _BeamSearchDecodeIds(self, theta, encoder_outputs, num_hyps_per_beam, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap computed by encoder. num_hyps_per_beam: Number of hyps per beam. init_beam_search_state: The InitBeamSearchState callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The PreBeamSearchStepCallback callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The PostBeamSearchStepCallback callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: hyps: A tensor of shape [time, b * k] with ids of the token selected. prev_hyps: A tensor of shape [time, b * k] with index to the previous hyps which was selected. done_hyps: A boolean tensor of shape [time, b * k] where value indicates if hyps was terminated. scores: A tensor of shape [time, b * k] with scores of the token selected. atten_probs: A tensor of shape [time, b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. eos_scores: A tensor of shape [time, b * k] with scores of the eos token selected. eos_atten_probs: A tensor of shape [time, b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. source_seq_lengths: A tensor of shape [time] containing the source seq_lengths. flat_final_other_states: A array of tensors that are part of other states. """ p = self.params source_paddings = encoder_outputs.padding initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam # We cache the NestedMap as member variable so that we can use it to # pack the final outputs. Tpu rewrite methods forces us to strictly pass # in Tensors, and output Tensors self._other_states = other_states step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 fprop_dtype = py_utils.FPropDtype(p) best_scores = (tf.zeros(shape=[num_beams], dtype=fprop_dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=fprop_dtype) histories = tf.zeros(shape=[num_hyps], dtype=tf.int32) in_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_prev_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_done_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_eos_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_eos_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) # States for beam search that are inputs into Beam search step. accum_bs_states = [best_scores, cumulative_scores, histories] # States that are not accumulators. non_accum_bs_states = [ in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, in_eos_scores, in_eos_atten_probs, ] core_bs_states = tuple(accum_bs_states + non_accum_bs_states) flat_other_states = other_states.Flatten() # If there is an optimized implementation for short sequence, LoopBodyShort # will run first for short_seq_limit steps (after which the # LoopBodyShort does not have performance benefit). Then LoopBodyLong (the # default implementation) is used to continue the rest of the steps. For # decoders which do not have the short sequence specific implementation, # only the LoopBodyLong (the default implementation) will run. if p.short_seq_limit > 0: def LoopContinueShort(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): """Use short_seq optimization when cur_step is smaller than limit.""" return tf.math.logical_and(cur_step < p.short_seq_limit, tf.math.logical_not(all_done)) def LoopBodyShort(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): """Loop body of short_seq optimization. Instead of doing computation for the entire padded sequence, while loop with early exit is used within each _BeamSearchStep to do computation for only the actual sequence (seq_length <= cur_step). use_short_seq_opt is used as the flag to pass this information down to the decoder implementation. Args: cur_step: A scalar int tensor, the current time step, 0-based. unused_all_done: A tf.bool, indicating whether the decoding finishes. step_ids: An int32 tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. other_states_list: A flattened NestedMap of other beam search states. Returns: The updated input tuple, with the same shape. """ (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=True) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) (cur_step, all_done, step_ids, core_bs_states, flat_other_states) = tf.while_loop( LoopContinueShort, LoopBodyShort, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=( tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), tuple( list(_GetShapes(accum_bs_states)) + list(_GetShapes(non_accum_bs_states, none_shapes=True))), _GetShapes(flat_other_states, none_shapes=True)), maximum_iterations=max_steps) def LoopContinueLong(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): """Continue default implementation until decoding finishes.""" return tf.math.logical_and(cur_step < max_steps, tf.math.logical_not(all_done)) def LoopBodyLong(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): """Loop body of default long_seq implementation.""" (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=False) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinueLong, LoopBodyLong, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=( tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), tuple( list(_GetShapes(accum_bs_states)) + list(_GetShapes(non_accum_bs_states, none_shapes=True))), _GetShapes(flat_other_states, none_shapes=False)), maximum_iterations=max_steps) if isinstance(source_paddings, py_utils.NestedMap): source_seq_lengths = tf.cast(tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), dtype=tf.int32) else: source_seq_lengths = tf.cast(tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), dtype=tf.int32) # Concatenate all outputs on axis=0. scores = final_bs_states[3].stack() hyps = final_bs_states[4].stack() prev_hyps = final_bs_states[5].stack() done_hyps = tf.cast(final_bs_states[6].stack(), tf.bool) atten_probs = final_bs_states[7].stack() eos_scores = final_bs_states[8].stack() eos_atten_probs = final_bs_states[9].stack() rets = (hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores, eos_atten_probs, source_seq_lengths) # TODO(rohananil): Only send a single R1 tensor to host instead of 3 after # b/111131551 is resolved. # Canonical shapes for tensors of various. ranks r_shapes = [ py_utils.GetShape(source_seq_lengths), py_utils.GetShape(hyps), py_utils.GetShape(atten_probs) ] # Reshape all tensors to [-1] to avoid cost of copy due to padding. rets_r1 = [tf.reshape(r, [-1]) for r in rets] return tuple(r_shapes) + tuple(rets_r1) + tuple( flat_final_other_states)
def _ConstructWarpMatrix(self, batch_size, matrix_size, origin, destination, choose_range, dtype): """Returns warp matrices according to origin, destination and choose_range. This function constructs a batch of warp matrices which maps the batch of origin points to the batch of destination points with fixed boundary coordinates at 0 and choose_range. The warping function, defined by the origin anchor point `origin`, the destination of the origin anchor point `destination` and the length of the domain in the warping axis `choose_range` is a piecewise linear map that fixes the points 0 and `choose_range` and maps `origin` to `destination`. For the warping matrix to be non-singular, destination must lie in the range 1<= destination <= choose_range - 1, so a destination out of this range is adjusted to be in this range before the warping matrix is constructed. The warping map can be explicitly written by first defining the slopes: 1) slope_0 = origin / destination. 2) slope_1 = (choose_range - origin) / (choose_range - destination). 3) slope_2 = 1.0. Then the origin point orig_i of the mapped coordinate i is given by: 1) i < destination: orig_i = slope_0 * i. 2) destination <= i < choose_range: orig_i = slope_1 * i - (slope_1 - slope_0) * destination. 3) i >= choose_range: orig_i = i. Denoting n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by: 1) j = n_i: 1 - n_i + orig_i. 2) j = n_i - 1: n_i - orig_i. 3) Otherwise: 0. Applying the warp matrix to an array of pixels, i.e., warped_pixel[i] = sum_j warp[i][j] * pixel[j], one would get warped_pixel[i] = (n_i-orig_i) pixel[n_i-1] + (1-n_i+orig_i) pixel[n_i]. Args: batch_size: Batch size. Integer number. matrix_size: Dimension of the vector space the warp matrix is applied to. Integer number. origin: Origin anchor point for warping. Tensor of shape (batch_size,) and data type dtype. destination: Destination of the origin anchor point upon warping. Tensor of shape (batch_size,) and data type dtype. choose_range: Range within which the warp reference points must lie. Tensor of shape (batch_size,) data type dtype. dtype: Data type of origin, destination, choose_range and the output warp matrix. Returns: warp_matrix: An array of fixed size warp matrices with shape (batch_size, matrix_size, matrix_size). """ p = self.params # Entries of destination must be in the range # 1 <= destination <= choose_range - 1 # for warp matrix to have non-singular values. destination = tf.minimum(tf.maximum(destination, 1.0), choose_range - 1.0) # Construct piece-wise linear function fixing boundary points # specified by zero, choose_range and matrix size and maps # the origin anchor point to the destination. destination_bc = tf.broadcast_to(destination, (matrix_size, batch_size)) destination_bc = tf.transpose(destination_bc) choose_range_bc = tf.broadcast_to(choose_range, (matrix_size, batch_size)) choose_range_bc = tf.transpose(choose_range_bc) # Slopes of piece-wise linear function. slope_0 = origin / destination slope_1 = (choose_range - origin) / (choose_range - destination) slope_2 = 1.0 # x is a batch of origin matrices. # The origin matrix is the matrix such that # origin[i][j] = Origin coordinate of coordinate i for the warp map. # Denoting the destination of the origin anchor point in the # warp map as "dest," the origin coordinate of point i is given by: # 1) i < dest: slope_0 * i. # 2) dest <= i < choose_range: slope_1 * i - (slope_1 - slope_0) * dest. # 3) i >= choose_range: i. x = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype), (batch_size, matrix_size)) x = (self.EinsumBBmBm(slope_0, x) + self.EinsumBBmBm( slope_1 - slope_0, tf.nn.relu(x - destination_bc)) + self.EinsumBBmBm(slope_2 - slope_1, tf.nn.relu(x - choose_range_bc))) x = tf.broadcast_to(x, (matrix_size, batch_size, matrix_size)) x = tf.transpose(x, perm=[1, 2, 0]) # y is a batch of coordinate matrices. # A coordinate matrix is a matrix such that # coordinate[i][j] = j. y = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype), (batch_size, matrix_size, matrix_size)) # Warp matrix is obtained by applying hat function element-wise to (x-y). # Denoting the origin point of i under the warp map as orig_i, # and n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by: # 1) j = n_i: 1 - n_i + orig_i. # 2) j = n_i - 1: n_i - orig_i. # 3) Otherwise: 0. # Applying the warp matrix to pixels, i.e., # warped_pixel[i] = sum_j warp[i][j] * original_pixel[j], one would get # warped_pixel[i] = (n_i - orig_i) * original_pixel[n_i-1] # + (1 - n_i + orig_i) * original_pixel[n_i]. warp_matrix = x - y warp_matrix = _hat(warp_matrix) if p.fprop_dtype is not None and p.fprop_dtype != dtype: warp_matrix = tf.cast(warp_matrix, p.fprop_dtype) return warp_matrix
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs): """Merges beam search hyps from multiple decoders. Args: max_hyps_per_beam: the number of top hyps in the merged results. Must be less than or equal to total number of input hyps. beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share the same source_batch and max sequence length. Returns: A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per beam. """ source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0] value_dict = {} for output in beam_search_outputs: hyps_per_beam = py_utils.with_dependencies([ py_utils.assert_equal(source_batch, tf.shape(output.topk_hyps)[0]), ], tf.shape( output.topk_hyps)[1]) for k, v in six.iteritems(output._asdict()): if v is None: continue if k == 'done_hyps': v = tf.transpose(v) if k not in value_dict: value_dict[k] = [] value_dict[k].append( tf.reshape(v, [source_batch, hyps_per_beam, -1])) # Concatenate the tensors along the 'num_hyps_per_beam' dimension. concatenated = {} for k, values in six.iteritems(value_dict): if len(values) != len(beam_search_outputs): raise ValueError('Incomplete values for %s: %s' % (k, beam_search_outputs)) concatenated[k] = tf.concat(values, axis=1) scores = concatenated['topk_scores'] scores = tf.where(tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6), scores) scores = tf.squeeze(scores, -1) # Select top max_hyps_per_beam indices per beam. _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam) batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam]) # [source_batch, max_hyps_per_beam, 2] gather_indices = tf.stack([batch_ids, top_indices], axis=-1) # Gather the merged top hyps according to 'gather_indices'. top = beam_search_outputs[0]._asdict() total_hyps = source_batch * max_hyps_per_beam for k, v in six.iteritems(concatenated): v = tf.gather_nd(v, gather_indices) if k == 'done_hyps': v = tf.transpose(tf.reshape(v, [total_hyps, -1])) elif k == 'topk_hyps': v = tf.reshape(v, [source_batch, max_hyps_per_beam]) elif k == 'topk_ids': v = tf.reshape(v, [total_hyps, -1]) elif k in ('topk_lens', 'topk_scores', 'topk_decoded'): v = tf.reshape(v, [total_hyps]) else: raise ValueError('Unexpected field: %s' % k) top[k] = v return BeamSearchDecodeOutput(**top)
def BeamSearchDecode(self, theta, encoder_outputs, num_hyps_per_beam_override=0, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. Mostly opaque to BeamSearchHelper, except that it should contain either a 'seq_lengths' field of shape [source_batch_size] or a 'paddings' field of shape [source_max_lengths, source_batch_size]. num_hyps_per_beam_override: If set to a value <= 0, this parameter is ignored. If set to a value > 0, then this value will be used to override `p.num_hyps_per_beam`. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A `BeamSearchDecodeOutput`. """ p = self.params num_hyps_per_beam = p.num_hyps_per_beam if num_hyps_per_beam_override > 0: num_hyps_per_beam = num_hyps_per_beam_override if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype) in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype) in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string) bs_atten_probs = tf.zeros( [max_steps, num_hyps, tf.shape(initial_results.atten_probs)[1]], dtype=p.dtype) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, bs_atten_probs) def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): return tf.math.logical_and(cur_step < max_steps, tf.math.logical_not(all_done)) def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), _GetShapes(core_bs_states), _GetShapes(flat_other_states, none_shapes=True))) # [target_seq_len, num_beams * num_hyps_per_beam]. final_done_hyps = final_bs_states[5] final_other_states = other_states.Pack(flat_final_other_states) # Assume that `paddings` has shape [source_max_lengths, source_batch_size] # by default, and compute `encoded_seq_lengths` accordingly. This can be # overridden by directly passing `seq_lengths` in the `encoder_outputs` # NestedMap. encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None) if encoded_seq_lengths is None: source_paddings = encoder_outputs.padding if isinstance(source_paddings, py_utils.NestedMap): encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum( 1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), tf.int32) else: encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), tf.int32) # [num_beams, num_hyps_per_beam]. topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, encoded_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) # [num_beams * num_hyps_per_beam, ...]. max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps topk_ids, topk_lens, topk_scores = ops.unpack_hyp( tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores, None, final_other_states)
def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states, num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback): """Extend beam search hyps for one step. | num_beams = Number of source sequences to be decoded. | num_hyps_per_beam = Number of hyps to keep per source sequence. | num_hyps = num_beams * num_hyps_per_beam | src_seq_len = Number of time steps in the source sequence. | src_batch = Number of examples in the source sequence. | tgt_seq_len = Maximum allowed time steps in the target sequence. | tgt_batch = num_hyps_per_beam * src_batch Args: theta: A `.NestedMap` object containing weights' values of the decoder layer and its children layers. encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to the callbacks. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. This list is maintained by this helper class. other_states: A `.NestedMap` of other beam search states. This `.NestedMap` is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. num_hyps_per_beam: Num of hyps to keep per beam. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. See class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. See class header comments for more details. Returns: A tuple of following elements for the next beam search step, (next step, all_done, step_ids, core_bs_states, other_states) """ p = self.params bs_results, other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam) (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs) = core_bs_states (out_best_scores, out_cumulative_scores, out_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs, all_done) = ops.beam_search_step( tf.cast(bs_results.log_probs, dtype=p.dtype), tf.cast(bs_results.atten_probs, dtype=p.dtype), best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, bs_results.is_last_chunk if self._model_uses_eoc_id else [], cur_step, eoc_id=p.target_eoc_id, eos_id=p.target_eos_id, beam_size=p.beam_size, num_hyps_per_beam=num_hyps_per_beam, valid_eos_max_logit_delta=p.valid_eos_max_logit_delta, merge_paths=p.merge_paths, allow_empty_terminated_hyp=p.allow_empty_terminated_hyp, ensure_full_beam=p.ensure_full_beam, force_eos_in_last_step=p.force_eos_in_last_step, local_eos_threshold=p.local_eos_threshold) new_step_ids = tf.reshape(out_hyps[cur_step, :], tf.shape(step_ids)) new_step_ids.set_shape(step_ids.get_shape()) old_hyp_ids = tf.reshape( tf.slice(out_prev_hyps, begin=[cur_step, 0], size=[1, -1]), [-1]) if p.batch_major_compute: # Transformed the indices into the key/value cache for fast decoding # (prefix_states in other_states) due to the num_hyps dimension of # cache is computed as num_beams by num_hyps_per_beam, which is different # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams). # Both transpose and recomputation are required to correct the indices. num_beams = tf.shape(best_scores)[0] old_hyp_ids_in_cache_order = tf.reshape( tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])), [-1]) old_hyp_ids_in_cache_order = ( (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam + old_hyp_ids_in_cache_order // num_beams) new_bs_states = (out_best_scores, out_cumulative_scores, out_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs) def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and x_in.shape.ndims > 0): if x_in.shape.ndims > 2 and not p.batch_major_state: # Use corrected indices only here for batch major compute as key/value # caches are the states being affected. correct_old_hyp_ids = (old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in new_other_states = other_states.Transform(ReOrderHyps) final_other_states = post_beam_search_step_callback( theta, encoder_outputs, new_step_ids, new_other_states) return (cur_step + 1, all_done, new_step_ids, new_bs_states, final_other_states)
def Transpose(paddings): paddings = paddings if isinstance(paddings, list) else [paddings] return [tf.transpose(p) for p in paddings]
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where(tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def GreedySearchDecode(self, theta, encoder_outputs, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs greedy-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as src_batch_size. - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos> token is encountered during search. - hyp_lens: [num_hyps]. - done_hyps: [num_hyps], whether or not an eos is encountered. """ p = self.params if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, 1 # num_hyps_per_beam ) num_hyps = tf.shape(initial_results.log_probs)[0] if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) cur_step = tf.constant(0, dtype=tf.int32) done_hyps = inplace_ops.empty(shape=[num_hyps], dtype=tf.bool, init=True, name='done_hyps') hyp_lens = inplace_ops.empty(shape=[num_hyps], dtype=tf.int32, init=True, name='hyp_lens') hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps], dtype=tf.int32, init=True, name='hyp_ids') def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids, unused_hyp_lens, done_hyps, unused_other_states_list): return tf.math.logical_and( cur_step < max_steps, tf.math.logical_not(tf.reduce_all(done_hyps))) def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states_list): (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states) = self._GreedySearchStep( theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states.Pack(other_states_list), pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(step_ids.get_shape()), tf.TensorShape(hyp_ids.get_shape()), tf.TensorShape(hyp_lens.get_shape()), tf.TensorShape(done_hyps.get_shape()), _GetShapes(flat_other_states, none_shapes=True))) # transpose hyp_ids so it matches BeamSearchDecode's output final_hyp_ids = tf.transpose(final_hyp_ids) return final_hyp_ids, final_hyp_lens, final_done_hyps
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` object containing: ids - The inputs tensor of shape [batch, time]. paddings - The ids' paddings of shape [batch, time]. Returns: A '.NestedMap' object containing: encoded - The encoded features of shape [time, batch, dim] or [batch, time, dim], depending p.output_data_format. padding - The encoded features' padding of shape [time, batch] or [batch, time]. segment_id - The segmentation of packed inputs of shape [time, batch] or [batch, time] if it is supported by the model, or None otherwise. embedded_inputs - The embedded inputs tokens without positional encodings of shape [time, batch, dim] or [batch, time, dim]. """ p = self.params with tf.name_scope(p.name): # [batch, time] input_ids = input_batch.ids # [batch, time] paddings = input_batch.paddings # [batch, time] segment_ids = input_batch.segment_ids if p.packed_input else None batch = py_utils.GetShape(input_ids)[0] time = py_utils.GetShape(input_ids)[1] # Embedding layer. # [batch, time, dim] if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, input_ids) else: input_embs = self.softmax.EmbLookup(theta.softmax, input_ids) orig_input_embs = input_embs # [1, time, dim] if p.packed_input: positions = input_batch.segment_pos position_embs = tf.expand_dims( self.position_emb.FPropWithPosition( theta.position_emb, positions), 0) else: position_embs = tf.expand_dims( self.position_emb.FProp(theta.position_emb, time), 0) # [batch, time, dim] input_embs += tf.cast(position_embs, tf.bfloat16) if p.input_dropout_tpl.fprop_dtype: input_embs = tf.cast(input_embs, p.input_dropout_tpl.fprop_dtype) paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [batch, time, dim] transformer_input = input_embs # Explicitly set the input shape of Transformer layers, to avoid # unknown shape error occurred to tf.einsum on nonTPU devices. transformer_input = tf.reshape(transformer_input, [batch, time, p.model_dim]) # Compute self-attention segment mask once. if p.packed_input: segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids, dtype=transformer_input.dtype) else: segment_mask = tf.zeros([batch, 1, time, time]) encoded, padding = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, segment_mask) if p.final_layer_norm: encoded = self.final_ln.FProp(theta.final_ln, encoded) seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1), tf.int32) if p.output_data_format == 'TBC': encoded = tf.transpose(encoded, [1, 0, 2]) # [time, batch, dim] padding = tf.transpose(padding) # [time, batch] segment_ids = tf.transpose( segment_ids) if p.packed_input else None orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2]) return py_utils.NestedMap( encoded=encoded, padding=padding, seq_lengths=seq_lengths, # used by beam_search_helper. segment_id=segment_ids, embedded_inputs=orig_input_embs)
def Sample(self, decoder_theta, encoder_outputs, random_seed, init_state_callback, pre_step_callback, post_step_callback): """Samples target sequences, one target sequence per source sequence. (Please see beam_search_helper.py for description of decoder callbacks.) Args: decoder_theta: A NestedMap object containing weights' values of the decoder layer and its children layers, to be passed to decoder callbacks. encoder_outputs: the outputs of the encoder, to be passed to callbacks. random_seed: a scalar int32 tensor representing the random seed. init_state_callback: decoder._InitBeamSearchStateCallback. pre_step_callback: decoder._PreBeamSearchStepCallback. post_step_callback: decoder._PostBeamSearchStepCallback. Returns: A NestedMap containing the following tensors - 'logits': [batch, max_target_length, vocab_size], representing the distribution from which target sequences are sampled. - 'ids': [batch, max_target_length] of int32, representing the target sequence ids, not including target_sos_id, but maybe ending with target_eos_id if end-of-sequence is reached before target_seq_len. - 'paddings': [batch, max_target_length] of 0/1, where 1 represents a padded timestep. """ p = self.params assert p.temperature > 0 if getattr(encoder_outputs, 'segment_id', 1) is None: # Remove None values, which are not supported by recurrent. del encoder_outputs['segment_id'] # init_state_callback may modify 'encoder_outputs', e.g., by inserting # 'packed_src'. bs_result, bs_state = init_state_callback(decoder_theta, encoder_outputs, num_hyps_per_beam=1) # 'recurrent_theta' represents all cross-timestep information used by the # recurrent loop below, including layer theta and encoder outputs. recurrent_theta = py_utils.NestedMap(theta=decoder_theta, random_seed=random_seed, encoder_outputs=encoder_outputs) batch = tf.shape(bs_result.log_probs)[0] recurrent_state0 = py_utils.NestedMap( timestep=tf.zeros(shape=[], dtype=tf.int32), logits=bs_result.log_probs, # Start with target_sos_id. ids=tf.fill([batch], tf.cast(p.target_sos_id, tf.int32)), bs_state=bs_state) inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch])) def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_categorical( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap() accumulated_states, _ = recurrent.Recurrent( recurrent_theta, recurrent_state0, inputs, Step, allow_implicit_capture=True) result = py_utils.NestedMap(logits=tf.transpose( accumulated_states.logits, [1, 0, 2]), ids=tf.transpose(accumulated_states.ids)) result.paddings = tf.cast( _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype) # Force ids to be eos_id if the timestep is padded. result.ids = tf.where(tf.equal(result.paddings, 0), result.ids, tf.fill(tf.shape(result.ids), p.target_eos_id)) static_batch_size = bs_result.log_probs.shape[0] result.ids.set_shape([static_batch_size, p.target_seq_len]) result.paddings.set_shape([static_batch_size, p.target_seq_len]) return result
def FProp(self, theta, source_input, source_paddings, target_input=None, target_paddings=None, source_segment_id=None, target_segment_id=None, labels=None, label_weights=None, source_pos_id=None, target_pos_id=None, source_task_id=None, target_task_id=None): """Transforms source sequence of Tensors with Transformers layers. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. source_input: A sequence of ints indicating source input ids of [time, batch] shape or [batch, time] if batch_dim is 0. source_paddings: A sequence of 0s and 1s indicating input paddings of [time, batch] shape or [batch, time] if batch_dim is 0. target_input: A sequence of ints indicating target input ids of [time, batch] shape or [batch, time] if batch_dim is 0. target_paddings: [target_time, target_batch] or [target_batch, target_time] if batch_dim is 0. source_segment_id: A sequence of ints indicating source segment ids of [time, batch] shape or [batch, time] if batch_dim is 0. target_segment_id: A sequence of ints indicating target segment ids of [time, batch] shape or [batch, time] if batch_dim is 0. labels: A sequence of ints indicating label ids of [time, batch] shape, or [batch, time] if batch_dim is 0. label_weights: A sequence of floats indicates label weights of [time, batch] shape, or [batch, time] if batch_dim is 0. source_pos_id: A sequence of ints indicating source position ids of [time, batch] shape, or [batch, time] if batch_dim is 0. target_pos_id: A sequence of ints indicating target position ids of [time, batch] shape, or [batch, time] if batch_dim is 0. source_task_id: A sequence of ints indicating source task ids of [time, batch] shape, or [batch, time] if batch_dim is 0. target_task_id: A sequence of ints indicating target task ids of [time, batch] shape, or [batch, time] if batch_dim is 0. Returns: transformer_output with shape [time, batch, dim] or [batch, time, dim] if batch_dim is 0. """ p = self.params if p.num_decoder_layers > 0: assert target_input is not None assert target_paddings is not None if p.packed_input: assert source_segment_id is not None, ( 'Need to specify src_segment_id if packed input is supported.') assert source_pos_id is not None, ( 'Need to specify src_pos_id for packed input and embeddings.') logits = super(GPipeTransformerStack, self).FProp(theta, source_input, source_paddings, target_input, target_paddings, source_segment_id, target_segment_id, source_pos_id, target_pos_id, source_task_id, target_task_id) if not p.softmax_tpl: return logits label_weights = tf.reshape(label_weights, [-1]) target_probs = None if p.label_smoothing: if p.batch_dim: # Time-major target_probs = tf.transpose( self.smoother.FProp(theta.smoother, tf.transpose(target_paddings), tf.transpose(labels), target_ids=None), [1, 0, 2]) else: target_probs = self.smoother.FProp(theta.smoother, target_paddings, labels, target_ids=None) target_probs = tf.reshape(target_probs, [-1, p.softmax_tpl.num_classes]) reshaped_logits = tf.reshape(logits, [-1, p.softmax_tpl.num_classes]) tgt_labels = tf.reshape(labels, [-1]) num_splits = len(p.splits) softmax = self.children['cell_{}'.format(num_splits - 1)].softmax softmax_theta = theta['cell_{}'.format(num_splits - 1)].softmax per_example_xent, _ = softmax.XentLossFromLogits( softmax_theta, reshaped_logits, class_weights=tf.reshape(label_weights, [-1]), class_ids=tgt_labels, class_probabilities=target_probs) xent_shape = tf.shape(logits)[:2] per_example_xent = tf.reshape(per_example_xent, xent_shape) return per_example_xent, logits
def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states, num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=False): """Extend beam search hyps for one step. num_beams = Number of source sequences to be decoded. num_hyps_per_beam = Number of hyps to keep per source sequence. num_hyps = num_beams * num_hyps_per_beam src_seq_len = Number of time steps in the source sequence. tgt_seq_len = Maximum allowed time steps in the target sequence. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap computed by encoder. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. This list is maintained by this helper class. other_states: A NestedMap of other beam search states. This NestedMap is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. num_hyps_per_beam: Num of hyps to keep per beam. pre_beam_search_step_callback: The PreBeamSearchStepCallback callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The PostBeamSearchStepCallback callback. Please refer to the class header comments for more details. use_short_seq_opt: A bool, whether using short sequence optimization. Returns: A tuple of following elements for the next beam search step: (next step, all_done, step_ids, core_bs_states, other_states) """ p = self.params if use_short_seq_opt: bs_results, other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam, use_short_seq_opt) else: bs_results, other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam) (best_scores, cumulative_scores, histories, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, in_eos_scores, in_eos_atten_probs) = core_bs_states (out_best_scores, out_cumulative_scores, out_scores, out_eos_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs, out_eos_atten_probs, all_done, out_histories) = beam_search_tpu_ops.beam_search_step( bs_results.log_probs, bs_results.atten_probs, best_scores, cumulative_scores, histories, cur_step, eos_id=p.target_eos_id, beam_size=p.beam_size, num_beams=tf.shape(best_scores)[0], num_hyps_per_beam=num_hyps_per_beam, valid_eos_max_logit_delta=p.valid_eos_max_logit_delta, merge_paths=p.merge_paths, eoc_id=p.target_eoc_id if p.merge_paths else -1, is_last_chunk=bs_results.get('is_last_chunk')) # Write out values into TensorArray's corresponding to each output. arr_scores = in_scores.write(cur_step, out_scores) arr_eos_scores = in_eos_scores.write(cur_step, out_eos_scores) arr_hyps = in_hyps.write(cur_step, out_hyps) arr_prev_hyps = in_prev_hyps.write(cur_step, out_prev_hyps) # TODO(rohananil): Change the implementation of TensorArray write for # tf.bool from false += current_value to logical_and(true, current_value) as # addition operator for bool is not defined. arr_done_hyps = in_done_hyps.write(cur_step, tf.cast(out_done_hyps, tf.int32)) arr_atten_probs = in_atten_probs.write(cur_step, out_atten_probs) arr_eos_atten_probs = in_eos_atten_probs.write(cur_step, out_eos_atten_probs) # New beam search states. new_bs_states = (out_best_scores, out_cumulative_scores, out_histories, arr_scores, arr_hyps, arr_prev_hyps, arr_done_hyps, arr_atten_probs, arr_eos_scores, arr_eos_atten_probs) old_hyp_ids = tf.reshape(out_prev_hyps, [-1]) if p.batch_major_compute: # Transformed the indices into the key/value cache for fast decoding # (prefix_states in other_states) due to the num_hyps dimension of # cache is computed as num_beams by num_hyps_per_beam, which is different # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams). # Both transpose and recomputation are required to correct the indices. num_beams = tf.shape(best_scores)[0] old_hyp_ids_in_cache_order = tf.reshape( tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])), [-1]) old_hyp_ids_in_cache_order = ( (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam + old_hyp_ids_in_cache_order // num_beams) def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if isinstance(x_in, tf.Tensor) and x_in.shape.ndims > 0: # For rank > 1 tensors we make use of an efficient matmul based gather # on tpu that takes in account the range of the values. For R1, we # rely on the tf.gather and xla to optimize it efficiently for R1 # layout. if x_in.shape.ndims > 1: if p.batch_major_state: num_hyps = tf.shape(old_hyp_ids)[0] x_out = beam_search_tpu_ops.fast_gather( x_in, old_hyp_ids, num_hyps, max_value=None, batch_major_state=p.batch_major_state) else: # Use corrected indices only here for batch major compute as # key/value caches are the states being affected. correct_old_hyp_ids = (old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) def _GatherStep(x_in, t): """Gather for one time step. Args: x_in: in the shape of [T, B, ...] we first get slice(t) from the tensors, then gather old_hyp_ids from the slice and write the interpolated slice inplace to update the original x_in. t: current time step Returns: Updated x_in and time step """ x = tf.gather(tf.gather(x_in, t), correct_old_hyp_ids) return inplace_ops.alias_inplace_update( x_in, t, x), t + 1 x_out, _ = tf.while_loop( lambda _, t: t <= cur_step, _GatherStep, (x_in, tf.zeros([], tf.int32))) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in new_other_states = other_states.Transform(ReOrderHyps) new_step_ids = tf.reshape(out_hyps, [-1, 1]) final_other_states = post_beam_search_step_callback( theta, encoder_outputs, new_step_ids, new_other_states) return (cur_step + 1, all_done, new_step_ids, new_bs_states, final_other_states)