Beispiel #1
0
def post_language_model_processing(
    lm_output,
    labels,
    logit_weights,
    get_key_value,
    parallel_output,
    forward_method_parallel_output,
    fp16_lm_cross_entropy,
    return_logits=False,
):
    if get_key_value:
        lm_output, presents = lm_output

    # Output.
    if forward_method_parallel_output is not None:
        parallel_output = forward_method_parallel_output
    output = parallel_lm_logits(lm_output, logit_weights, parallel_output)

    if get_key_value:
        output = [output, presents]

    if labels is None:
        return output
    else:
        if fp16_lm_cross_entropy:
            assert output.dtype == torch.half
            loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
        else:
            loss = tensor_parallel.vocab_parallel_cross_entropy(
                output.float(), labels)

        if return_logits:
            return loss, output
        else:
            return loss
Beispiel #2
0
    def forward(
        self,
        encoder_input_ids,
        decoder_input_ids,
        encoder_attn_mask,
        decoder_attn_mask,
        encoder_decoder_attn_mask,
        tokentype_ids=None,
        lm_labels=None,
        enc_hidden_states=None,
        output_enc_hidden_only=False,
    ):

        # Converting the attention masks to proper parameter settings
        encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask(
            [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask])

        encoder_position_ids = t5_position_ids(encoder_input_ids)

        # Handle case when decoder_input_ids is None to get just the encoder hidden states.
        decoder_position_ids = t5_position_ids(
            decoder_input_ids) if decoder_input_ids is not None else None

        lm_output = self.language_model(
            enc_input_ids=encoder_input_ids,
            enc_position_ids=encoder_position_ids,
            enc_attn_mask=encoder_attn_mask,
            dec_input_ids=decoder_input_ids,
            dec_position_ids=decoder_position_ids,
            dec_attn_mask=decoder_attn_mask,
            enc_dec_attn_mask=encoder_decoder_attn_mask,
            tokentype_ids=tokentype_ids,
            enc_hidden_states=enc_hidden_states,
            output_enc_hidden_only=output_enc_hidden_only,
        )

        if output_enc_hidden_only:
            return lm_output

        decoder_output, encoder_output = lm_output

        # Output.
        lm_logits = self.lm_head(
            decoder_output,
            self.language_model.embedding.word_embeddings.weight)

        if lm_labels is None:
            return lm_logits, encoder_output
        else:
            if self.fp16_lm_cross_entropy:
                assert lm_logits.dtype == torch.half
                lm_loss = tensor_parallel.vocab_parallel_cross_entropy(
                    lm_logits, lm_labels)
            else:
                lm_loss = tensor_parallel.vocab_parallel_cross_entropy(
                    lm_logits.float(), lm_labels)
            return lm_loss, encoder_output
Beispiel #3
0
    def forward(
        self,
        input_ids,
        input_attn_mask,
        retrieved_emb,
        retrieved_attn_mask,
        token_type_ids=None,
        labels=None,
        input_emb=None,
    ):
        """
        Return value is per token / per dimension (i.e., non collapsed loss value)
        """
        if input_emb is None:
            if self.pre_process and self.add_encoder:
                # encoder embeddings
                input_position_ids = build_position_ids(input_ids)
                input_emb = self.encoder_embedding(
                    input_ids,
                    input_position_ids,
                    token_type_ids=token_type_ids)
            else:
                input_emb = None

        if self.add_decoder:
            hidden = self.pre_decoder(input_emb, input_attn_mask)

        if self.add_encoder:
            retrieved_emb = self.encoder(retrieved_emb,
                                         retrieved_attn_mask,
                                         context_attn_mask=input_attn_mask,
                                         encoder_output=hidden)

        if self.add_decoder:
            dec_output = self.post_decoder(
                hidden,
                input_attn_mask,
                retrieved_attn_mask=retrieved_attn_mask,
                retrieved_emb=retrieved_emb)
            token_logits = self.tokens_head(dec_output,
                                            self.word_embeddings_weight())

            if labels is not None:
                # tensor_parallel.vocab_parallel_cross_entropy performs log_softmax and return log p(x_i|z) per token i
                if self.fp16_cross_entropy:
                    assert token_logits.dtype == torch.half
                    tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(
                        token_logits, labels)
                else:
                    tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(
                        token_logits.float(), labels)
                return tokens_loss
            else:
                return token_logits
