Exemple #1
0
    def training_step(self, batch, batch_idx):
        tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = self.process_batch(
            batch)
        if not self.cfg.bert_binary_head:
            types = None
        output_tensor = self(tokens,
                             padding_mask,
                             token_type_ids=types,
                             lm_labels=lm_labels)
        loss_dict = self.loss_func(loss_mask, sentence_order, output_tensor)
        if 'sop loss' in loss_dict:
            lm_loss = loss_dict['lm loss']
            sop_loss = loss_dict['sop loss']
            loss = lm_loss + sop_loss
            reduced_loss = average_losses_across_data_parallel_group(
                [loss, lm_loss, sop_loss])
            self._reduced_loss_buffer.append(reduced_loss[0])
            self._reduced_lm_loss_buffer.append(reduced_loss[1])
            self._reduced_sop_loss_buffer.append(reduced_loss[2])
        else:
            lm_loss = loss_dict['lm loss']
            loss = lm_loss
            reduced_loss = average_losses_across_data_parallel_group(
                [loss, lm_loss])
            self._reduced_loss_buffer.append(reduced_loss[0])
            self._reduced_lm_loss_buffer.append(reduced_loss[1])

        if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
            # Reduced loss for logging.
            average_reduced_loss = sum(self._reduced_loss_buffer) / len(
                self._reduced_loss_buffer)
            self.log('reduced_train_loss', average_reduced_loss, prog_bar=True)
            if len(self._reduced_sop_loss_buffer) > 0:
                average_reduced_lm_loss = sum(
                    self._reduced_lm_loss_buffer) / len(
                        self._reduced_lm_loss_buffer)
                average_reduced_sop_loss = sum(
                    self._reduced_sop_loss_buffer) / len(
                        self._reduced_sop_loss_buffer)
                self.log('reduced_lm_train_loss',
                         average_reduced_lm_loss,
                         prog_bar=True)
                self.log('reduced_sop_train_loss',
                         average_reduced_sop_loss,
                         prog_bar=True)
            lr = self._optimizer.param_groups[0]['lr']
            self.log('lr', lr)
            self.log('global_step', self.trainer.global_step, prog_bar=True)
            self.log(
                'consumed_samples',
                self.compute_consumed_samples(self.trainer.global_step -
                                              self.init_global_step),
                prog_bar=True,
            )
            self._reduced_loss_buffer = []
            self._reduced_lm_loss_buffer = []
            self._reduced_sop_loss_buffer = []
        return loss
Exemple #2
0
 def test_epoch_end(self, outputs):
     averaged_loss = average_losses_across_data_parallel_group(outputs)
     logging.info(f'test_loss: {averaged_loss[0]}')
     self.log(
         'consumed_samples', self.compute_consumed_samples(self.trainer.global_step - self.init_global_step),
     )
     return averaged_loss
Exemple #3
0
    def inference_epoch_end(self, outputs):
        losses = [x['loss'] for x in outputs]
        averaged_loss = average_losses_across_data_parallel_group(losses)
        all_preds = []
        all_labels = []
        for item in outputs:
            preds = item['predicted_token_ids'].cpu().numpy().tolist()
            labels = item['labels'].cpu().numpy().tolist()
            for i, (pred, label) in enumerate(zip(preds, labels)):
                if self.tokenizer.eos_id in pred:
                    idx = pred.index(self.tokenizer.eos_id)
                    pred = pred[:idx]
                pred = [
                    id for id in pred
                    if id not in self.tokenizer.special_token_to_id.values()
                ]
                label = [
                    id for id in label
                    if id not in self.tokenizer.special_token_to_id.values()
                ]
                pred = self.tokenizer.ids_to_text(pred)
                label = self.tokenizer.ids_to_text(label)
                all_preds.append(pred)
                all_labels.append(label)

        correct = 0
        for pred, label in zip(all_preds, all_labels):
            if pred == label:
                correct += 1
        acc = correct / len(all_preds)
        return averaged_loss[0], acc
