Example #1
0
    def forward(self,
                bert_model_input,
                attention_mask,
                tokentype_ids=None,
                lm_labels=None):

        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        input_ids = bert_model_input
        position_ids = build_position_ids(input_ids)

        lm_output = self.language_model(input_ids,
                                        position_ids,
                                        extended_attention_mask,
                                        tokentype_ids=tokentype_ids)

        if self.post_process and self.add_binary_head:
            lm_output, pooled_output = lm_output
        else:
            pooled_output = None

        if self.post_process:
            return post_language_model_processing(
                lm_output,
                pooled_output,
                self.lm_head,
                self.binary_head,
                lm_labels,
                self.word_embeddings_weight(),
                self.fp16_lm_cross_entropy,
            )
        else:
            return lm_output
Example #2
0
    def forward(
        self,
        input_ids,
        input_attn_mask,
        retrieved_emb,
        retrieved_attn_mask,
        token_type_ids=None,
        labels=None,
        input_emb=None,
    ):
        """
        Return value is per token / per dimension (i.e., non collapsed loss value)
        """
        if input_emb is None:
            if self.pre_process and self.add_encoder:
                # encoder embeddings
                input_position_ids = build_position_ids(input_ids)
                input_emb = self.encoder_embedding(
                    input_ids,
                    input_position_ids,
                    token_type_ids=token_type_ids)
            else:
                input_emb = None

        if self.add_decoder:
            hidden = self.pre_decoder(input_emb, input_attn_mask)

        if self.add_encoder:
            retrieved_emb = self.encoder(retrieved_emb,
                                         retrieved_attn_mask,
                                         context_attn_mask=input_attn_mask,
                                         encoder_output=hidden)

        if self.add_decoder:
            dec_output = self.post_decoder(
                hidden,
                input_attn_mask,
                retrieved_attn_mask=retrieved_attn_mask,
                retrieved_emb=retrieved_emb)
            token_logits = self.tokens_head(dec_output,
                                            self.word_embeddings_weight())

            if labels is not None:
                # tensor_parallel.vocab_parallel_cross_entropy performs log_softmax and return log p(x_i|z) per token i
                if self.fp16_cross_entropy:
                    assert token_logits.dtype == torch.half
                    tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(
                        token_logits, labels)
                else:
                    tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(
                        token_logits.float(), labels)
                return tokens_loss
            else:
                return token_logits
Example #3
0
    def get_loss(self, batch):
        tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_taskname = self.process_batch(
            batch)
        input_embeds = self.embed_input(tokens_enc, enc_taskname)

        encoder_position_ids = build_position_ids(tokens_enc)

        position_embeddings = self.position_embeddings(encoder_position_ids)

        encoder_input = input_embeds + position_embeddings

        if self.float_type == torch.float32:
            output = self.model.enc_dec_model(
                enc_input_ids=None,
                enc_attn_mask=enc_mask,
                dec_input_ids=tokens_dec,
                dec_attn_mask=dec_mask,
                token_type_ids=None,
                labels=labels,
                enc_hidden_states=None,
                output_enc_hidden_only=False,
                enc_input=encoder_input,
            )
        else:
            with torch.autocast(device_type="cuda", dtype=self.float_type):
                output = self.model.enc_dec_model(
                    enc_input_ids=None,
                    enc_attn_mask=enc_mask,
                    dec_input_ids=tokens_dec,
                    dec_attn_mask=dec_mask,
                    token_type_ids=None,
                    labels=labels,
                    enc_hidden_states=None,
                    output_enc_hidden_only=False,
                    enc_input=encoder_input,
                )

        tokens_loss = output

        loss = self.model.loss_func(loss_mask, tokens_loss)
        self.log('train_loss', loss)

        return loss, tokens_enc, labels, enc_mask, encoder_input