Beispiel #4
0
def post_language_model_processing(
    lm_output, pooled_output, lm_head, binary_head, lm_labels, logit_weights, fp16_lm_cross_entropy
):
    # Output.
    lm_logits = lm_head(lm_output, logit_weights)

    binary_logits = None
    if binary_head is not None:
        binary_logits = binary_head(pooled_output)

    if lm_labels is None:
        return lm_logits, binary_logits
    else:
        if fp16_lm_cross_entropy:
            assert lm_logits.dtype == torch.half
            lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
        else:
            lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels)
        return lm_loss, binary_logits
Beispiel #5
0
    def forward(
        self,
        input_ids,
        input_attn_mask,
        retrieved_ids,
        retrieved_attn_mask,
        token_type_ids=None,
        labels=None,
        input_emb=None,
        set_inference_key_value_memory=False,
        inference_max_sequence_len=None,
        neighbors=None,
    ):
        """
        Return value is per token / per dimension (i.e., non collapsed loss value)
        """
        eod_positions = None
        retrieved_emb = None
        if input_ids is not None and self.eod_id is not None:
            eod_positions = torch.where(input_ids == self.eod_id)

        if input_emb is None:
            if self.pre_process and self.add_encoder:
                # encoder embeddings
                if self.add_abs_position_embedding:
                    input_position_ids = build_position_ids(input_ids)
                else:
                    input_position_ids = None
                input_emb = self.encoder_embedding(input_ids, input_position_ids, token_type_ids=token_type_ids)
            else:
                input_emb = None

        if retrieved_ids is not None:
            if self.add_abs_position_embedding:
                seq_length = retrieved_ids.size(-1)
                retrieved_position_ids = torch.arange(seq_length, dtype=torch.long, device=retrieved_ids.device)
                retrieved_position_ids = retrieved_position_ids.unsqueeze(0).expand_as(retrieved_ids).clone()
            else:
                retrieved_position_ids = None
            retrieved_emb = self.encoder_embedding(retrieved_ids, retrieved_position_ids)

        if self.add_decoder:
            hidden = self.pre_decoder(
                input_emb,
                input_attn_mask,
                eod_positions=eod_positions,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len,
            )
            # hidden is a tuple, (layernorm_input, layernorm_output)
            self.post_decoder.set_input_tensor(hidden)
            encoder_input = hidden[1].transpose(0, 1).contiguous()

        if self.add_encoder:
            if retrieved_emb is not None and neighbors is None:
                neighbors = retrieved_emb.shape[2]
            retrieved_emb = self.encoder(
                retrieved_emb,
                retrieved_attn_mask,
                context_attn_mask=input_attn_mask,
                encoder_output=encoder_input,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len,
                neighbors=neighbors,
            )

        if self.add_decoder:
            dec_output = self.post_decoder(
                hidden,
                input_attn_mask,
                retrieved_attn_mask=retrieved_attn_mask,
                retrieved_emb=retrieved_emb,
                eod_positions=eod_positions,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len,
            )
            token_logits = self.tokens_head(dec_output, self.word_embeddings_weight())

            if labels is not None:
                # tensor_parallel.vocab_parallel_cross_entropy performs log_softmax and return log p(x_i|z) per token i
                if self.fp16_cross_entropy:
                    assert token_logits.dtype == torch.half
                    tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(token_logits, labels)
                else:
                    tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(token_logits.float(), labels)
                return tokens_loss
            else:
                return token_logits
    def forward(
        self,
        enc_input_ids,
        enc_attn_mask,
        dec_input_ids,
        dec_attn_mask,
        token_type_ids=None,
        labels=None,
        enc_hidden_states=None,
        enc_output_mask=None,
        output_enc_hidden_only=False,
        enc_input=None,
    ):
        """
        Return value is per token / per dimension (i.e., non collapsed loss value)
        """
        if enc_input is None:
            if self.pre_process and self.add_encoder:
                # encoder embeddings
                enc_position_ids = build_position_ids(enc_input_ids)
                enc_input = self.encoder_embedding(
                    enc_input_ids,
                    enc_position_ids,
                    token_type_ids=token_type_ids)
            else:
                enc_input = None

        if output_enc_hidden_only:
            enc_output = self.enc_dec_model.encode(
                enc_input=enc_input,
                enc_attn_mask=enc_attn_mask,
                enc_layer_past=None,
                enc_get_key_value=False,
            )
            return enc_output
        else:
            if self.pre_process and self.add_decoder:
                dec_position_ids = build_position_ids(dec_input_ids)
                dec_input = self.decoder_embedding(
                    dec_input_ids,
                    dec_position_ids,
                    token_type_ids=token_type_ids)
            else:
                # Note: This is when the decoder itself is split across PP ranks.
                dec_input = None

            output = self.enc_dec_model(
                enc_input=enc_input,
                enc_attn_mask=enc_attn_mask,
                dec_input=dec_input,
                dec_attn_mask=dec_attn_mask,
                enc_layer_past=None,
                enc_get_key_value=False,
                enc_output=None,
                dec_layer_past=None,
                dec_get_key_value=False,
            )

            if self.post_process and self.add_decoder:
                dec_output, enc_output = output
                # project decoder output to vocabulary-size dimensions
                token_logits = self.tokens_head(dec_output,
                                                self.word_embeddings_weight())

                if labels is not None:
                    # tensor_parallel.vocab_parallel_cross_entropy performs log_softmax and return log p(x_i|z) per token i
                    if self.fp16_cross_entropy:
                        assert token_logits.dtype == torch.half
                        tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(
                            token_logits, labels)
                    else:
                        tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(
                            token_logits.float(), labels)
                    return tokens_loss
                else:
                    return token_logits

            elif self.add_decoder and not self.add_encoder:
                decoder_output, _ = output
                return decoder_output
            else:
                encoder_output = output
                return encoder_output
