def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None, weight_tying=True): """LM logits using word embedding weights.""" # Parallel logits. input_parallel = mpu.copy_to_model_parallel_region(input_) if weight_tying: # Matrix multiply. if bias is None: logits_parallel = F.linear(input_parallel, word_embeddings_weight) else: logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) else: args = get_args() logits_fn = mpu.RowParallelLinear(args.hidden_size, args.padded_vocab_size, bias=False, input_is_parallel=True, skip_bias_add=False) logits_parallel = logits_fn(input_parallel) # Gather if needed. if parallel_output: return logits_parallel return mpu.gather_from_model_parallel_region(logits_parallel)
def forward(self, layernorm_input, attention_mask, presents=None, layer_past=None, get_key_value=False): # hidden_states: [b, s, h] # Collect the scattered result from the fused dropout. if self.scattered_attn_output: layernorm_input = mpu.gather_from_model_parallel_region( layernorm_input) # Attention output/bias are not used again, so no need to gather # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) #re-enable torch grad to enable fused optimization. with torch.enable_grad(): output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) if get_key_value: output = [output, presents] return output
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): """LM logits using word embedding weights.""" # Parallel logits. input_parallel = mpu.copy_to_model_parallel_region(input_) # Matrix multiply. if bias is None: logits_parallel = F.linear(input_parallel, word_embeddings_weight) else: logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) # Gather if needed. if parallel_output: return logits_parallel return mpu.gather_from_model_parallel_region(logits_parallel)