Example #4
0
    def collate_fn(self, batch):
        """ Prepares input_ids, labels, loss mask, attention_mask, and position ids for global batch """
        # Get max sequence length of batch
        taskname_ids, input_ids, answer_starts = zip(*batch)

        # Pad taskname_ids to be the same length for the prompt encoder
        if self.virtual_prompt_source == "prompt-encoder":
            max_taskname_length = max(len(ids) for ids in taskname_ids)
            taskname_ids = [
                ids + [self.pad_token_id] * (max_taskname_length - len(ids))
                for ids in taskname_ids
            ]
            taskname_ids = torch.tensor(taskname_ids)

        # Task ids are just used for a look up embeddings for prompt-table
        elif self.virtual_prompt_source == "prompt-table":
            taskname_ids = torch.tensor(taskname_ids)

        batch_max = max(len(ids) for ids in input_ids)
        input_ids, loss_mask = self.pad_batch_and_build_loss_mask(
            input_ids, batch_max, answer_starts)

        # Should be a label for every token in batch, label is the next token
        labels = input_ids[:, 1:].contiguous()
        input_ids = input_ids[:, :-1].contiguous()
        batch_max -= 1

        # Loss mask should align with labels
        loss_mask = loss_mask[:, 1:].contiguous()

        # Using causal attention mask for whole input
        batch_size = len(input_ids)
        attention_mask = torch.tril(
            torch.ones((batch_size, batch_max,
                        batch_max))).view(batch_size, 1, batch_max, batch_max)

        # Convert attention mask from float to bool
        attention_mask = attention_mask < 0.5
        position_ids = build_position_ids(input_ids)

        return input_ids, labels, loss_mask, position_ids, attention_mask, taskname_ids
    def get_loss(self, batch):
        enc_input = batch['enc_input']
        enc_taskname = batch['enc_taskname']
        labels = batch['labels']
        loss_mask = batch['loss_mask']
        enc_query = batch['enc_query']
        input_attn_mask = batch['input_attn_mask']

        input_attn_mask = input_attn_mask.unsqueeze(1) < 0.5

        input_embeds = self.embed_input(enc_input, enc_taskname)

        encoder_position_ids = build_position_ids(enc_input)

        position_embeddings = self.model.model.language_model.embedding.position_embeddings(
            encoder_position_ids)

        encoder_input = input_embeds + position_embeddings

        if self.float_type == torch.float32:
            output = self.model.model(
                None,
                None,
                encoder_input=encoder_input,
                attention_mask=input_attn_mask,
                labels=labels,
            )
        else:
            with torch.autocast(device_type="cuda", dtype=self.float_type):
                output = self.model.model(
                    None,
                    None,
                    encoder_input=encoder_input,
                    attention_mask=input_attn_mask,
                    labels=labels,
                )
        output_tensor, encoder_hidden_states = output
        loss = self.loss_func(loss_mask, output_tensor)
        return loss
Example #6
0
    def ptune_inference(self,
                        queries: List[Dict],
                        batch_size: int = 1,
                        decode_token_len: int = None) -> List[str]:
        """
        Get prediction for the queries
        Args:
            queries: List of data samples without labels
            batch_size: batch size to use during inference
            decode_token_len: max number of tokens to generate during inference
        Returns:
            all_preds: model predictions
        """
        if decode_token_len is None:
            decode_token_len = self.decoder_seq_length
        # store predictions for all queries in a single list
        all_preds = []
        mode = self.training
        try:
            # Switch model to evaluation mode
            self.eval()
            logging_level = logging.get_verbosity()
            logging.set_verbosity(logging.WARNING)
            dataloader_cfg = {
                "batch_size": batch_size,
                "num_workers": 3,
                "pin_memory": False
            }
            infer_datalayer = self._setup_infer_dataloader(
                dataloader_cfg, queries, decode_token_len)
            for i, batch in enumerate(infer_datalayer):
                tokens_enc = batch['text_enc'].to(self.device)
                enc_taskname = batch['enc_taskname'].to(self.device)
                enc_mask = batch['enc_mask'].to(self.device)

                input_embeds = self.embed_input(tokens_enc, enc_taskname)

                encoder_position_ids = build_position_ids(tokens_enc)

                position_embeddings = self.position_embeddings(
                    encoder_position_ids)

                encoder_input = input_embeds + position_embeddings

                # loss, tokens_enc, labels, enc_mask, encoder_input = self.get_loss(batch)
                if self.float_type == torch.float32:
                    predicted_token_ids, _ = self.model.decode(
                        tokens_enc=tokens_enc,
                        enc_mask=enc_mask,
                        num_tokens_to_generate=decode_token_len,
                        enc_input=encoder_input,
                    )
                else:
                    with torch.autocast(device_type="cuda",
                                        dtype=self.float_type):
                        predicted_token_ids, _ = self.model.decode(
                            tokens_enc=tokens_enc,
                            enc_mask=enc_mask,
                            num_tokens_to_generate=decode_token_len,
                            enc_input=encoder_input,
                        )

                preds = predicted_token_ids.cpu().numpy().tolist()
                for i, pred in enumerate(preds):
                    if self.tokenizer.eos_id in pred:
                        idx = pred.index(self.tokenizer.eos_id)
                        pred = pred[:idx]
                    pred = [
                        id for id in pred if id not in
                        self.tokenizer.special_token_to_id.values()
                    ]
                    pred = self.tokenizer.ids_to_text(pred)
                    all_preds.append(pred)
        finally:
            # set mode back to its original value
            self.train(mode=mode)
            logging.set_verbosity(logging_level)
        return all_preds
    def decode(self, enc_query, enc_taskname, label_position,
               num_tokens_to_generate):
        with torch.no_grad():
            predicted_tokens_dec = enc_query

            label_start = label_position[:, 0].clone()

            for _ in range(num_tokens_to_generate):
                attn_mask = make_attention_mask_3d(predicted_tokens_dec,
                                                   predicted_tokens_dec,
                                                   self.pad_token_id)
                attn_mask = attn_mask * make_history_mask_3d(
                    predicted_tokens_dec)

                attn_mask = attn_mask < 0.5

                attn_mask = attn_mask.unsqueeze(1)

                input_embeds = self.embed_input(predicted_tokens_dec,
                                                enc_taskname)

                encoder_position_ids = build_position_ids(predicted_tokens_dec)
                position_embeddings = self.model.model.language_model.embedding.position_embeddings(
                    encoder_position_ids)

                encoder_input = input_embeds + position_embeddings

                if self.float_type == torch.float32:
                    output = self.model.model(
                        None,
                        None,
                        encoder_input=encoder_input,
                        attention_mask=attn_mask,
                    )
                else:
                    with torch.autocast(device_type="cuda",
                                        dtype=self.float_type):
                        output = self.model.model(
                            None,
                            None,
                            encoder_input=encoder_input,
                            attention_mask=attn_mask,
                        )
                output_tensor = output

                output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(
                    output_tensor)

                # TODO, add logic to use the allowed labels if it is defined
                log_probs, token_ids = torch.max(nn.functional.log_softmax(
                    output_tensor, dim=-1),
                                                 dim=-1)

                new_pred = torch.full_like(token_ids[:, 0:1],
                                           self.pad_token_id)
                predicted_tokens_dec = torch.cat(
                    [predicted_tokens_dec, new_pred], 1)

                predicted = torch.gather(token_ids, 1, label_start.view(-1, 1))

                # need to scatter the token id at the right position
                label_start += 1
                predicted_tokens_dec.scatter_(1, label_start.view(-1, 1),
                                              predicted)

        return predicted_tokens_dec, log_probs