Beispiel #7
0
    def forward(
        self,
        enc_input_ids,
        enc_attn_mask,
        dec_input_ids,
        dec_attn_mask,
        tokentype_ids=None,
        labels=None,
        enc_hidden_states=None,
        enc_output_mask=None,
        output_enc_hidden_only=False,
        enc_input=None,
    ):
        """
        Return value is per token / per dimension (i.e., non collapsed loss value)
        """
        ret_dict = {}

        # encoder embeddings
        if enc_input is None:
            enc_position_ids = build_position_ids(enc_input_ids)
            enc_input = self.encoder_embedding(enc_input_ids,
                                               enc_position_ids,
                                               tokentype_ids=tokentype_ids)

        if output_enc_hidden_only:
            enc_output, enc_output_mask = self.enc_dec_model.encode(
                enc_input=enc_input,
                enc_attn_mask=enc_attn_mask,
                enc_layer_past=None,
                enc_get_key_value=False,
            )
            ret_dict["enc_output"] = enc_output
            ret_dict["enc_output_mask"] = enc_output_mask
        else:
            dec_position_ids = build_position_ids(dec_input_ids)
            dec_input = self.decoder_embedding(dec_input_ids,
                                               dec_position_ids,
                                               tokentype_ids=tokentype_ids)

            ret_dict.update(
                self.enc_dec_model(
                    enc_input=enc_input,
                    enc_attn_mask=enc_attn_mask,
                    dec_input=dec_input,
                    dec_attn_mask=dec_attn_mask,
                    enc_layer_past=None,
                    enc_get_key_value=False,
                    enc_output=enc_hidden_states,
                    enc_output_mask=enc_output_mask,
                    dec_layer_past=None,
                    dec_get_key_value=False,
                ))

            # project decoder output to vocabulary-size dimensions
            token_logits = self.tokens_head(
                ret_dict["dec_output"],
                self.decoder_embedding.word_embeddings.weight)
            # token_logits [batch, length, vocab_size]
            ret_dict["token_logits"] = token_logits

            if labels is not None:
                # tensor_parallel.vocab_parallel_cross_entropy performs log_softmax and return log p(x_i|z) per token i
                if self.fp16_cross_entropy:
                    assert token_logits.dtype == torch.half
                    tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(
                        token_logits, labels)
                else:
                    tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(
                        token_logits.float(), labels)

                # tokens_loss [batch, length]
                ret_dict["tokens_loss"] = tokens_loss

        return ret_dict