def cond_fn(step_num, prev_ids, *unused_states): """Should we run another loop iteration.""" overflow = mtf.equal(step_num, num_steps) has_eos = mtf.reduce_any( mtf.equal(prev_ids, eos_id), reduced_dim=length_dim) all_has_eos = mtf.reduce_all(has_eos) return mtf.logical_not(mtf.logical_or(overflow, all_has_eos))
def attention_mask_ignore_padding(inputs, dtype=tf.float32): """Bias for encoder-decoder attention. Args: inputs: a mtf.Tensor with shape [..., length_dim] dtype: a tf.dtype Returns: a mtf.Tensor with shape [..., memory_length_dim] """ inputs = rename_length_to_memory_length(inputs) return mtf.cast(mtf.equal(inputs, 0), dtype) * -1e9
def masked_local_attention_1d_incremental(x, prev_k, prev_v, step_num, master_dtype, slice_dtype, name=None): """Incremental local self-attention (one decode step). Incremental version of masked_local_attention_1d() Args: x: a mtf.Tensor with shape [batch..., io_channels] prev_k: mtf.Tensor with shape [batch..., heads, window_length, kv_channels] prev_v: mtf.Tensor with shape [batch..., heads, window_length, kv_channels] step_num: mtf Scalar with dtype tf.int32 master_dtype: a tf.dtype slice_dtype: a tf.dtype name: an optional string. Returns: y: A mtf.Tensor with shape [batch..., io_channels] new_k: mtf.Tensor with shape [batch..., heads, window_length, kv_channels] new_v: mtf.Tensor with shape [batch..., heads, window_length, kv_channels] Raises: ValueError: if the dimensions do not match. """ batch_dims = x.shape.dims[:-1] io_channels = x.shape.dims[-1] heads, window_length, kv_channels = prev_k.shape.dims[-3:] with tf.variable_scope(name, default_name="multihead_attention"): q_var, k_var, v_var, o_var = multihead_attention_vars( x.mesh, heads, io_channels, kv_channels, master_dtype, slice_dtype, x.dtype) q = mtf.einsum([x, q_var], mtf.Shape(batch_dims + [heads, kv_channels])) k = mtf.einsum([x, k_var], mtf.Shape(batch_dims + [heads, kv_channels])) v = mtf.einsum([x, v_var], mtf.Shape(batch_dims + [heads, kv_channels])) current_position = mtf.equal( mtf.range(x.mesh, window_length, dtype=tf.int32), mtf.mod(step_num, window_length.size)) k = mtf.where(current_position, k, prev_k, output_shape=prev_k.shape) v = mtf.where(current_position, v, prev_v, output_shape=prev_v.shape) o = dot_product_attention(q, k, v, mask=None) y = mtf.einsum([o, o_var], x.shape) return y, k, v
def body_fn(step_num, ids, *states): """Body function for greedy decoding. Args: step_num: a mtf.Tensor ids: a mtf.Tensor *states: additional mtf.Tensors Returns: new_step_num, new_ids, *new_states """ logits, new_states = logits_fn(step_num, ids, states) vocab_dim = logits.shape.dims[-1] new_ids = mtf.sample_with_temperature(logits, vocab_dim, temperature) if forced_ids is not None: # force the new ids to equal the partial targets where specified # (positions where partial_targets contain nonzero values) forced = mtf.gather(forced_ids, step_num, length_dim) new_ids = forced + new_ids * mtf.to_int32(mtf.equal(forced, 0)) ids += new_ids * mtf.one_hot(step_num, length_dim, dtype=tf.int32) new_step_num = step_num + 1 return [new_step_num, ids] + new_states
def grow_topk(i, alive_seq, alive_log_probs, states=None): r"""Inner beam search loop. This function takes the current alive sequences, and grows them to topk sequences where k = 2*beam. We use 2*beam because, we could have beam_size number of sequences that might hit <EOS> and there will be no alive sequences to continue. With 2*beam_size, this will not happen. This relies on the assumption the vocab size is > beam size. If this is true, we'll have at least beam_size non <EOS> extensions if we extract the next top 2*beam words. Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to https://arxiv.org/abs/1609.08144. Args: i: loop index alive_seq: Topk sequences decoded so far [batch, beam, length] alive_log_probs: probabilities of these sequences. [batch, beam] states: optional list of mtf.Tensor Returns: Tuple of (Topk sequences extended by the next word, The log probs of these sequences, The scores with length penalty of these sequences, Flags indicating which of these sequences have finished decoding, list of transformed decoding states) """ logits, new_states = logits_fn(i, alive_seq, states) batch_dim, beam_dim, vocab_dim = logits.shape.dims # Convert logits to normalized log probs candidate_log_probs = mtf.log_softmax(logits, vocab_dim) # Multiply the probabilities by the current probabilities of the beam. # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1) log_probs = candidate_log_probs + alive_log_probs length_penalty = mtf.pow(((5. + mtf.cast(i + 1, logits.dtype)) / 6.), alpha) curr_scores = log_probs / length_penalty # scores have shape [batch, beam, vocab] beam_and_vocab_dim = mtf.Dimension( "beam_and_vocab", beam_dim.size * vocab_dim.size) flat_shape = mtf.Shape([batch_dim, beam_and_vocab_dim]) double_beam = mtf.Dimension("double_beam", beam_dim.size * 2) # Flatten out (beam_size, vocab_size) probs in to a list of possibilities flat_curr_scores = mtf.reshape(curr_scores, flat_shape) top_ids, top_scores = mtf.top_k( flat_curr_scores, reduced_dim=beam_and_vocab_dim, new_dim=double_beam) # Recovering the log probs because we will need to send them back top_log_probs = top_scores * length_penalty # Work out what beam the top probs are in. top_beam_index = top_ids // vocab_dim.size top_ids %= vocab_dim.size # Unflatten the ids def my_gather(tensor): return mtf.gather( tensor, top_beam_index, beam_dim, output_shape=mtf.Shape( [double_beam if d == beam_dim else d for d in tensor.shape.dims])) # Gather up the most probable 2*beams both for the ids and finished_in_alive # bools top_seq = my_gather(alive_seq) if states: states = [my_gather(state) for state in new_states] # Append the most probable alive top_seq += top_ids * mtf.one_hot(i, length_dim, dtype=tf.int32) top_finished = mtf.equal(top_ids, eos_id) return top_seq, top_log_probs, top_scores, top_finished, states