Example #8
0
    def forward(
        self,
        input_ids,
        input_attn_mask,
        retrieved_ids,
        retrieved_attn_mask,
        token_type_ids=None,
        labels=None,
        input_emb=None,
        set_inference_key_value_memory=False,
        inference_max_sequence_len=None,
        neighbors=None,
    ):
        """
        Return value is per token / per dimension (i.e., non collapsed loss value)
        """
        eod_positions = None
        retrieved_emb = None
        if input_ids is not None and self.eod_id is not None:
            eod_positions = torch.where(input_ids == self.eod_id)

        if input_emb is None:
            if self.pre_process and self.add_encoder:
                # encoder embeddings
                if self.add_abs_position_embedding:
                    input_position_ids = build_position_ids(input_ids)
                else:
                    input_position_ids = None
                input_emb = self.encoder_embedding(input_ids, input_position_ids, token_type_ids=token_type_ids)
            else:
                input_emb = None

        if retrieved_ids is not None:
            if self.add_abs_position_embedding:
                seq_length = retrieved_ids.size(-1)
                retrieved_position_ids = torch.arange(seq_length, dtype=torch.long, device=retrieved_ids.device)
                retrieved_position_ids = retrieved_position_ids.unsqueeze(0).expand_as(retrieved_ids).clone()
            else:
                retrieved_position_ids = None
            retrieved_emb = self.encoder_embedding(retrieved_ids, retrieved_position_ids)

        if self.add_decoder:
            hidden = self.pre_decoder(
                input_emb,
                input_attn_mask,
                eod_positions=eod_positions,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len,
            )
            # hidden is a tuple, (layernorm_input, layernorm_output)
            self.post_decoder.set_input_tensor(hidden)
            encoder_input = hidden[1].transpose(0, 1).contiguous()

        if self.add_encoder:
            if retrieved_emb is not None and neighbors is None:
                neighbors = retrieved_emb.shape[2]
            retrieved_emb = self.encoder(
                retrieved_emb,
                retrieved_attn_mask,
                context_attn_mask=input_attn_mask,
                encoder_output=encoder_input,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len,
                neighbors=neighbors,
            )

        if self.add_decoder:
            dec_output = self.post_decoder(
                hidden,
                input_attn_mask,
                retrieved_attn_mask=retrieved_attn_mask,
                retrieved_emb=retrieved_emb,
                eod_positions=eod_positions,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len,
            )
            token_logits = self.tokens_head(dec_output, self.word_embeddings_weight())

            if labels is not None:
                # tensor_parallel.vocab_parallel_cross_entropy performs log_softmax and return log p(x_i|z) per token i
                if self.fp16_cross_entropy:
                    assert token_logits.dtype == torch.half
                    tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(token_logits, labels)
                else:
                    tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(token_logits.float(), labels)
                return tokens_loss
            else:
                return token_logits
    def forward(
        self,
        enc_input_ids,
        enc_attn_mask,
        dec_input_ids,
        dec_attn_mask,
        token_type_ids=None,
        labels=None,
        enc_hidden_states=None,
        enc_output_mask=None,
        output_enc_hidden_only=False,
        enc_input=None,
    ):
        """
        Return value is per token / per dimension (i.e., non collapsed loss value)
        """
        if enc_input is None:
            if self.pre_process and self.add_encoder:
                # encoder embeddings
                enc_position_ids = build_position_ids(enc_input_ids)
                enc_input = self.encoder_embedding(
                    enc_input_ids,
                    enc_position_ids,
                    token_type_ids=token_type_ids)
            else:
                enc_input = None

        if output_enc_hidden_only:
            enc_output = self.enc_dec_model.encode(
                enc_input=enc_input,
                enc_attn_mask=enc_attn_mask,
                enc_layer_past=None,
                enc_get_key_value=False,
            )
            return enc_output
        else:
            if self.pre_process and self.add_decoder:
                dec_position_ids = build_position_ids(dec_input_ids)
                dec_input = self.decoder_embedding(
                    dec_input_ids,
                    dec_position_ids,
                    token_type_ids=token_type_ids)
            else:
                # Note: This is when the decoder itself is split across PP ranks.
                dec_input = None

            output = self.enc_dec_model(
                enc_input=enc_input,
                enc_attn_mask=enc_attn_mask,
                dec_input=dec_input,
                dec_attn_mask=dec_attn_mask,
                enc_layer_past=None,
                enc_get_key_value=False,
                enc_output=None,
                dec_layer_past=None,
                dec_get_key_value=False,
            )

            if self.post_process and self.add_decoder:
                dec_output, enc_output = output
                # project decoder output to vocabulary-size dimensions
                token_logits = self.tokens_head(dec_output,
                                                self.word_embeddings_weight())

                if labels is not None:
                    # tensor_parallel.vocab_parallel_cross_entropy performs log_softmax and return log p(x_i|z) per token i
                    if self.fp16_cross_entropy:
                        assert token_logits.dtype == torch.half
                        tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(
                            token_logits, labels)
                    else:
                        tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(
                            token_logits.float(), labels)
                    return tokens_loss
                else:
                    return token_logits

            elif self.add_decoder and not self.add_encoder:
                decoder_output, _ = output
                return decoder_output
            else:
                encoder_output = output
                return encoder_output
Example #10
0
    def forward(
        self,
        enc_input_ids,
        enc_attn_mask,
        dec_input_ids,
        dec_attn_mask,
        tokentype_ids=None,
        labels=None,
        enc_hidden_states=None,
        enc_output_mask=None,
        output_enc_hidden_only=False,
        enc_input=None,
    ):
        """
        Return value is per token / per dimension (i.e., non collapsed loss value)
        """
        ret_dict = {}

        # encoder embeddings
        if enc_input is None:
            enc_position_ids = build_position_ids(enc_input_ids)
            enc_input = self.encoder_embedding(enc_input_ids,
                                               enc_position_ids,
                                               tokentype_ids=tokentype_ids)

        if output_enc_hidden_only:
            enc_output, enc_output_mask = self.enc_dec_model.encode(
                enc_input=enc_input,
                enc_attn_mask=enc_attn_mask,
                enc_layer_past=None,
                enc_get_key_value=False,
            )
            ret_dict["enc_output"] = enc_output
            ret_dict["enc_output_mask"] = enc_output_mask
        else:
            dec_position_ids = build_position_ids(dec_input_ids)
            dec_input = self.decoder_embedding(dec_input_ids,
                                               dec_position_ids,
                                               tokentype_ids=tokentype_ids)

            ret_dict.update(
                self.enc_dec_model(
                    enc_input=enc_input,
                    enc_attn_mask=enc_attn_mask,
                    dec_input=dec_input,
                    dec_attn_mask=dec_attn_mask,
                    enc_layer_past=None,
                    enc_get_key_value=False,
                    enc_output=enc_hidden_states,
                    enc_output_mask=enc_output_mask,
                    dec_layer_past=None,
                    dec_get_key_value=False,
                ))

            # project decoder output to vocabulary-size dimensions
            token_logits = self.tokens_head(
                ret_dict["dec_output"],
                self.decoder_embedding.word_embeddings.weight)
            # token_logits [batch, length, vocab_size]
            ret_dict["token_logits"] = token_logits

            if labels is not None:
                # tensor_parallel.vocab_parallel_cross_entropy performs log_softmax and return log p(x_i|z) per token i
                if self.fp16_cross_entropy:
                    assert token_logits.dtype == torch.half
                    tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(
                        token_logits, labels)
                else:
                    tokens_loss = tensor_parallel.vocab_parallel_cross_entropy(
                        token_logits.float(), labels)

                # tokens_loss [batch, length]
                ret_dict["tokens_loss"] = tokens_loss

        return ret_dict