def layer_prepostprocess_dropout(x, hparams): batch_dim = x.shape.dims[0] model_dim = x.shape.dims[-1] return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([batch_dim, model_dim]))
def __init__(self, config, is_training, input_ids, input_mask=None, token_type_ids=None, scope=None, mesh_shape="", layout=""): self.config = copy.deepcopy(config) del config if not is_training: self.config.layer_output_dropout_prob = 0.0 self.config.attention_probs_dropout_prob = 0.0 self.config.feedforward_intermediate_dropout_prob = 0.0 input_shape = input_ids.shape assert input_shape.ndims == 2 self._seq_dim = input_shape.dims[1] self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size) self._extra_losses = [] mesh = input_ids.mesh if token_type_ids is None: token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32) with tf.variable_scope(scope, default_name="bert"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. self.embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([self.vocab_dim, self.model_dim]), initializer=self.embedding_initializer) self.word_embedding_output = mtf.gather( self.embedding_table, input_ids, self.vocab_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = self.word_embedding_output token_type_table = mtf.get_variable( mesh, "token_type_embeddings", mtf.Shape([self.token_type_vocab_dim, self.model_dim]), initializer=self.embedding_initializer) if token_type_ids is not None: self.embedding_output += mtf.gather( token_type_table, token_type_ids, self.token_type_vocab_dim) if self.config.position_signal == "embedding": full_position_table = mtf.get_variable( mesh, "position_embeddings", mtf.Shape( [self.max_position_embeddings_dim, self.model_dim]), initializer=self.embedding_initializer) short_position_table = mtf.rename_dimension( mtf.slice(full_position_table, 0, self.seq_dim.size, self.max_position_embeddings_dim.name), self.max_position_embeddings_dim.name, self.seq_dim.name) self.embedding_output += short_position_table self.embedding_output = self.normalize(self.embedding_output) self.embedding_output = mtf.dropout( self.embedding_output, keep_prob=1.0 - self.config.layer_output_dropout_prob) with tf.variable_scope("encoder"): attention_biases = [] if input_mask: # [batch_dim, memory_seq_dim] attention_biases.append((1.0 - mtf.to_float( mtf.replace_dimensions(input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0) if self.config.position_signal == "relative_attention_bias": buckets_dim = mtf.Dimension("buckets", 32) rp_bucket = _relative_position_bucket( mtf.range(mesh, self.memory_seq_dim, tf.int32) - mtf.range(mesh, self.seq_dim, tf.int32), num_buckets=buckets_dim.size) bias_var = mtf.get_variable( mesh, "relative_attention_bias", [self.num_heads_dim, buckets_dim], initializer=tf.zeros_initializer()) attention_biases.append( mtf.gather(bias_var, rp_bucket, buckets_dim)) attention_bias = mtf.add_n(attention_biases) prev_layer_output = self.embedding_output self.all_encoder_layers = [] for block_num in range(self.config.num_blocks): with tf.variable_scope("block_%d" % block_num): for layer_idx, layer_type in enumerate( self.config.block_layers): layer_name = layer_type count = self.config.block_layers[:layer_idx].count( layer_type) if count: layer_name += "_%d" % count with tf.variable_scope(layer_name): x = prev_layer_output if self.config.residual_structure == "direct": x = self.normalize(x) if layer_type == "attention": x = self.self_attention(x, attention_bias) elif layer_type == "feedforward": x = self.feedforward(x) elif layer_type == "moe": x = self.moe(x, layout, mesh_shape, input_mask, is_training) else: raise ValueError("unknown layer type " + layer_type) x = mtf.dropout( x, keep_prob=1.0 - self.config.layer_output_dropout_prob) layer_output = prev_layer_output + x if self.config.residual_structure == "original": layer_output = self.normalize(layer_output) prev_layer_output = layer_output self.all_encoder_layers.append(layer_output) self.sequence_output = prev_layer_output if self.config.residual_structure == "direct": self.sequence_output = self.normalize(self.sequence_output) # The "pooler" converts the encoded sequence tensor of shape # [batch_dim, seq_dim, hidden_size] to a tensor of shape # [batch_dim, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim) self.pooled_output = mtf.layers.dense( first_token_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=mtf.tanh, kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias)
def transformer_moe_layer_v1(inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu): """Local mixture of experts that works well on TPU. Adapted from the paper https://arxiv.org/abs/1701.06538 Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_num_experts: number of experts hparams.moe_hidden_size: size of hidden layer in each expert hparams.moe_group_size: size of each "group" for gating purposes hparams.moe_capacity_factor_train: a float hparams.moe_capacity_factor_eval: a float hparams.moe_gating: a string + all hyperparmeters used by _top_2_gating() The number of parameters in the gating network is: (input_dim.size * hparams.num_experts) + The number of parameters in the experts themselves is: (hparams.num_experts * (input_dim.size + output_dim.size) * hparams.moe_hidden_size) The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-2 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, but due to our hacked-up gather-by-matmul implementation, we need to divide the batch into "groups". For each group, the same number of elements are sent to each expert. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. Dimensions cheat sheet: B: batch dim(s) L: original sequence length M: input depth N: output depth G: number of groups S: group size E: number of experts C: expert capacity Args: inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean variable_dtype: a mtf.VariableDType layout: optional - an input to mtf.convert_to_layout_rules mesh_shape: optional - an input to mtf.convert_to_shape nonpadding: an optional Tensor with shape [batch_dim(s), length_dim] and the same dtype as inputs, consisting of ones(nonpadding) and zeros(padding). activation: a function. Returns: outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ # pylint: disable=line-too-long # # O outer_batch dimension can be used for expert replication, e.g. # outer_batch=4 for placing 128 experts on 512 cores with 4 replicas of each # expert. # # E.g. 16x16 basic example: # moe_num_experts=512, num_groups=1024, batch=4096, length=256, d_model=1024 # --- # Below ` indicates common way of splitting along mesh dimension. # # orig_inputs OB`LM Tensor # Shape[outer_batch=1, batch=4096, length=256, d_model=1024] # v (reshaped) # inputs OG`SM # Shape[outer_batch=1, batch=1024, group=1024, d_model=1024] # # combine_tensor, # dispatch_tensor OG`SEC # Shape[outer_batch=1, batch=1024, group=1024, expert_unsplit=512, expert_capacity=4] # # (dispatched inputs) # expert_inputs OEG`CM # Shape[outer_batch=1, expert_unsplit=512, batch=1024, expert_capacity=4, d_model=1024] # v (re-split via ReshapeOperation) # OE`GCM # Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, d_model=1024] # # (hidden representation) # h OE`GCH # Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, expert_hidden=8192] # # expert_output OE`GCM # Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, d_model=1024] # v (re-split via ReshapeOperation) # OEG`CM # Shape[outer_batch=1, expert_unsplit=512, batch=1024, expert_capacity=4, d_model=1024] # # (combined expert_output) # output OG`SM # Shape[outer_batch=1, batch=1024, group=1024, d_model=1024 # v (reshape) # OB`LM # Shape[outer_batch=1, batch=4096, length=256, d_model=1024] # # pylint: enable=line-too-long orig_inputs = inputs hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) experts_dim = mtf.Dimension("experts", hparams.moe_num_experts) # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups is a multiple of the mesh dimension # over which those groups are split. batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1], orig_inputs.shape.dims[-1]) # Hack: we assume that # "outer_batch" == replication of experts # mesh_dim_size can be derived from mesh_shape and orig_batch_dim # # We then reqire num_groups to be a multiple of mesh_dim_size. if orig_inputs.shape.dims[0].name == "outer_batch": outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2] else: outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1), orig_inputs.shape.dims[0]) # Number of MoE inputs (total number of position across batch_and_length_dims # per replica. n = 1 for d in batch_and_length_dims: n *= d.size n = n // outer_batch_dim.size mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, orig_batch_dim) num_groups, group_size = _split_into_groups(n, hparams.moe_group_size, mesh_dim_size) group_size_dim = mtf.Dimension("group", group_size) num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups) moe_input_dims = [ outer_batch_dim, num_groups_dim, group_size_dim, input_dim ] # OGSM Tensor inputs = mtf.reshape(inputs, moe_input_dims) # Each sequence sends expert_capacity positions to each expert. if train: capacity_factor = hparams.moe_capacity_factor_train else: capacity_factor = hparams.moe_capacity_factor_eval expert_capacity = min( group_size_dim.size, int((group_size_dim.size * capacity_factor) / experts_dim.size)) expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity) experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size) batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size) if nonpadding is not None: nonpadding = mtf.zeros(inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1]) if hparams.moe_gating == "top_2": # combine_tensor, # dispatch_tensor OG`SEC Tensors # (G is generally split along mesh dim) dispatch_tensor, combine_tensor, loss = _top_2_gating( inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) expert_inputs = mtf.einsum([inputs, dispatch_tensor], mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, input_dim ])) expert_inputs = mtf.reshape( expert_inputs, mtf.Shape([ outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim ])) # Now feed the expert inputs through the experts. h = mtf.layers.dense_product(expert_inputs, reduced_dims=expert_inputs.shape.dims[-1:], new_dims=[hidden_dim], expert_dims=[experts_dim], activation_functions=activation, use_bias=False, variable_dtype=variable_dtype, name="wi") if train and hparams.moe_dropout_rate != 0.0: h = mtf.dropout(h, 1.0 - hparams.moe_dropout_rate) expert_output = mtf.layers.dense(h, output_dim, expert_dims=[experts_dim], use_bias=False, reduced_dims=h.shape.dims[-1:], variable_dtype=variable_dtype, name="wo") expert_output = mtf.reshape( expert_output, mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, output_dim, ])) moe_output_dims = moe_input_dims[:-1] + [output_dim] output = mtf.einsum([expert_output, combine_tensor], mtf.Shape(moe_output_dims)) output = mtf.reshape(output, batch_and_length_dims + [output_dim]) return output, loss * hparams.moe_loss_coef
def hybrid_attention(q, k, v, context, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor context: context of the attention layer. memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor Returns: Tensor with shape q.shape - key_dim + value_dim """ logits = mtf.einsum([q, k], reduced_dims=[key_dim]) if bias is not None: logits += bias query_length_dim = mtf.Dimension("length", memory_length_dim.size) doubly_coeff = mtf.get_variable(context.mesh, "doubly_coeff", [], initializer=tf.constant_initializer(0.5), dtype=context.variable_dtype) doubly_coeff = mtf.maximum(mtf.minimum(doubly_coeff, 1.), 0.) upper_weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) lower_log_weights = mtf.log_softmax(logits, query_length_dim, extra_logit=extra_logit) doubly_weights = mtf.softmax(lower_log_weights, memory_length_dim, extra_logit=extra_logit) weights = doubly_coeff * doubly_weights + (1. - doubly_coeff) * upper_weights if dropout_rate != 0.0: weights = mtf.dropout(weights, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None, context=None, float32_logits=True): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor context: an optional Transformer.Context float32_logits: a boolean - if True, then compute logits in float32 to avoid numerical issues with bfloat16 Returns: Tensor with shape q.shape - key_dim + value_dim """ orig_q_shape = q.shape q, k, v, bias = maybe_reshape_attention_input_for_2d_sharding( context, q, k, v, bias, [key_dim, value_dim]) if float32_logits: k = mtf.cast(k, tf.float32) q = mtf.cast(q, tf.float32) logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim]) if bias is not None: logits += mtf.cast(bias, logits.dtype) weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) weights = mtf.cast(weights, v.dtype) if dropout_rate != 0.0: weights = mtf.dropout(weights, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) outputs = mtf.reshape(outputs, orig_q_shape - key_dim + value_dim) return outputs
def synthetic_attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None, synthesize=True, synthesize_mode="random_plus_alpha", factorized_dim=16, max_length=512, context=None): """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743). key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor synthesize: flag to use synthetic attention or not synthesize_mode: which variant of synthesizer to use factorized_dim: factorized dim for synthesizers max_length: max length of input sequence context: context since we need context mode Returns: Tensor with shape q.shape - key_dim + value_dim """ if synthesize: num_heads = v.shape.get_dim_by_name("heads") tf.logging.info("Using synthesizer") if synthesize_mode == "random": tf.logging.info("Using Random Synthesizers") r_shape = mtf.Shape([ mtf.Dimension("length", max_length), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", max_length) ]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r r_shape = logits.shape elif synthesize_mode == "factorized": tf.logging.info("Using Factorized Random Synthesizers") k = factorized_dim r1_shape = mtf.Shape([ mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512) ]) r2_shape = mtf.Shape([ mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512) ]) r_shape = mtf.Shape([ mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512) ]) r1 = mtf.get_variable(context.mesh, "R1", r1_shape, initializer=None, dtype=context.variable_dtype) r2 = mtf.get_variable(context.mesh, "R2", r2_shape, initializer=None, dtype=context.variable_dtype) r = mtf.einsum([r1, r2], r_shape) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r elif synthesize_mode == "dense_minus": # Dense Synthesizer Model tmp_dim = mtf.Dimension("memory_length", max_length) logits = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) logits = mtf.slice(logits, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") logits = mtf.slice(logits, 0, length_dim.size, "length") elif synthesize_mode == "random_plus_alpha" or \ synthesize_mode == "random_plus": # Mixture Random Synthesizer with learnable Alpha tf.logging.info("Using Random Plus Alpha") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) num_heads = logits.shape.get_dim_by_name("heads") r_shape = mtf.Shape([ mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512) ]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, length_dim.name) if "alpha" in synthesize_mode: alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1 - alpha) * logits) + (alpha * r) else: logits = logits + r elif synthesize_mode == "dense_plus_alpha" or \ synthesize_mode == "dense_plus": # Mixture Dense Synthesizer with learnable alpha tf.logging.info("Using Dense Plus Alpha Scaling") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) tmp_dim = mtf.Dimension("memory_length", 512) r = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") if "alpha" in synthesize_mode: alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1 - alpha) * logits) + (alpha * r) else: logits = logits + r if bias is not None: logits += bias weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) if dropout_rate != 0.0: weights = mtf.dropout(weights, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) if synthesize and "plus" not in synthesize_mode: if synthesize_mode == "dense_minus": outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim]) else: outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim]) else: outputs_shape = q.shape - [key_dim] + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def layer_prepostprocess_dropout(x): if is_incremental: return x return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
def model(mtf_features, other_features, params, mesh, variable_dtype, context=None): """A GPT style model implemented in mesh tensorflow.""" x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs( mtf_features, other_features) if is_incremental_inference(context): # reshape inputs if in inference mode x = mtf.gather(x, context.position - 1, sequence_dim) x = mtf.reshape(x, [batch_dim]) use_axial_pos_emb = params["axial_pos_emb"] is not None if not use_axial_pos_emb: # Use standard position encoding wpe = mtf.get_variable( mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.01), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) else: wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype) # Text encoding wte = mtf.get_variable( mesh, "wte", mtf.Shape([vocab_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.02), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) with tf.variable_scope("token_embd"): # Text embedding h = mtf.gather(wte, x, vocab_dim) if params["embed_dropout"] > 0 and params["mode"] == "train": h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout") with tf.variable_scope("pos_embd"): # Positional embedding position_indices = mtf.range( mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else ( context.position - 1) pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) if params["embed_dropout"] > 0 and params["mode"] == "train": pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout") h += pos_emb aux_losses = 0 # instantiate auxiliary losses (for MOE models) for layer in range(params["n_layer"]): # attn blocks share_parameters = exists( params["share_parameters"]) and params["share_parameters"] == True block_scope = f"h{layer}" if not share_parameters else "" block_fn = block(params=params, scope=block_scope, layer_num=layer, bias=other_features["attn_bias"], sequence_dim=sequence_dim, memory_length_dim=other_features["memory_length_dim"], variable_dtype=variable_dtype, context=context) # If true and in train mode, enable gradient checkpointing recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad( block_fn, [h]) aux_losses += loss no_weight_tie_emb = params["no_weight_tie"] == True if no_weight_tie_emb: with tf.variable_scope("wte_final_linear"): logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params) else: # Layer normalize & affine transform h = layer_norm(h, "ln_f", variable_dtype=variable_dtype) seq_dim = sequence_dim if not is_incremental_inference( context) else mtf.Dimension("sequence", 1) with tf.variable_scope("wte_final_einsum"): # Equivalent to tf.matmul logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim]) if params["mode"] in ["train", "eval"]: labels = mtf_features["labels"] z_loss = params.get( "z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy # Go to full precision for the logits logits = mtf.cast(logits, tf.float32) use_entmax_loss = params.get("entmax_loss", False) loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits with tf.variable_scope("xentropy_final"): loss_batch = loss_fn(logits=logits, targets=labels, vocab_dim=logits.shape[-1], z_loss=z_loss) # For non-autoregressive models (masked language modeling training) # Make sure labels with padding tokens are not counted in the loss if not params["causal"]: padding_id = params.get("padding_id", 0) loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch)) with tf.variable_scope("reduce_mean_final"): loss = mtf.reduce_mean(loss_batch) loss += aux_losses # Add on auxiliary losses (currently only used for MoE) loss /= params["num_microbatches"] # Convert to train dtype loss = mtf.cast(loss, variable_dtype.slice_dtype) else: loss = None loss_batch = None # Cast back to checkpoint dtype logits = mtf.cast(logits, variable_dtype.master_dtype) return logits, loss, loss_batch
def fn(x): with tf.variable_scope(scope): nx = x.shape[-1] # Grab last dimension from input if use_rezero: prenorm = identity elif use_scale_norm: prenorm = scale_norm else: prenorm = layer_norm pre_residual_fn = rezero if use_rezero else identity attention_type = params["attention_types"][layer_num] if macaron_attention: mult = 0.5 mlp_fn = mlp_glu if use_mlp_glu else mlp intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2) # Define intermediate layer of mlp - to split dim_intermediate_expanded = mtf.Dimension( "intermediate_expanded", intermediate_size) m = mlp_fn(x, "mlp_macaron", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params) x = x + (m * mult) else: mult = 1 if attention_type != "none": res_x = prenorm(x, "norm_1", variable_dtype=variable_dtype, params=params) a = attn(res_x, "attn", nx, attention_type=attention_type, params=params, bias=bias, dim_seq=sequence_dim, memory_length_dim=memory_length_dim, variable_dtype=variable_dtype, context=context) else: a = x x = x + pre_residual_fn(a, "norm_rezero_1", dtype=variable_dtype) res_x = prenorm(x, "norm_2", variable_dtype=variable_dtype, params=params) if use_moe: moe_params = mtf.transformer.moe.HParams() mtf.transformer.moe.set_default_moe_hparams(moe_params) moe_params.add_hparam("moe_min_expert_capacity", 1) moe_params.add_hparam("moe_use_experts_attention", False) # Override defaults for k, v in params["moe_params"].items(): moe_params.add_hparam(k, v) moe_train = params["mode"] == "train" m, aux_loss = mtf.transformer.moe.transformer_moe_layer_v1( res_x, x.shape[-1], moe_params, train=moe_train, mesh_shape=params["mesh_shape"], layout=params["layout"], activation=params.get("moe_activation", "relu"), variable_dtype=variable_dtype, num_microbatches=params["num_microbatches"]) m = mtf.dropout(m, rate=params["res_dropout"], name="moe_dropout") else: mlp_fn = mlp_glu if use_mlp_glu else mlp intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2) # Define intermediate layer of mlp - to split dim_intermediate_expanded = mtf.Dimension( "intermediate_expanded", intermediate_size) m = mlp_fn(res_x, "mlp", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params) aux_loss = mtf.zeros(x.mesh, mtf.Shape([]), dtype=variable_dtype.slice_dtype) x = x + pre_residual_fn( (m * mult), "norm_rezero_2", variable_dtype) return x, aux_loss
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None): # x :: [batch, seq, n_embd] x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh # n_state is the same as config["n_embd"], which is also the same as dim_embd. assert n_state.size % params["n_head"] == 0 dim_heads = mtf.Dimension("heads", params["n_head"]) num_mem_kv = params.get("num_mem_kv", 0) use_num_mem_kv = num_mem_kv > 0 with tf.variable_scope(scope): # Compute attention inputs dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"]) mtfparams = mtf.transformer.attention.attention_params_simple( x.mesh, io_dim=dim_embd, kv_dim=dim_kv, heads_dim=dim_heads, variable_dtype=variable_dtype) q = mtfparams.compute_q(x) k = mtfparams.compute_k(x) v = mtfparams.compute_v(x) if is_incremental_inference(context): one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype) inv_one_hot = 1.0 - one_hot old_k, old_v = context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if exists(context): context.record_new_states([k, v]) with tf.variable_scope("attention"): if attention_type == "local": # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. radius = params.get("local_attention_radius", 256) if is_incremental_inference(context): q *= one_hot a = mtf_transformer.attention.local_attention_1d( q, k, v, length_dim=k.shape[1], key_dim=dim_kv, value_dim=dim_kv, radius=radius, length_dim_num_splits=1, fully_autoregressive=params["causal"], attention_kwargs={}, ) if is_incremental_inference(context): a = mtf.gather(a, context.position - 1, dim_seq) elif attention_type == "global": # TODO: pass in fake context # Broadcast mask bias across batch and heads if exists(bias): if not is_incremental_inference(context): broadcasted_bias = mtf.broadcast( bias, [ dim_batch, dim_heads, bias.shape[-2], bias.shape[-1] ]) else: # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position bias = mtf.gather(bias, context.position - 1, dim_seq) broadcasted_bias = mtf.broadcast( bias, [dim_batch, dim_heads, bias.shape[-1]]) # memory key / values, from all-attention paper if use_num_mem_kv: k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh) k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim) v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim) attn_dropout_rate = params["attn_dropout"] if params[ "mode"] == "train" else 0 a = mtf_transformer.attention.attention( q, k, v, memory_length_dim=memory_length_dim, key_dim=dim_kv, value_dim=dim_kv, bias=broadcasted_bias, dropout_rate=attn_dropout_rate) elif attention_type == "linear": linear_attn_fn = causal_linear_attention if params[ "causal"] else linear_attention a = linear_attn_fn(q, k, v) else: raise NotImplementedError( "Unknown attention type {}!".format(attention_type)) with tf.variable_scope("compute_output"): a = mtfparams.compute_output(a, x_shape) with tf.variable_scope("compute_output_bias"): b = mtf.get_variable( x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) a += b if params["mode"] == "train" and params["res_dropout"] > 0: a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout") return a
def _rand_1_gating( inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="rand_1_gating", num_microbatches=None): """Compute a random top-1 gating.""" # SELECT EXPERT if train: policy = hparams.moe_rand_1_policy_train else: policy = hparams.moe_rand_1_policy_eval # The internals of this function run in float32. # bfloat16 seems to reduce quality. gate_inputs = mtf.to_float(inputs) # Input perturbations if train and policy == "input_dropout": gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_rand_1_dropout) elif train and policy == "input_jitter": gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs, hparams.moe_rand_1_jitter) gate_logits = mtf.layers.dense( gate_inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim) if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter": expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim) elif policy == "sample": expert_index = mtf.sample_with_temperature( gate_logits, experts_dim, temperature=hparams.moe_rand_1_temperature) expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim) else: raise ValueError("Unknown rand_1 policy %s" % policy) expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) # LOAD BALANCING LOSS # TODO(liamfedus): Check entropy loss. group_size_dim = inputs.shape[-2] density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim) density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim) if importance is not None: expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) density_1_proxy *= mtf.cast( mtf.equal(importance, 1.0), dtype=raw_gates.dtype) loss = ( mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if num_microbatches and num_microbatches > 1: tf.logging.info("Dividing load-balance loss by num_microbatches={}".format( num_microbatches)) loss /= num_microbatches # Logging if train: entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim) batch_entropy = mtf.reduce_mean(entropy) mtf.scalar_summary(name + "/entropy", batch_entropy) mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim]) total_routed = mtf.reduce_sum(mask_count_experts) expert_fraction = mtf.to_float(mask_count_experts / total_routed) split_fractions = mtf.split( expert_fraction, split_dim=experts_dim, num_or_size_splits=experts_dim.size) for fraction in split_fractions: mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"), mtf.reduce_mean(fraction)) mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss)) # COMPUTE ASSIGNMENT TO EXPERT # Experts have a limited capacity, ensure we do not exceed it. Construct # the batch indices, to each expert, with position_in_expert position_in_expert = mtf.cumsum( expert_mask, group_size_dim, exclusive=True) * expert_mask position_in_expert = mtf.cast(position_in_expert, dtype=raw_gates.dtype) # Keep only tokens that fit within expert_capacity. expert_capacity_float = float(expert_capacity_dim.size) expert_mask *= mtf.cast( mtf.less(position_in_expert, expert_capacity_float), dtype=raw_gates.dtype) expert_mask_flat = mtf.reduce_sum(expert_mask, reduced_dim=experts_dim) # Mask out the experts that have overflowed expert capacity. Sparsify the # expert_gate. expert_gate *= expert_mask_flat combine_tensor = ( expert_gate * expert_mask_flat * mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) * mtf.one_hot( mtf.to_int32(position_in_expert), expert_capacity_dim, dtype=raw_gates.dtype)) # Match the inputs dtype. combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast( mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def attention(self, x, n_state, mask, attention_type="global", name="attn"): # x :: [batch, seq, n_embd] batch_dim, seq_dim, embd_dim = x_shape = x.shape assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads" with tf.variable_scope(name): # Compute attention inputs mtfparams = mtf.transformer.attention.attention_params_simple( x.mesh, io_dim=self.dimensions["embed_dim"], kv_dim=self.dimensions["kv_dim"], heads_dim=self.dimensions["heads_dim"], variable_dtype=self.variable_dtype) q = mtfparams.compute_q(x) k = mtfparams.compute_k(x) v = mtfparams.compute_v(x) if self.is_incremental_inference: one_hot = mtf.one_hot(self.context.position - 1, seq_dim, dtype=self.variable_dtype.master_dtype) inv_one_hot = 1.0 - one_hot old_k, old_v = self.context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if exists(self.context): self.context.record_new_states([k, v]) with tf.variable_scope("attention"): if attention_type == "local": # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. radius = self.params.get("local_attention_radius", 256) if self.is_incremental_inference: q *= one_hot a = mtf_transformer.attention.local_attention_1d( q, k, v, length_dim=k.shape[1], key_dim=self.dimensions["kv_dim"], value_dim=self.dimensions["kv_dim"], radius=radius, length_dim_num_splits=1, fully_autoregressive=True, attention_kwargs={}, ) if self.is_incremental_inference: a = mtf.gather(a, self.context.position - 1, seq_dim) elif attention_type == "global": if exists(mask): if not self.is_incremental_inference: broadcasted_mask = mtf.broadcast( mask, [ batch_dim, self.dimensions["heads_dim"], mask.shape[-2], mask.shape[-1] ]) # TODO: not sure this is correct else: # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position mask = mtf.gather(mask, self.context.position - 1, seq_dim) broadcasted_mask = mtf.broadcast( mask, [ batch_dim, self.dimensions["heads_dim"], mask.shape[-1] ]) k = mtf.replace_dimensions( k, k.shape[1], self.dimensions["memory_len_dim"]) v = mtf.replace_dimensions( v, v.shape[1], self.dimensions["memory_len_dim"]) attn_dropout_rate = self.params.get( "attention_dropout", 0) if self.mode == "train" else 0 a = mtf_transformer.attention.attention( q, k, v, memory_length_dim=self.dimensions["memory_len_dim"], key_dim=self.dimensions["kv_dim"], value_dim=self.dimensions["kv_dim"], bias=broadcasted_mask, dropout_rate=attn_dropout_rate) else: raise NotImplementedError( "Unknown attention type {}!".format(attention_type)) with tf.variable_scope("compute_output"): a = mtfparams.compute_output(a, x_shape) with tf.variable_scope("compute_output_bias"): b = mtf.get_variable( x.mesh, "o_b", [embd_dim], initializer=tf.constant_initializer(0), master_dtype=self.variable_dtype.master_dtype, slice_dtype=self.variable_dtype.slice_dtype, activation_dtype=self.variable_dtype.activation_dtype) a += b residual_dropout = self.params.get("residual_dropout", 0) if self.mode == "train" and residual_dropout > 0: a = mtf.dropout(a, rate=residual_dropout, name="res_dropout") return a
def transformer_moe_layer_v1(inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu, num_microbatches=None, token_embeddings=None, context=None): """Local heterogenous mixture of experts. See transformer_moe_layer_v1 in moe.py for a more detailed explanation for a generic moe layer. The heterogeneous mask outputted by generate_heterogeneous_expert_masks has dimension [maximum hidden size, maximum # layers, # experts] and its shape will overwrite the parameters moe_num_layers and moe_hidden_size in hparams. The layer-specific mask slice is applied at each expert layer to the activation which is [expert width, # experts]. If the heterogeneous_mask_info is None, there is no mask applied and the code is equivalent to the homogeneous case. The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-2 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Dimensions cheat sheet: B: batch dim(s) L: original sequence length M: input depth N: output depth G: number of groups S: group size E: number of experts C: expert capacity Args: inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean variable_dtype: a mtf.VariableDType layout: optional - an input to mtf.convert_to_layout_rules mesh_shape: optional - an input to mtf.convert_to_shape nonpadding: an optional Tensor with shape [batch_dim(s), length_dim] and the same dtype as inputs, consisting of ones(nonpadding) and zeros(padding). activation: a function. num_microbatches: number of microbatches. token_embeddings: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]. These are the word embeddings for that correspond to the inputs. These can optionally be used to make routing decisions. context: a Context. Returns: outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ orig_inputs = inputs experts_dim = mtf.Dimension("experts", hparams.moe_num_experts) if hparams.moe_heterogeneous_mask_info is not None: tf.logging.info("moe_heterogeneous_mask_info: {}".format( hparams.moe_heterogeneous_mask_info)) heterogeneous_mask = generate_heterogeneous_expert_masks( hparams.moe_heterogeneous_mask_info, hparams.moe_num_experts, experts_dim, mesh=inputs.mesh, expert_width=hparams.moe_hidden_size) # overwrite depth and width with the mask maximum dimension hparams.moe_num_layers = heterogeneous_mask.shape[1].size hparams.moe_hidden_size = heterogeneous_mask.shape[0].size hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups is a multiple of the mesh dimension # over which those groups are split. batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1], orig_inputs.shape.dims[-1]) # Hack: we assume that # "outer_batch" == replication of experts # mesh_dim_size can be derived from mesh_shape and orig_batch_dim # # We then reqire num_groups to be a multiple of mesh_dim_size. if orig_inputs.shape.dims[0].name == "outer_batch": outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2] else: outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1), orig_inputs.shape.dims[0]) # Number of MoE inputs (total number of position across batch_and_length_dims # per replica. n = 1 for d in batch_and_length_dims: n *= d.size n = n // outer_batch_dim.size mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, orig_batch_dim) num_groups, group_size = moe._split_into_groups( # pylint: disable=protected-access n, hparams.moe_group_size, mesh_dim_size) # TODO(barretzoph): implementation without pylint calls? group_size_dim = mtf.Dimension("group", group_size) num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups) moe_input_dims = [ outer_batch_dim, num_groups_dim, group_size_dim, input_dim ] # OGSM Tensor inputs = mtf.reshape(inputs, moe_input_dims) # Token embeddings that can be optionally used in the router for determining # where to send tokens. if hparams.moe_word_embed_mode is not None: token_embeddings = mtf.cast( mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype) # Each sequence sends expert_capacity positions to each expert. if train: capacity_factor = hparams.moe_capacity_factor_train else: capacity_factor = hparams.moe_capacity_factor_eval expert_capacity = min( group_size_dim.size, int((group_size_dim.size * capacity_factor) / experts_dim.size)) expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity) tf.logging.info("expert_capacity: %d" % expert_capacity) expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity) experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size) batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size) if nonpadding is not None: nonpadding = mtf.zeros(inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1]) if hparams.moe_gating == "top_2": # combine_tensor, # dispatch_tensor OG`SEC Tensors # (G is generally split along mesh dim) dispatch_tensor, combine_tensor, loss = moe._top_2_gating( # pylint: disable=protected-access inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding, num_microbatches=num_microbatches, token_embeddings=token_embeddings) elif hparams.moe_gating == "top_n": dispatch_tensor, combine_tensor, loss = moe._top_n_gating( # pylint: disable=protected-access inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding, num_microbatches=num_microbatches, token_embeddings=token_embeddings) elif hparams.moe_gating == "switch": dispatch_tensor, combine_tensor, loss = moe._switch_gating( # pylint: disable=protected-access inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding, num_microbatches=num_microbatches, token_embeddings=token_embeddings) elif hparams.moe_gating == "ntlb": dispatch_tensor, combine_tensor, loss = moe._ntlb_gating( # pylint: disable=protected-access inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding, num_microbatches=num_microbatches, token_embeddings=token_embeddings) elif hparams.moe_gating == "switch_max": dispatch_tensor, combine_tensor, loss = moe._switch_max_gating( # pylint: disable=protected-access inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding, num_microbatches=num_microbatches, token_embeddings=token_embeddings) elif hparams.moe_gating == "expert_selection": dispatch_tensor, combine_tensor, loss = moe._expert_selection_gating( # pylint: disable=protected-access inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, group_size_dim=group_size_dim, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding, name="expert_selection_gating", num_microbatches=num_microbatches, token_embeddings=token_embeddings) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) expert_inputs = mtf.einsum([inputs, dispatch_tensor], mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, input_dim ])) # Extra reshape reduces communication cost for model-parallel versions. # For model-parallel versions, this reshape causes an mtf.slice and for non- # model-parallel versions, this has no effect. d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size) expert_inputs = mtf.reshape( expert_inputs, mtf.Shape([ outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim, d_model_split_dim ])) # Split over batch -> split over experts expert_inputs = mtf.reshape( expert_inputs, mtf.Shape([ outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim ])) # Pretend we have heterogeneous_mask with shape [moe_num_layers, num_experts] for layer_idx in range(hparams.moe_num_layers): with tf.variable_scope("expert_layer_{}".format(layer_idx)): res_h = 0.0 if layer_idx > 0: res_h = expert_inputs expert_inputs = transformer.sublayer_rms_norm( expert_inputs, None, context) # Now feed the expert inputs through the experts. h = mtf.layers.dense_product( expert_inputs, reduced_dims=expert_inputs.shape.dims[-1:], new_dims=[hidden_dim], expert_dims=[experts_dim], activation_functions=activation, use_bias=False, variable_dtype=variable_dtype, name="wi") # apply dropout if hparams.moe_dropout_rate != 0.0: h = mtf.dropout(h, is_training=train, keep_prob=1.0 - hparams.moe_dropout_rate) # only if heterogeneous if hparams.moe_heterogeneous_mask_info is not None: # Get mask for current layer by slicing heterogeneous mask heterogeneous_mask_slice = mtf.slice(heterogeneous_mask, layer_idx, 1, "num_expert_layers") # Get rid of the expert layers dimension. heterogeneous_mask_slice = mtf.reshape( heterogeneous_mask_slice, [ heterogeneous_mask_slice.shape[0], heterogeneous_mask_slice.shape[-1] ]) h *= mtf.cast(heterogeneous_mask_slice, h.dtype) expert_output = mtf.layers.dense(h, output_dim, expert_dims=[experts_dim], use_bias=False, reduced_dims=h.shape.dims[-1:], variable_dtype=variable_dtype, name="wo") if layer_idx < (hparams.moe_num_layers - 1): expert_output = transformer.sublayer_dropout( expert_output, None, context) expert_output += res_h expert_inputs = expert_output # Extra reshape reduces communication cost for model-parallel versions. # For model-parallel versions, this reshape causes an mtf.slice and for non- # model-parallel versions, this has no effect. expert_output = mtf.reshape( expert_output, mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, d_model_split_dim ])) # Split over experts -> split over batch expert_output = mtf.reshape( expert_output, mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, output_dim, ])) moe_output_dims = moe_input_dims[:-1] + [output_dim] output = mtf.einsum([expert_output, combine_tensor], mtf.Shape(moe_output_dims)) output = mtf.reshape(output, batch_and_length_dims + [output_dim]) return output, loss * hparams.moe_loss_coef
def self_attention(self, x, attention_bias): """Performs multi-headed self-attention with output projection. Args: x: output of previous layer attention_bias: optional float32 Tensor broadcastable to shape x.shape - self.model_dim + self.memory_seq_dim to be added to attention logits. This may used to mask out padding regions of the memory. Returns: float Tensor with the same shape as x """ queries = mtf.layers.dense( x, reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="query", use_bias=self.config.use_bias) keys = mtf.layers.dense( mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim), reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="key", use_bias=self.config.use_bias) values = mtf.layers.dense( mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim), reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="value", use_bias=self.config.use_bias) # Take the dot product between "query" and "key" to get the raw # attention scores. attention_scores = mtf.einsum([queries, keys], reduced_dims=[self.size_per_head_dim]) attention_scores *= self.size_per_head_dim.size**-0.5 if attention_bias is not None: attention_scores += attention_bias # Normalize the attention scores to probabilities. attention_probs = mtf.softmax(attention_scores, self.memory_seq_dim) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = mtf.dropout(attention_probs, keep_prob=1.0 - self.config.attention_probs_dropout_prob) output = mtf.einsum([attention_probs, values], reduced_dims=[self.memory_seq_dim]) # linear transformation back to shape of query_antecedent output = mtf.layers.dense( output, reduced_dims=[self.num_heads_dim, self.size_per_head_dim], new_dims=[self.model_dim], kernel_initializer=self.dense_initializer, name="output", use_bias=self.config.use_bias) output = mtf.transpose(output, x.shape) return output
def add_dropout(self, x, dropout_prob=0.0): return mtf.dropout(x, keep_prob=1.0 - dropout_prob)
def Alexnet(img, labels, num_nodes, num_gpus, args): num_classes = 1000 keep_prob = 0.5 learning_rate = 0.01 graph, meshes, mesh_to_impl, mtf_img, mtf_labels = CreateMeshes( img, labels, num_nodes, num_gpus, args) RenameFC = lambda x: mt.rename_dimension(x, x.shape[-1].name, utils.RandName()) strategy = args.strategy if strategy == 0: fc6_units = mtf.Dimension(utils.RandName(), 4096) fc7_units = mtf.Dimension(utils.RandName(), 4096) fc8_units = mtf.Dimension(utils.RandName(), num_classes) elif strategy == 1: fc6_units = mtf.Dimension('axis1', 4096) fc7_units = mtf.Dimension('axis0', 4096) fc8_units = mtf.Dimension('axis1', num_classes) elif strategy == 2: num_classes = utils.RoundUp(num_classes, num_gpus) fc6_units = mtf.Dimension('axis0', 4096) fc7_units = mtf.Dimension('axis0', 4096) fc8_units = mtf.Dimension('axis0', num_classes) elif strategy == 3: num_classes = utils.RoundUp(num_classes, num_gpus // 2) fc6_units = mtf.Dimension('axis1', 4096) fc7_units = mtf.Dimension('axis1', 4096) fc8_units = mtf.Dimension('axis1', num_classes) with tf.variable_scope('alexnet'): # Conv1 + ReLU + maxpool1 conv1 = mt.Conv2d(mtf_img, GetFilterShape(mtf_img, (11, 11, 3, 96)), (4, 4), 'VALID', activation=mtf.relu, name='conv1') pool1 = mt.MaxPool(conv1, (3, 3), (2, 2), 'VALID', name='pool1') # Conv2 + ReLU + maxpool2 conv2 = mt.Conv2d(pool1, GetFilterShape(pool1, (5, 5, 96, 256)), (1, 1), 'SAME', activation=mtf.relu, name='conv2') pool2 = mt.MaxPool(conv2, (3, 3), (2, 2), name='pool2') # Conv3 + ReLU conv3 = mt.Conv2d(pool2, GetFilterShape(pool2, (3, 3, 256, 384)), padding='SAME', activation=mtf.relu, name='conv3') # Conv4 + ReLU conv4 = mt.Conv2d(conv3, GetFilterShape(conv3, (3, 3, 384, 384)), padding='SAME', activation=mtf.relu, name='conv4') # Conv5 + ReLU + maxpool5 conv5 = mt.Conv2d(conv4, GetFilterShape(conv4, (3, 3, 384, 256)), padding='SAME', activation=mtf.relu, name='conv5') pool5 = mt.MaxPool(conv5, (3, 3), (2, 2), name='pool5') # Rename dims if strategy == 1: k_dim = mtf.Dimension(utils.RandName(), utils.Prod(pool5.shape.to_integer_list[1:])) pool5 = mtf.reshape(pool5, mtf.Shape([pool5.shape[0], k_dim])) pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1], (utils.RandName(), 'axis0')) elif strategy == 2: pool5 = mt.rename_dimension(pool5, pool5.shape[0].name, utils.RandName()) elif strategy == 3: assert pool5.shape[0].name == 'axis0' #dim_names = pool5.shape.rename_dimension('axis0', utils.RandName()) #pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1], dim_names) pool5 = ReplaceMeshWithConcatSplit(pool5, meshes[1]) # FC + ReLU + dropout fc_activation = lambda x: mtf.dropout(mtf.relu(x), keep_prob) fc6 = mtf.layers.dense(pool5, fc6_units, activation=fc_activation, reduced_dims=pool5.shape[1:], name='fc6') if strategy == 2: fc6 = RenameFC(fc6) elif strategy == 3: fc6 = RenameFC(fc6) fc7 = mtf.layers.dense(fc6, fc7_units, activation=fc_activation, reduced_dims=fc6.shape.dims[-1:], name='fc7') if strategy == 2: fc7 = RenameFC(fc7) elif strategy == 3: fc7 = RenameFC(fc7) fc8 = mtf.layers.dense(fc7, fc8_units, reduced_dims=fc7.shape.dims[-1:], name='fc8') fc8 = mtf.dropout(fc8, keep_prob) if strategy == 1: assert fc8.shape[-1].name == 'axis1' fc8 = ReplaceMeshWithDuplicates(fc8, meshes[2]) with tf.variable_scope('loss'): if fc8.shape[0] != mtf_labels.shape[0]: fc8 = mt.rename_dimension(fc8, fc8.shape[0].name, mtf_labels.shape[0].name) one_hot_labels = mtf.one_hot(mtf_labels, fc8.shape[-1]) mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits( fc8, one_hot_labels, fc8.shape[-1]) mtf_loss = mtf.reduce_mean(mtf_cross_ent) return graph, mesh_to_impl, mtf_loss
def layer_prepostprocess_dropout(x, hparams): batch_dim = x.shape.dims[0] model_dim = x.shape.dims[-1] return mtf.dropout(x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([batch_dim, model_dim]))
def model(self, mesh, x, y, params): # x :: [batch, io, vocab] if params["precision"] == "bfloat16": dtype = tf.bfloat16 # master has type float32, slice and activation have type bfloat16 variable_dtype = mtf.VariableDType(tf.float32, tf.bfloat16, tf.bfloat16) else: dtype = tf.float32 # master, slice and activate have all float16 variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32) # Build the actual model batch_dim = mtf.Dimension("batch", params["batch_size"]) vocab_dim = mtf.Dimension("vocab", params["vocab_size"]) io_dim = mtf.Dimension("sequence", params["io"]) io_chan_dim = mtf.Dimension("io", params["io_channels"]) # from input to mtf x = mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, io_dim, vocab_dim])) # Embeddings with tf.variable_scope(scope="toy", default_name="seq2seq"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([vocab_dim, io_chan_dim]), initializer=self.embedding_initializer, ) word_embedding_output = mtf.gather( embedding_table, x, dim=vocab_dim, output_shape=io_chan_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. embedding_output = word_embedding_output pos_embedding = mtf.get_variable( mesh, "pos_embeddings", mtf.Shape([io_dim, io_chan_dim]), initializer=self.embedding_initializer, ) embedding_output = self.normalize(embedding_output) embedding_output = mtf.dropout( embedding_output, keep_prob=1.0 - self.config.layer_output_dropout_prob, ) # shift token by pos embeddings x = word_embedding_output + pos_embedding x = mtf.cast(x, variable_dtype.activation_dtype) h = x for lnum in range(1, self.num_hidden_layers + 2): if lnum + 1 == self.num_hidden_layers + 2: # output layer dim = io_dim elif lnum % 2 == 0: dim = mtf.Dimension("hidden_even", io_chan_dim) else: dim = mtf.Dimension("hidden_odd", io_chan_dim) h = mtf.layers.dense( h, dim, use_bias=False, master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, name="layer_%d" % lnum, ) prediction = h # project back to token dimensions # compute the mean quare loss between the input and the output loss = mtf.reduce_mean(mtf.square(y - prediction)) return prediction, loss