Exemple #4
0
    def training_step(self, batch, batch_idx):
        input_tokens_id = batch['tokens']
        input_attn_mask = batch['tokens_mask']
        loss_mask = batch['loss_mask']
        retrieved_ids = batch['retrieved_ids']
        retrieved_attn_mask = batch['retrieved_emb_mask']
        labels = batch['labels']

        loss = self(input_tokens_id, input_attn_mask, retrieved_ids, retrieved_attn_mask, labels=labels)
        loss_mask = loss_mask.float()
        lm_loss = torch.sum(loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
        reduced_loss = average_losses_across_data_parallel_group([lm_loss])
        self._reduced_loss_buffer.append(reduced_loss[0])

        if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
            # Reduced loss for logging.
            average_reduced_loss = sum(self._reduced_loss_buffer) / len(self._reduced_loss_buffer)
            self.log('reduced_train_loss', average_reduced_loss, prog_bar=True)
            lr = self._optimizer.param_groups[0]['lr']
            self.log('lr', lr)
            self.log('global_step', self.trainer.global_step, prog_bar=True)
            self.log(
                'consumed_samples',
                self.compute_consumed_samples(self.trainer.global_step - self.init_global_step),
                prog_bar=True,
            )
            self._reduced_loss_buffer = []
        return lm_loss
Exemple #5
0
    def training_step(self, batch, batch_idx):
        tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask = self.process_batch(
            batch)

        tokens_loss = itemgetter("tokens_loss")(self.model(
            tokens_enc,
            tokens_dec,
            enc_mask,
            dec_mask,
            tokentype_ids=None,
            lm_labels=labels,
        ))

        loss = self.model.loss_func(loss_mask, tokens_loss)
        self.log('train_loss', loss)
        # Reduced loss for logging.
        reduced_loss = average_losses_across_data_parallel_group([loss])
        # cache reduced loss while accumulating gradients
        self.model._reduced_loss_buffer.append(reduced_loss[0])

        if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
            # Reduced loss for logging.
            average_reduced_loss = sum(self.model._reduced_loss_buffer) / len(
                self.model._reduced_loss_buffer)
            self.log('reduced_train_loss', average_reduced_loss, prog_bar=True)
            lr = self._optimizer.param_groups[0]['lr']
            self.log('lr', lr)
            self.log('global_step', self.trainer.global_step, prog_bar=True)
            self.model._reduced_loss_buffer = []

        return loss
Exemple #6
0
    def training_step(self, batch, batch_idx):
        input_ids, labels, loss_mask, position_ids, attention_mask, taskname_ids = batch
        output = self.forward(input_ids,
                              position_ids,
                              attention_mask,
                              taskname_ids,
                              labels,
                              inference=False)
        output_tensor, encoder_hidden_states = output
        loss = self.frozen_model.loss_func(loss_mask, output_tensor)
        self.log('train_loss', loss)

        # Reduced loss for logging.
        reduced_loss = average_losses_across_data_parallel_group([loss])

        # Cache reduced loss while accumulating gradients
        self._reduced_loss_buffer.append(reduced_loss[0])

        if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
            # Reduced loss for logging.
            average_reduced_loss = sum(self._reduced_loss_buffer) / len(
                self._reduced_loss_buffer)
            self.log('reduced_train_loss', average_reduced_loss, prog_bar=True)
            lr = self._optimizer.param_groups[0]['lr']
            self.log('lr', lr)
            self.log('global_step', self.trainer.global_step, prog_bar=True)
            self._reduced_loss_buffer = []

        return loss
    def training_step(self, batch, batch_idx):
        tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask = self.process_batch(batch)

        tokens_loss = itemgetter("tokens_loss")(
            self(tokens_enc, tokens_dec, enc_mask, dec_mask, tokentype_ids=None, lm_labels=labels,)
        )

        loss = self.loss_func(loss_mask, tokens_loss)
        self.log('train_loss', loss)
        # Reduced loss for logging. This averages the loss across all workers unlike "loss" above which is specific to a DDP rank.
        reduced_loss = average_losses_across_data_parallel_group([loss])
        # cache reduced loss while accumulating gradients
        self._reduced_loss_buffer.append(reduced_loss[0])

        if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
            # Reduced loss for logging.
            average_reduced_loss = sum(self._reduced_loss_buffer) / len(self._reduced_loss_buffer)
            self.log('reduced_train_loss', average_reduced_loss, prog_bar=True)
            lr = self._optimizer.param_groups[0]['lr']
            self.log('lr', lr)
            self.log('global_step', self.trainer.global_step, prog_bar=True)
            self.log(
                'consumed_samples',
                self.compute_consumed_samples(self.trainer.global_step - self.init_global_step),
                prog_bar=True,
            )
            self._reduced_loss_buffer = []

        return loss
Exemple #8
0
 def inference_epoch_end(self, outputs):
     losses = [x['loss'] for x in outputs]
     averaged_loss = average_losses_across_data_parallel_group(losses)
     val_acc = self.acc_metric.compute()
     self.log('validation_loss', averaged_loss)
     self.log('validation_acc', val_acc['acc'])
     self.acc_metric.reset()
     return averaged_loss[0], val_acc
    def validation_step(self, batch, batch_idx):
        tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask = self.process_batch(batch)

        tokens_loss = itemgetter("tokens_loss")(
            self(tokens_enc, tokens_dec, enc_mask, dec_mask, tokentype_ids=None, lm_labels=labels,)
        )
        loss = self.loss_func(loss_mask, tokens_loss)
        reduced_loss = average_losses_across_data_parallel_group([loss])
        return reduced_loss
Exemple #10
0
 def inference_epoch_end(self, outputs):
     losses = [x['loss'] for x in outputs]
     averaged_loss = average_losses_across_data_parallel_group(losses)
     acc_result = self.acc_metrics.compute()
     self.log('validation_loss', averaged_loss)
     self.log('validation_acc', acc_result['acc'])
     for lang in self.cfg.eval_languages:
         self.log(f'{lang}_acc', acc_result[lang])
     self.acc_metrics.reset()
     return averaged_loss[0], acc_result
Exemple #11
0
 def validation_step(self, batch, batch_idx):
     input_tokens_id = batch['tokens']
     input_attn_mask = batch['tokens_mask']
     loss_mask = batch['loss_mask']
     retrieved_ids = batch['retrieved_ids']
     retrieved_attn_mask = batch['retrieved_emb_mask']
     labels = batch['labels']
     loss = self(input_tokens_id, input_attn_mask, retrieved_ids, retrieved_attn_mask, labels=labels)
     loss_mask = loss_mask.float()
     lm_loss = torch.sum(loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
     reduced_loss = average_losses_across_data_parallel_group([lm_loss])
     return reduced_loss
Exemple #12
0
    def validation_step(self, batch, batch_idx):
        tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask = self.process_batch(
            batch)

        output_tensor, encoder_hidden_states = self(tokens_enc,
                                                    tokens_dec,
                                                    enc_mask,
                                                    dec_mask,
                                                    enc_dec_mask,
                                                    tokentype_ids=None,
                                                    lm_labels=labels)
        loss = self.loss_func(loss_mask, output_tensor)
        reduced_loss = average_losses_across_data_parallel_group([loss])
        return reduced_loss
    def training_step(self, batch, batch_idx):
        loss, _, _, _, _ = self.get_loss(batch)
        self.log('train_loss', loss)
        # Reduced loss for logging.
        reduced_loss = average_losses_across_data_parallel_group([loss])
        # cache reduced loss while accumulating gradients
        self._reduced_loss_buffer.append(reduced_loss[0])

        if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
            # Reduced loss for logging.
            average_reduced_loss = sum(self._reduced_loss_buffer) / len(self._reduced_loss_buffer)
            self.log('reduced_train_loss', average_reduced_loss, prog_bar=True)
            lr = self._optimizer.param_groups[0]['lr']
            self.log('lr', lr)
            self.log('global_step', self.trainer.global_step, prog_bar=True)
            self._reduced_loss_buffer = []

        return loss
Exemple #14
0
 def validation_step(self, batch, batch_idx):
     tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = self.process_batch(
         batch)
     if not self.cfg.bert_binary_head:
         types = None
     output_tensor = self(tokens,
                          padding_mask,
                          tokentype_ids=types,
                          lm_labels=lm_labels)
     loss_dict = self.loss_func(loss_mask, sentence_order, output_tensor)
     if 'sop loss' in loss_dict:
         lm_loss = loss_dict['lm loss']
         sop_loss = loss_dict['sop loss']
         loss = lm_loss + sop_loss
     else:
         lm_loss = loss_dict['lm loss']
         loss = lm_loss
     reduced_loss = average_losses_across_data_parallel_group([loss])
     return reduced_loss
    def validation_step(self, batch, batch_idx):
        tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch(
            batch)
        output_tensor = self(tokens, position_ids, attention_mask, labels)
        loss = self.loss_func(loss_mask, output_tensor)
        reduced_loss = average_losses_across_data_parallel_group([loss])

        # TODO: add text generation
        # take the first k tokens and then generate text - compare with ground truth (won't look similar in general)
        """
        k = num_context
        n = max_generate_length
        context_tokens = tokens[0:k]
        while k < n:
            output_tensor = self(context_tokens)
            next_token = sample(output_tensor)
            context_tokens.append(next_token)
            k += 1
        """
        return reduced_loss
Exemple #16
0
    def validation_step(self, batch, batch_idx):
        if self.use_soft_prompts:
            tokens, labels, prompt_tags, attention_mask, loss_mask, text_position_ids = batch

            tokens = tokens.to(self.device)
            labels = labels.to(self.device)
            attention_mask = attention_mask.to(self.device)
            loss_mask = loss_mask.to(self.device)
            text_position_ids = text_position_ids.to(self.device)

            output_tensor = self(tokens, text_position_ids, attention_mask,
                                 labels, prompt_tags)
        else:
            tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch(
                batch)
            output_tensor = self(tokens, position_ids, attention_mask, labels)

        loss = self.loss_func(loss_mask, output_tensor)
        reduced_loss = average_losses_across_data_parallel_group([loss])

        return reduced_loss
    def training_step(self, batch, batch_idx):
        tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch(
            batch)
        output_tensor = self(tokens, position_ids, attention_mask, labels)
        loss = self.loss_func(loss_mask, output_tensor)
        reduced_loss = average_losses_across_data_parallel_group([loss])

        # cache reduced loss while accumulating gradients
        self._reduced_loss_buffer.append(reduced_loss[0])

        if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
            # Reduced loss for logging.
            average_reduced_loss = sum(self._reduced_loss_buffer) / len(
                self._reduced_loss_buffer)
            self.log('reduced_train_loss', average_reduced_loss, prog_bar=True)
            lr = self._optimizer.param_groups[0]['lr']
            self.log('lr', lr)
            self.log('global_step', self.trainer.global_step, prog_bar=True)
            self.log('consumed_samples',
                     self.compute_consumed_samples(self.trainer.global_step),
                     prog_bar=True)
            self._reduced_loss_buffer = []
        return loss
Exemple #18
0
 def test_epoch_end(self, outputs):
     averaged_loss = average_losses_across_data_parallel_group(outputs)
     logging.info(f'test_loss: {averaged_loss[0]}')
 def loss_func(output_tensor):
     loss = self.loss_func(loss_mask, output_tensor)
     reduced_loss = average_losses_across_data_parallel_group([loss])
     return loss, {'avg': reduced_loss}
Exemple #20
0
    def eval_epoch_end(self, outputs, mode):
        if isinstance(outputs[0], dict):
            outputs = [outputs]

        loss_list = []
        bleu_score_list = []
        for dataloader_idx, output in enumerate(outputs):
            averaged_loss = average_losses_across_data_parallel_group(
                [x['loss'] for x in output])
            inputs = list(itertools.chain(*[x['inputs'] for x in output]))
            translations = list(
                itertools.chain(*[x['translations'] for x in output]))
            ground_truths = list(
                itertools.chain(*[x['ground_truths'] for x in output]))
            assert len(translations) == len(inputs)
            assert len(translations) == len(ground_truths)

            # Gather translations and ground truths from all workers
            tr_gt_inp = [
                None
                for _ in range(parallel_state.get_data_parallel_world_size())
            ]
            # we also need to drop pairs where ground truth is an empty string
            torch.distributed.all_gather_object(
                tr_gt_inp,
                [(t, g, i)
                 for (t, g, i) in zip(translations, ground_truths, inputs)],
                group=parallel_state.get_data_parallel_group(),
            )
            if parallel_state.get_data_parallel_rank() == 0:
                _translations = []
                _ground_truths = []
                _inputs = []

                # Deduplicate sentences that may have been distributed across multiple data parallel ranks.
                gt_inp_set = set()
                for rank in range(
                        0, parallel_state.get_data_parallel_world_size()):
                    for t, g, i in tr_gt_inp[rank]:
                        if g + i not in gt_inp_set:
                            gt_inp_set.add(g + i)
                            _translations.append(t)
                            _ground_truths.append(g)
                            _inputs.append(i)

                if self.tgt_language in ['ja']:
                    sacre_bleu = corpus_bleu(_translations, [_ground_truths],
                                             tokenize="ja-mecab")
                elif self.tgt_language in ['zh']:
                    sacre_bleu = corpus_bleu(_translations, [_ground_truths],
                                             tokenize="zh")
                else:
                    sacre_bleu = corpus_bleu(_translations, [_ground_truths],
                                             tokenize="13a")

                bleu_score = sacre_bleu.score * parallel_state.get_data_parallel_world_size(
                )

                dataset_name = "Validation" if mode == 'val' else "Test"
                logging.info(
                    f"{dataset_name}, Dataloader index: {dataloader_idx}, Set size: {len(_translations)}"
                )
                logging.info(
                    f"{dataset_name}, Dataloader index: {dataloader_idx}, SacreBLEU = {bleu_score / parallel_state.get_data_parallel_world_size()}"
                )
                logging.info(
                    f"{dataset_name}, Dataloader index: {dataloader_idx}, Translation Examples:"
                )
                logging.info(
                    '============================================================'
                )
                for example_idx in range(0, 3):
                    random_index = random.randint(0, len(_translations) - 1)
                    logging.info("    " +
                                 '\u0332'.join(f"Example {example_idx}:"))
                    logging.info(f"    Input:        {_inputs[random_index]}")
                    logging.info(
                        f"    Prediction:   {_translations[random_index]}")
                    logging.info(
                        f"    Ground Truth: {_ground_truths[random_index]}")
                    logging.info(
                        '============================================================'
                    )

            else:
                bleu_score = 0.0

            loss_list.append(averaged_loss[0].cpu().numpy())
            bleu_score_list.append(bleu_score)
            if dataloader_idx == 0:
                self.log(f'{mode}_sacreBLEU', bleu_score, sync_dist=True)
                self.log(f'{mode}_loss', averaged_loss[0], prog_bar=True)
                if self.multilingual:
                    self._log_multilingual_bleu_and_loss(
                        dataloader_idx, bleu_score, averaged_loss[0], mode)
            else:
                if self.multilingual:
                    self._log_multilingual_bleu_and_loss(
                        dataloader_idx, bleu_score, averaged_loss[0], mode)
                else:
                    self.log(f'{mode}_sacreBLEU_dl_index_{dataloader_idx}',
                             bleu_score,
                             sync_dist=True)
                    self.log(f'{mode}_loss_dl_index_{dataloader_idx}',
                             averaged_loss[0],
                             prog_bar=False)

        if len(loss_list) > 1:
            self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True)
            self.log(f"{mode}_sacreBLEU_avg",
                     np.mean(bleu_score_list),
                     sync_dist=True)
 def validation_epoch_end(self, outputs):
     averaged_loss = average_losses_across_data_parallel_group(outputs)
     self.log('val_loss', averaged_loss[0], prog_bar=True)
     self.log('consumed_samples', self.compute_consumed_samples(self.trainer.global_step - self.init_global_step))
Exemple #22
0
    def validation_epoch_end(self, outputs):
        averaged_loss = average_losses_across_data_parallel_group(outputs)

        # we can only log on one rank if it is rank zero so we broadcast from last rank
        torch.distributed.broadcast(averaged_loss, get_last_rank())
        self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True)