def get_indices(self, keys: mtf.Tensor, query: mtf.Tensor) -> Tuple[mtf.Tensor, mtf.Tensor]: """Generate score and indices for the query.""" score_shape = mtf.Shape(query.shape.dims[:-1] + keys.shape.dims[2:3]) scores = mtf.einsum([query, keys], output_shape=score_shape) # [b, l, h, 2, n_keys] knn_dim = mtf.Dimension("knn", self.knn) scores, indices = mtf.top_k(scores, score_shape.dims[-1], knn_dim) # [b, l, h, 2, knn] # Computes the top cartesian products and their indices knn_square_dim = mtf.Dimension("knn_square_dim", self.knn**2) scores1, scores2 = mtf.unstack(scores, scores.shape.dims[-2]) scores2 = mtf.rename_dimension(scores2, "knn", "knn2") out_shape = mtf.Shape(scores1.shape.dims + scores2.shape.dims[-1:]) all_scores = mtf.add(scores1, scores2, output_shape=out_shape) all_scores = mtf.replace_dimensions(all_scores, out_shape[-2:], knn_square_dim) indices1, indices2 = mtf.unstack(indices, indices.shape.dims[-2]) indices1 = mtf.multiply(indices1, self.n_keys) indices2 = mtf.rename_dimension(indices2, "knn", "knn2") all_indices = mtf.add(indices1, indices2, output_shape=out_shape) all_indices = mtf.replace_dimensions(all_indices, out_shape[-2:], knn_square_dim) scores, best_indices = mtf.top_k(all_scores, all_scores.shape.dims[-1], knn_dim) return scores, mtf.gather(all_indices, best_indices, knn_square_dim)
def testTopK(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") a_dim = mtf.Dimension("a", 6) b_dim = mtf.Dimension("b", 2) inputs = tf.constant([[1, 10], [2, 9], [3, 8], [4, 7], [5, 6], [6, 5]], dtype=tf.float32) k_dim = mtf.Dimension("k", 2) d_values = tf.constant([[11, 12], [13, 14]], dtype=tf.float32) reduced_dim = a_dim expected_values = tf.constant([[6, 5], [10, 9]], dtype=tf.float32) expected_indices = tf.constant([[5, 4], [0, 1]]) expected_d_inputs = tf.constant([[0, 13], [0, 14], [0, 0], [0, 0], [12, 0], [11, 0]], dtype=tf.float32) mtf_inputs = mtf.import_fully_replicated( mesh, inputs, shape=mtf.Shape([a_dim, b_dim])) mtf_d_values = mtf.import_tf_tensor( mesh, d_values, shape=mtf.Shape([b_dim, k_dim])) mtf_values, mtf_indices = mtf.top_k(mtf_inputs, reduced_dim=reduced_dim, k_dim=k_dim, name="test_nth_smallest") [mtf_d_inputs] = mtf.gradients([mtf_values], [mtf_inputs], [mtf_d_values]) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape="rows:2,cols:2", layout="a:rows,b:cols", devices=["", "", "", ""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_values = lowering.export_to_tf_tensor(mtf_values) actual_indices = lowering.export_to_tf_tensor(mtf_indices) actual_d_inputs = lowering.export_to_tf_tensor(mtf_d_inputs) actual_inputs = lowering.export_to_tf_tensor(mtf_inputs) self.assertAllEqual(self.evaluate(actual_inputs), self.evaluate(inputs)) self.assertAllEqual(self.evaluate(actual_values), self.evaluate(expected_values)) self.assertAllEqual(self.evaluate(actual_indices), self.evaluate(expected_indices)) self.assertAllEqual(self.evaluate(actual_d_inputs), self.evaluate(expected_d_inputs))
def _switch_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="switch_gating", num_microbatches=None): """Compute a switch top-1 gating with no-token-left behind behavior.""" # SELECT EXPERT if train: policy = hparams.moe_rand_1_policy_train else: policy = hparams.moe_rand_1_policy_eval # Input perturbations if train and policy == "input_jitter": inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_rand_1_jitter) gate_logits = mtf.layers.dense( inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) # Top-k operation k_dim = mtf.Dimension("k", hparams.moe_switch_top_k) expert_gate, expert_index = mtf.top_k( raw_gates, reduced_dim=experts_dim, k_dim=k_dim) expert_mask = mtf.one_hot(expert_index, experts_dim) # LOAD BALANCING LOSS outer_batch_dim = inputs.shape[0] batch_dim = inputs.shape[1] group_size_dim = inputs.shape[-2] density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim) density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim) if importance is not None: expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) density_1_proxy *= mtf.cast( mtf.equal(importance, 1.0), dtype=raw_gates.dtype) loss = ( mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if num_microbatches and num_microbatches > 1: tf.logging.info("Dividing load-balance loss by num_microbatches={}".format( num_microbatches)) loss /= num_microbatches # Logging if train: entropy = mtf.reduce_sum( -raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim) batch_entropy = mtf.reduce_mean(entropy) mtf.scalar_summary(name + "/entropy", batch_entropy) mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim]) total_routed = mtf.reduce_sum(mask_count_experts) expert_fraction = mtf.to_float(mask_count_experts / total_routed) split_fractions = mtf.split( expert_fraction, split_dim=experts_dim, num_or_size_splits=experts_dim.size) for fraction in split_fractions: mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"), mtf.reduce_mean(fraction)) mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss)) # COMPUTE ASSIGNMENT TO EXPERT # Iteratively route tokens (no-token-left-behind). The idea is to route as # many tokens as possible to top-i before then trying top-(i+1). top_k_masks = mtf.split( expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size) top_k_gates = mtf.split( expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size) top_k_indices = mtf.split( expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size) # Tensors cumulative values over the iterative process. combine_tensor = mtf.constant( inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim]) cum_tokens = mtf.constant( inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim]) tokens_left_to_route = mtf.constant( inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim]) expert_capacity_float = float(expert_capacity_dim.size) for (top_i_mask, top_i_gate, top_i_index) in zip(top_k_masks, top_k_gates, top_k_indices): top_i_mask = mtf.reshape( top_i_mask, new_shape=[outer_batch_dim, batch_dim, group_size_dim, experts_dim]) # Operate only on the unrouted tokens. top_i_mask *= tokens_left_to_route # Record cumulative number of tokens to each expert across iterations. cumulative_tokens_in_expert = cum_tokens + mtf.cumsum( top_i_mask, group_size_dim) expert_overflow = mtf.to_float( mtf.less_equal(cumulative_tokens_in_expert, expert_capacity_float)) output_i_tokens = top_i_mask * expert_overflow # Update the cumulative tokens routed to each expert. cum_tokens += mtf.reduce_sum(output_i_tokens, reduced_dim=group_size_dim) tokens_left_to_route -= ( mtf.reduce_sum(output_i_tokens, reduced_dim=experts_dim)) # Combine-tensor for this iteration output_i_tokens_flat = mtf.reduce_sum( output_i_tokens, reduced_dim=experts_dim) position_in_expert = cumulative_tokens_in_expert - 1 top_i_combine_tensor = ( top_i_gate * output_i_tokens_flat * mtf.one_hot(top_i_index, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim)) combine_tensor += top_i_combine_tensor # Match the inputs dtype. combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast( mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss