def attention_bias_local_block(mesh, block_length, memory_length, dtype=tf.int32): """Bias for attention for local blocks where attention to right is disallowed. Create the bias matrix by using two separate masks, one for the memory part which doesn't overlap with the query and second which interacts with the query and should be disallowed to look to the right of the current query position. Args: mesh: a MeshTensorflow object block_length: a mtf.Dimension memory_length: a mtf.Dimension dtype: a tf.dtype Returns: a mtf.Tensor with shape [block_length, memory_length] """ memory_length = mtf.Dimension(memory_length.name, block_length.size) memory_mask = mtf.zeros(mesh, [block_length, memory_length], dtype=dtype) mask = mtf.cast(mtf.less(mtf.range(mesh, block_length, dtype=dtype), mtf.range(mesh, memory_length, dtype=dtype)), dtype=dtype) mask = mtf.cast(mtf.concat([memory_mask, mask], memory_length.name), dtype=tf.float32) * -1e9 return mask
def attention_bias_local_block(mesh, block_length, memory_length, dtype=tf.int32): """Bias for attention for local blocks where attention to right is disallowed. Args: mesh: a MeshTensorflow object block_length: a mtf.Dimension memory_length: a mtf.Dimension dtype: a tf.dtype Returns: a mtf.Tensor with shape [rows, cols] """ mask = mtf.cast(mtf.less(mtf.range(mesh, block_length, dtype=dtype), mtf.range(mesh, memory_length, dtype=dtype)), dtype=dtype) mask = mtf.cast(mask, dtype=tf.float32) * -1e9 return mask
def attention_mask_autoregressive(query_pos, dtype=tf.float32): """Bias for self-attention where attention to the right is disallowed. Args: query_pos: a mtf.Tensor with shape [..., length_dim] dtype: a tf.dtype Returns: a mtf.Tensor with shape [..., length_dim, memory_length_dim] """ memory_pos = rename_length_to_memory_length(query_pos) return mtf.cast(mtf.less(query_pos, memory_pos), dtype) * -1e9
def attention_mask_ignore_padding(inputs, dtype=tf.float32): """Bias for encoder-decoder attention. Args: inputs: a mtf.Tensor with shape [..., length_dim] dtype: a tf.dtype Returns: a mtf.Tensor with shape [..., memory_length_dim] """ inputs = rename_length_to_memory_length(inputs) return mtf.cast(mtf.equal(inputs, 0), dtype) * -1e9
def attention_mask_same_segment( query_segment, memory_segment=None, dtype=tf.float32): """Bias for attention where attention between segments is disallowed. Args: query_segment: a mtf.Tensor with shape [..., length_dim] memory_segment: a mtf.Tensor with shape [..., memory_length_dim] dtype: a tf.dtype Returns: a mtf.Tensor with shape [..., length_dim, memory_length_dim] """ memory_segment = rename_length_to_memory_length( memory_segment or query_segment) return mtf.cast(mtf.not_equal(query_segment, memory_segment), dtype) * -1e9
def moe_v0(inputs, hidden_dim, output_dim, experts_dim, loss_coef=1e-3, overhead=1.0): """Local mixture of experts that works well on TPU. See https://arxiv.org/abs/1701.06538 There are num_experts expert networks, each containing a relu-activated hidden layer of size hidden_size, followed by an output projection. The number of parameters is thus: num_experts * (input_size * hidden_size + hidden_size * output_size) The input is 3d: [batch, length, depth], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-2 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, as opposed to on individual sequences. This would allow more freedom for individual sequences to be unbalanced. Unfortunately, that would slow down our hacked-up gather-by-matmul implementation. TODO(noam): There is no real reason for a single sequence to be the unit of equal allocation. Reshaping the inputs would allow us to pick a different unit of equal allocation. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. We also want to integrate this gating/dispatching logic into multi-device mixtures-of-experts. Args: inputs: a mtf.Tensor with shape [batch_dim, length_dim, input_dim] hidden_dim: a mtf.Dimension output_dim: a mtf.Dimension experts_dim: a mtf.Dimension loss_coef: a float scalar overhead: multiplicative factor of how much spare capacity to assign Returns: outputs: a Tensor with shape [batch_dim, length_dim, output_dim] loss: a mtf scalar """ batch_dim, length_dim, input_dim = inputs.shape.dims # Each sequence sends expert_capacity positions to each expert. expert_capacity = min( length_dim.size, int((length_dim.size * 2 * overhead) / experts_dim.size)) expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity) experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size) batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size) # This is the learned gating function. # shape = [batch_dim, length_dim, experts_dim_unsplit] gates = mtf.softmax(dense(inputs, experts_dim_unsplit), experts_dim_unsplit) assignment_shape = mtf.TensorShape( [batch_dim, length_dim, experts_dim_unsplit, expert_capacity_dim]) backward_assignment = mtf.slicewise(functools.partial( _truncated_top_2_gating, expert_capacity=expert_capacity), [gates], output_shape=assignment_shape, splittable_dims=[batch_dim], name="backward_assignment") forward_assignment = mtf.cast(mtf.cast(backward_assignment, tf.bool), inputs.dtype) # put num_experts dimension first to make split easier in alltoall expert_inputs = mtf.einsum([inputs, forward_assignment], mtf.TensorShape([ experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim ])) expert_inputs = mtf.reshape( expert_inputs, mtf.TensorShape( [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim])) # Now feed the expert inputs through the experts. h = dense(expert_inputs, hidden_dim, expert_dims=[experts_dim], activation=mtf.relu, name="x0") expert_output = dense(h, output_dim, expert_dims=[experts_dim], name="x1") expert_output = mtf.reshape( expert_output, mtf.TensorShape( [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim])) output = mtf.einsum([expert_output, backward_assignment], mtf.TensorShape([batch_dim, length_dim, output_dim])) importance = mtf.reduce_sum(backward_assignment, output_shape=mtf.TensorShape( [batch_dim, experts_dim_unsplit])) loss = cv_squared(importance) * loss_coef return output, loss
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if self.has_input: inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf_layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.num_encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num in xrange(hparams.num_decoder_layers): with tf.variable_scope("decoder/layer_%d/encdec_attention" % layer_num): q_var, k_var, v_var, o_var = mtf_layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( [self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( [self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) partial_targets = None else: encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) if hparams.beam_size == 1: ids_shape = mtf.Shape([self.batch_dim, self.length_dim]) kv_shape = mtf.Shape([self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape([self.batch_dim, beam_dim, self.length_dim]) kv_shape = mtf.Shape([self.batch_dim, beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_kv_states = ( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * (2 * hparams.num_decoder_layers)) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" self_attention_k = states[:hparams.num_decoder_layers] self_attention_v = states[hparams.num_decoder_layers:] ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_self_attention_k, new_self_attention_v = ( self._decoder_layer_stack_incremental( x, step_num, encdec_tensors, self_attention_k, self_attention_v, encdec_attention_mask=encoder_attention_mask)) logits = mtf.matmul(x, softmax_var) return logits, new_self_attention_k + new_self_attention_v if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf_beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_kv_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if self.has_input: input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf_beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_kv_states, decode_length=decode_length, use_tpu=hparams.use_tpu) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def _top_2_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, importance=None): """Compute gating for mixture-of-experts in TensorFlow. Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_use_second_place_loss: a boolean hparams.moe_second_policy_train: a string hparams.moe_second_policy_eval: a string hparams.moe_second_threshold: a float The returned forward assignment is a tensor used to map (via einsum) from the inputs to the expert_inputs. Likewise, the returned combine_tensor is used to map (via einsum) from the expert outputs to the outputs. Both the forward and backward assignments are mostly zeros. The shapes of the tensors are as follows. inputs: [<batch_dims>, group_size_dim, input_dim] importance: [<batch_dims>, group_size_dim] dispatch_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] expert_inputs: [<batch_dims>, experts_dim, expert_capacity_dim, input_dim] expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim] combine_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] outputs: [<batch_dims>, group_size_dim, output_dim] "importance" is an optional tensor with one floating-point value for each input vector. If the importance of an input is 1.0, then we send it to up to 2 experts. If 0.0 < importance < 1.0, then we send it to at most one expert. If importance == 0.0, then we send it to no experts. We use "importance" at the second-level gating function of a hierarchical mixture of experts. Inputs to the first-choice expert-group get importance 1.0. Inputs to the second-choice expert group get importance 0.5. Inputs that represent padding get importance 0.0. Args: inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim] outer_expert_dims: an optional list of dimensions. This is for the case where we are at an inner level of a hierarchical MoE. experts_dim: a Dimension (the number of experts) expert_capacity_dim: a Dimension (number of examples per group per expert) hparams: model hyperparameters. train: a boolean importance: an optional tensor with shape [<batch_dims>, group_size_dim] Returns: dispatch_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] combine_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] loss: a mtf scalar Raises: ValueError: on illegal hyperparameters """ group_size_dim, unused_input_dim = inputs.shape.dims[-2:] raw_gates = mtf.softmax( mtf_layers.dense(inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims), experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) expert_capacity_f = float(expert_capacity_dim.size) # FIND TOP 2 EXPERTS PER POSITON # Find the top expert for each position. shape=[batch, group] index_1, gate_1 = mtf.top_1(raw_gates, experts_dim) # [batch, group, experts] mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype) density_1_proxy = raw_gates if importance is not None: mask_1 *= mtf.to_float(mtf.equal(importance, 1.0)) gate_1 *= mtf.to_float(mtf.equal(importance, 1.0)) density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0)) gates_without_top_1 = raw_gates * (1.0 - mask_1) # [batch, group] index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim) # [batch, group, experts] mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype) if importance is not None: mask_2 *= mtf.to_float(mtf.greater(importance, 0.0)) denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # BALANCING LOSSES # shape = [batch, experts] # We want to equalize the fraction of the batch assigned to each expert density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim) # Something continuous that is correlated with what we want to equalize. density_1_proxy = mtf.reduce_mean(density_1_proxy, reduced_dim=group_size_dim) density_1 = mtf.Print( density_1, [mtf.reduce_mean(density_1, output_shape=[experts_dim])], "density_1", summarize=1000) loss = (mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if hparams.moe_use_second_place_loss: # Also add a loss to encourage all experts to be used equally also as the # second-place expert. Experimentally, this seems to be a wash. # We want to equalize the fraction of the batch assigned to each expert: density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim) # As a proxy for density_2, we renormalize the raw gates after the top one # has been removed. normalized = gates_without_top_1 / (mtf.reduce_sum( gates_without_top_1, reduced_dim=experts_dim) + 1e-9) density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_size_dim) loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) * float(experts_dim.size * experts_dim.size)) loss += loss_2 * 0.5 # Depending on the policy in the hparams, we may drop out some of the # second-place experts. policy = (hparams.moe_second_policy_train if train else hparams.moe_second_policy_eval) threshold = (hparams.moe_second_threshold_train if train else hparams.moe_second_threshold_eval) if policy == "all": # Use second-place experts for all examples. pass elif policy == "none": # Never use second-place experts for all examples. mask_2 = mtf.zeros_like(mask_2) elif policy == "threshold": # Use second-place experts if gate_2 > threshold. mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold)) elif policy == "random": # Use second-place experts with probablity min(1.0, gate_2 / threshold). mask_2 *= mtf.to_float( mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape), gate_2 / max(threshold, 1e-9))) else: raise ValueError("Unknown policy %s" % policy) mask_2 = mtf.Print(mask_2, [mtf.reduce_mean(mask_2, output_shape=[experts_dim])], "density_2", summarize=1000) # COMPUTE ASSIGNMENT TO EXPERTS # [batch, group, experts] # This is the position within the expert's mini-batch for this sequence position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim, exclusive=True) * mask_1 # Remove the elements that don't fit. [batch, group, experts] mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f)) # [batch, experts] # How many examples in this sequence go to this expert mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim) # [batch, group] position_in_expert_1 = mtf.reduce_sum(position_in_expert_1, reduced_dim=experts_dim) # Weight assigned to first expert. [batch, group] gate_1 *= mask_1_flat # [batch, group, experts] position_in_expert_2 = ( mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count) position_in_expert_2 *= mask_2 mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f)) # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) gate_2 *= mask_2_flat position_in_expert_2 = mtf.reduce_sum(position_in_expert_2, reduced_dim=experts_dim) # [batch, group, experts, expert_capacity] combine_tensor = ( gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) + gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim)) combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def _top_2_gating(inputs, experts_dim, expert_capacity_dim, max_experts, hparams, train): """Compute gating for mixture-of-experts in TensorFlow. Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_use_second_place_loss: a boolean hparams.moe_second_policy_train: a string hparams.moe_second_policy_eval: a string hparams.moe_second_threshold: a float max_experts is an float tensor with shape [batch_dim, group_dim] indicating at most how many experts to use per example. This can be used to prevent padding from going to experts. The returned forward assignment is a tensor used to map (via einsum) from the inputs to the expert_inputs. Likewise, the returned backward_assignment is used to map (via einsum) from the expert outputs to the outputs. Both the forward and backward assignments are mostly zeros. The shapes of all of these are as follows. inputs: [batch_dim, group_dim, input_dim] forward_assignment: [batch_dim, group_dim, experts_dim, expert_capacity_dim] expert_inputs: [batch_dim, experts_dim, expert_capacity_dim, input_dim] expert_outputs: [batch_dim, experts_dim, expert_capacity_dim, output_dim] backward_assignment: [batch_dim, group_dim, experts_dim, expert_capacity_dim] outputs: [batch_dim, group_dim, output_dim] Args: inputs: a mtf.Tensor with shape [batch_dim, group_dim, input_dim] experts_dim: a Dimension (the number of experts) expert_capacity_dim: a Dimension (number of examples per group per expert) max_experts: optional mtf.Tensor with shape [batch_dim, group_dim] hparams: model hyperparameters. train: a boolean Returns: forward_assignment: a Tensor with shape [batch_dim, group_dim, experts_dim, expert_capacity_dim] backward_assignment: a Tensor with shape [batch_dim, group_dim, experts_dim, expert_capacity_dim] loss: a mtf scalar Raises: ValueError: on illegal hyperparameters """ unused_batch_dim, group_dim, unused_input_dim = inputs.shape.dims raw_gates = mtf.softmax( mtf_layers.dense(inputs, experts_dim, use_bias=False), experts_dim) expert_capacity_f = float(expert_capacity_dim.size) # FIND TOP 2 EXPERTS PER POSITON # Find the top expert for each position. shape=[batch, group] index_1, gate_1 = mtf.top_1(raw_gates, experts_dim) # [batch, group, experts] mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype) gates_without_top_1 = raw_gates * (1.0 - mask_1) # [batch, group] index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim) # [batch, group, experts] mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype) if max_experts is not None: geq1 = mtf.to_float(mtf.greater_equal(max_experts, 1.0)) geq2 = mtf.to_float(mtf.greater_equal(max_experts, 2.0)) mask_1 *= geq1 mask_2 *= geq2 raw_gates *= geq1 gates_without_top_1 *= geq2 # BALANCING LOSSES # shape = [batch, experts] # We want to equalize the fraction of the batch assigned to each expert density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_dim) # Something continuous that is correlated with what we want to equalize. density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_dim) density_1 = mtf.Print( density_1, [mtf.reduce_mean(density_1, output_shape=[experts_dim])], "density_1", summarize=1000) loss = (mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if hparams.moe_use_second_place_loss: # Also add a loss to encourage all experts to be used equally also as the # second-place expert. Experimentally, this seems to be a wash. # We want to equalize the fraction of the batch assigned to each expert: density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_dim) # As a proxy for density_2, we renormalize the raw gates after the top one # has been removed. normalized = gates_without_top_1 / (mtf.reduce_sum( gates_without_top_1, reduced_dim=experts_dim) + 1e-9) density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_dim) loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) * float(experts_dim.size * experts_dim.size)) loss += loss_2 * 0.5 # Depending on the policy in the hparams, we may drop out some of the # second-place experts. policy = (hparams.moe_second_policy_train if train else hparams.moe_second_policy_eval) threshold = (hparams.moe_second_threshold_train if train else hparams.moe_second_threshold_eval) if policy == "all": # Use second-place experts for all examples. pass elif policy == "none": # Never use second-place experts for all examples. mask_2 = mtf.zeros_like(mask_2) elif policy == "threshold": # Use second-place experts if gate_2 > threshold. mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold)) elif policy == "random": # Use second-place experts with probablity min(1.0, gate_2 / threshold). mask_2 *= mtf.to_float( mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape), gate_2 / max(threshold, 1e-9))) else: raise ValueError("Unknown policy %s" % policy) mask_2 = mtf.Print(mask_2, [mtf.reduce_mean(mask_2, output_shape=[experts_dim])], "density_2", summarize=1000) # COMPUTE ASSIGNMENT TO EXPERTS # [batch, group, experts] # This is the position within the expert's mini-batch for this sequence position_in_expert_1 = mtf.cumsum(mask_1, group_dim, exclusive=True) * mask_1 # Remove the elements that don't fit. [batch, group, experts] mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f)) # [batch, experts] # How many examples in this sequence go to this expert mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_dim) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim) # [batch, group] position_in_expert_1 = mtf.reduce_sum(position_in_expert_1, reduced_dim=experts_dim) # Weight assigned to first expert. [batch, group] gate_1 *= mask_1_flat # [batch, group, experts] position_in_expert_2 = (mtf.cumsum(mask_2, group_dim, exclusive=True) + mask_1_count) position_in_expert_2 *= mask_2 mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f)) # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) gate_2 *= mask_2_flat position_in_expert_2 = mtf.reduce_sum(position_in_expert_2, reduced_dim=experts_dim) # renormalize the two gate values to add up to 1 denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # [batch, group, experts, expert_capacity] backward_assignment = ( gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) + gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim)) forward_assignment = mtf.cast(mtf.cast(backward_assignment, tf.bool), backward_assignment.dtype) return forward_assignment, backward_assignment, loss