Exemplo n.º 1
0
def post_language_model_processing(
    lm_output,
    labels,
    logit_weights,
    get_key_value,
    parallel_output,
    forward_method_parallel_output,
    fp16_lm_cross_entropy,
    return_logits=False,
):
    if get_key_value:
        lm_output, presents = lm_output

    # Output.
    if forward_method_parallel_output is not None:
        parallel_output = forward_method_parallel_output
    output = parallel_lm_logits(lm_output, logit_weights, parallel_output)

    if get_key_value:
        output = [output, presents]

    if labels is None:
        return output
    else:
        if fp16_lm_cross_entropy:
            assert output.dtype == torch.half
            loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
        else:
            loss = tensor_parallel.vocab_parallel_cross_entropy(
                output.float(), labels)

        if return_logits:
            return loss, output
        else:
            return loss
Exemplo n.º 2
0
 def forward(self, hidden_states, word_embeddings_weight):
     hidden_states = self.dense(hidden_states)
     hidden_states = self.gelu(hidden_states)
     hidden_states = self.layernorm(hidden_states)
     output = parallel_lm_logits(hidden_states,
                                 word_embeddings_weight,
                                 self.parallel_output,
                                 bias=self.bias)
     return output
Exemplo n.º 3
0
 def forward(self, hidden_states, word_embeddings_weight):
     output = parallel_lm_logits(hidden_states,
                                 word_embeddings_weight,
                                 self.parallel_output,
                                 bias=self.bias)
     return output