def cond_fn(step_num, prev_ids, *unused_states): """Should we run another loop iteration.""" overflow = mtf.equal(step_num, num_steps) has_eos = mtf.reduce_any(mtf.equal(prev_ids, eos_id), reduced_dim=length_dim) all_has_eos = mtf.reduce_all(has_eos) return mtf.logical_not(mtf.logical_or(overflow, all_has_eos))
def attention_mask_ignore_padding(inputs, dtype=tf.float32): """Bias for encoder-decoder attention. Args: inputs: a mtf.Tensor with shape [..., length_dim] dtype: a tf.dtype Returns: a mtf.Tensor with shape [..., memory_length_dim] """ inputs = rename_length_to_memory_length(inputs) return mtf.cast(mtf.equal(inputs, 0), dtype) * -1e9
def body_fn(step_num, ids, *states): """Body function for greedy decoding. Args: step_num: a mtf.Tensor ids: a mtf.Tensor *states: additional mtf.Tensors Returns: new_step_num, new_ids, *new_states """ logits, new_states = logits_fn(step_num, ids, states) vocab_dim = logits.shape.dims[-1] new_ids = mtf.sample_with_temperature(logits, vocab_dim, temperature) if forced_ids is not None: # force the new ids to equal the partial targets where specified # (positions where partial_targets contain nonzero values) forced = mtf.gather(forced_ids, step_num, length_dim) new_ids = forced + new_ids * mtf.to_int32(mtf.equal(forced, 0)) ids += new_ids * mtf.one_hot(step_num, length_dim, dtype=tf.int32) new_step_num = step_num + 1 return [new_step_num, ids] + new_states
def grow_topk(i, alive_seq, alive_log_probs, states=None): r"""Inner beam search loop. This function takes the current alive sequences, and grows them to topk sequences where k = 2*beam. We use 2*beam because, we could have beam_size number of sequences that might hit <EOS> and there will be no alive sequences to continue. With 2*beam_size, this will not happen. This relies on the assumption the vocab size is > beam size. If this is true, we'll have at least beam_size non <EOS> extensions if we extract the next top 2*beam words. Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to https://arxiv.org/abs/1609.08144. Args: i: loop index alive_seq: Topk sequences decoded so far [batch, beam, length] alive_log_probs: probabilities of these sequences. [batch, beam] states: optional list of mtf.Tensor Returns: Tuple of (Topk sequences extended by the next word, The log probs of these sequences, The scores with length penalty of these sequences, Flags indicating which of these sequences have finished decoding, list of transformed decoding states) """ logits, new_states = logits_fn(i, alive_seq, states) batch_dim, beam_dim, vocab_dim = logits.shape.dims # Convert logits to normalized log probs candidate_log_probs = mtf.log_softmax(logits, vocab_dim) # Multiply the probabilities by the current probabilities of the beam. # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1) log_probs = candidate_log_probs + alive_log_probs length_penalty = mtf.pow(((5. + mtf.to_float(i + 1)) / 6.), alpha) curr_scores = log_probs / length_penalty # scores have shape [batch, beam, vocab] beam_and_vocab_dim = mtf.Dimension("beam_and_vocab", beam_dim.size * vocab_dim.size) flat_shape = mtf.Shape([batch_dim, beam_and_vocab_dim]) double_beam = mtf.Dimension("double_beam", beam_dim.size * 2) # Flatten out (beam_size, vocab_size) probs in to a list of possibilities flat_curr_scores = mtf.reshape(curr_scores, flat_shape) top_ids, top_scores = mtf.top_k(flat_curr_scores, reduced_dim=beam_and_vocab_dim, new_dim=double_beam) # Recovering the log probs because we will need to send them back top_log_probs = top_scores * length_penalty # Work out what beam the top probs are in. top_beam_index = top_ids // vocab_dim.size top_ids %= vocab_dim.size # Unflatten the ids def my_gather(tensor): return mtf.gather(tensor, top_beam_index, beam_dim, output_shape=mtf.Shape([ double_beam if d == beam_dim else d for d in tensor.shape.dims ])) # Gather up the most probable 2*beams both for the ids and finished_in_alive # bools top_seq = my_gather(alive_seq) if states: states = [my_gather(state) for state in new_states] # Append the most probable alive top_seq += top_ids * mtf.one_hot(i, length_dim, dtype=tf.int32) top_finished = mtf.equal(top_ids, eos_id) return top_seq, top_log_probs, top_scores, top_finished, states
def _top_2_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, importance=None): """Compute gating for mixture-of-experts in TensorFlow. Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_use_second_place_loss: a boolean hparams.moe_second_policy_train: a string hparams.moe_second_policy_eval: a string hparams.moe_second_threshold: a float The returned forward assignment is a tensor used to map (via einsum) from the inputs to the expert_inputs. Likewise, the returned combine_tensor is used to map (via einsum) from the expert outputs to the outputs. Both the forward and backward assignments are mostly zeros. The shapes of the tensors are as follows. inputs: [<batch_dims>, group_size_dim, input_dim] importance: [<batch_dims>, group_size_dim] dispatch_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] expert_inputs: [<batch_dims>, experts_dim, expert_capacity_dim, input_dim] expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim] combine_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] outputs: [<batch_dims>, group_size_dim, output_dim] "importance" is an optional tensor with one floating-point value for each input vector. If the importance of an input is 1.0, then we send it to up to 2 experts. If 0.0 < importance < 1.0, then we send it to at most one expert. If importance == 0.0, then we send it to no experts. We use "importance" at the second-level gating function of a hierarchical mixture of experts. Inputs to the first-choice expert-group get importance 1.0. Inputs to the second-choice expert group get importance 0.5. Inputs that represent padding get importance 0.0. Args: inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim] outer_expert_dims: an optional list of dimensions. This is for the case where we are at an inner level of a hierarchical MoE. experts_dim: a Dimension (the number of experts) expert_capacity_dim: a Dimension (number of examples per group per expert) hparams: model hyperparameters. train: a boolean importance: an optional tensor with shape [<batch_dims>, group_size_dim] Returns: dispatch_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] combine_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] loss: a mtf scalar Raises: ValueError: on illegal hyperparameters """ group_size_dim, unused_input_dim = inputs.shape.dims[-2:] raw_gates = mtf.softmax( mtf_layers.dense(inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims), experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) expert_capacity_f = float(expert_capacity_dim.size) # FIND TOP 2 EXPERTS PER POSITON # Find the top expert for each position. shape=[batch, group] index_1, gate_1 = mtf.top_1(raw_gates, experts_dim) # [batch, group, experts] mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype) density_1_proxy = raw_gates if importance is not None: mask_1 *= mtf.to_float(mtf.equal(importance, 1.0)) gate_1 *= mtf.to_float(mtf.equal(importance, 1.0)) density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0)) gates_without_top_1 = raw_gates * (1.0 - mask_1) # [batch, group] index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim) # [batch, group, experts] mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype) if importance is not None: mask_2 *= mtf.to_float(mtf.greater(importance, 0.0)) denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # BALANCING LOSSES # shape = [batch, experts] # We want to equalize the fraction of the batch assigned to each expert density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim) # Something continuous that is correlated with what we want to equalize. density_1_proxy = mtf.reduce_mean(density_1_proxy, reduced_dim=group_size_dim) density_1 = mtf.Print( density_1, [mtf.reduce_mean(density_1, output_shape=[experts_dim])], "density_1", summarize=1000) loss = (mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if hparams.moe_use_second_place_loss: # Also add a loss to encourage all experts to be used equally also as the # second-place expert. Experimentally, this seems to be a wash. # We want to equalize the fraction of the batch assigned to each expert: density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim) # As a proxy for density_2, we renormalize the raw gates after the top one # has been removed. normalized = gates_without_top_1 / (mtf.reduce_sum( gates_without_top_1, reduced_dim=experts_dim) + 1e-9) density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_size_dim) loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) * float(experts_dim.size * experts_dim.size)) loss += loss_2 * 0.5 # Depending on the policy in the hparams, we may drop out some of the # second-place experts. policy = (hparams.moe_second_policy_train if train else hparams.moe_second_policy_eval) threshold = (hparams.moe_second_threshold_train if train else hparams.moe_second_threshold_eval) if policy == "all": # Use second-place experts for all examples. pass elif policy == "none": # Never use second-place experts for all examples. mask_2 = mtf.zeros_like(mask_2) elif policy == "threshold": # Use second-place experts if gate_2 > threshold. mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold)) elif policy == "random": # Use second-place experts with probablity min(1.0, gate_2 / threshold). mask_2 *= mtf.to_float( mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape), gate_2 / max(threshold, 1e-9))) else: raise ValueError("Unknown policy %s" % policy) mask_2 = mtf.Print(mask_2, [mtf.reduce_mean(mask_2, output_shape=[experts_dim])], "density_2", summarize=1000) # COMPUTE ASSIGNMENT TO EXPERTS # [batch, group, experts] # This is the position within the expert's mini-batch for this sequence position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim, exclusive=True) * mask_1 # Remove the elements that don't fit. [batch, group, experts] mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f)) # [batch, experts] # How many examples in this sequence go to this expert mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim) # [batch, group] position_in_expert_1 = mtf.reduce_sum(position_in_expert_1, reduced_dim=experts_dim) # Weight assigned to first expert. [batch, group] gate_1 *= mask_1_flat # [batch, group, experts] position_in_expert_2 = ( mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count) position_in_expert_2 *= mask_2 mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f)) # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) gate_2 *= mask_2_flat position_in_expert_2 = mtf.reduce_sum(position_in_expert_2, reduced_dim=experts_dim) # [batch, group, experts, expert_capacity] combine_tensor = ( gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) + gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim)) combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss