Example #1
0
 def forward(self, hidden_states, word_embeddings_weight):
     hidden_states = self.dense(hidden_states)
     hidden_states = self.gelu(hidden_states)
     hidden_states = self.layernorm(hidden_states)
     output = parallel_lm_logits(hidden_states,
                                 word_embeddings_weight,
                                 self.parallel_output,
                                 bias=self.bias)
     return output
Example #2
0
    def forward(self,
                input_ids,
                position_ids,
                attention_mask,
                labels=None,
                tokentype_ids=None,
                layer_past=None,
                get_key_value=False,
                forward_method_parallel_output=None):

        # Language model.
        lm_output = self.language_model(input_ids,
                                        position_ids,
                                        attention_mask,
                                        tokentype_ids=tokentype_ids,
                                        layer_past=layer_past,
                                        get_key_value=get_key_value)

        if get_key_value:
            lm_output, presents = lm_output

        # Output.
        parallel_output = self.parallel_output
        if forward_method_parallel_output is not None:
            parallel_output = forward_method_parallel_output
        if self.weight_tying:
            output = parallel_lm_logits(
                lm_output,
                self.language_model.embedding.word_embeddings.weight,
                parallel_output)
        else:
            output, bias = self.final_linear(lm_output)

        if get_key_value:
            output = [output, presents]

        if labels is None:
            return output
        else:
            if self.fp16_lm_cross_entropy:
                assert output.dtype == torch.half
                loss = mpu.vocab_parallel_cross_entropy(output, labels)
            else:
                loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
            return loss
Example #3
0
 def _logits_helper(embedding, lm_output):
     """Just a wrapper to massage inputs/outputs from pipeline. """
     return parallel_lm_logits(
         lm_output,
         embedding.word_embeddings_weight,
         self.parallel_output)
Example #4
0
 def forward(self, hidden_states, word_embeddings_weight):
     output = parallel_lm_logits(hidden_states,
                                 word_embeddings_weight,
                                 self.parallel_output,
                                 bias=self.bias)
     return output