class LightningRoberta(LightningModule):
    def __init__(self, config, pretrained_model_path=None):
        super().__init__()
        self.config = config
        if pretrained_model_path == None or pretrained_model_path == '':
            self.model = RobertaForMaskedLM(self.config)
        else:
            self.model = RobertaForMaskedLM.from_pretrained(
                pretrained_model_path, config=self.config)

    def set_learning_rate(self, lr):
        self.learning_rate = lr

    def set_mask_token_id(self, mask_id):
        self.mask_token_id = mask_id

    def forward(self, input_ids, attention_mask, labels=None):
        return self.model(input_ids,
                          attention_mask=attention_mask,
                          labels=labels)

    def training_step(self, batch, batch_idx):
        outputs = self(batch['input_ids'],
                       attention_mask=batch['attention_mask'],
                       labels=batch['labels'])
        masked_token_acc = masked_token_accuracy(outputs.logits,
                                                 batch['input_ids'],
                                                 batch['labels'],
                                                 self.mask_token_id)
        self.log("train_loss",
                 outputs.loss,
                 on_step=True,
                 on_epoch=True,
                 prog_bar=True,
                 logger=True)
        self.log("train_mask_acc",
                 masked_token_acc,
                 on_step=True,
                 on_epoch=True,
                 prog_bar=True,
                 logger=True)
        return {'loss': outputs.loss, 'masked_token_acc': masked_token_acc}

    def validation_step(self, batch, batch_idx):
        outputs = self(batch['input_ids'],
                       attention_mask=batch['attention_mask'],
                       labels=batch['labels'])
        masked_token_acc = masked_token_accuracy(outputs.logits,
                                                 batch['input_ids'],
                                                 batch['labels'],
                                                 self.mask_token_id)
        self.log_dict(
            {
                'val_loss': outputs.loss,
                'val_mask_acc': masked_token_acc
            },
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True)
        return {'loss': outputs.loss, 'masked_token_acc': masked_token_acc}

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(),
                                 lr=self.learning_rate)

    def init_metrics(self):
        self.train_results = {}
        self.val_results = {}
        # self.train_metrics_log = {}
        # self.val_metrics_log = {}

    def set_ckpt_folder(self, ckpt_folder):
        self.ckpt_folder = ckpt_folder

    def set_start_epoch(self, epoch_num):
        self.start_epoch = epoch_num

    def training_epoch_end(self, training_step_outputs):
        # self.train_metrics_log[f'epoch_{self.start_epoch+self.current_epoch+1}'] = training_step_outputs
        total_loss = torch.tensor(
            0,
            dtype=training_step_outputs[0]['loss'].dtype,
            device=training_step_outputs[0]['loss'].device)
        total_acc = torch.tensor(
            0,
            dtype=training_step_outputs[0]['masked_token_acc'].dtype,
            device=training_step_outputs[0]['masked_token_acc'].device)
        for step_output in training_step_outputs:
            total_loss += step_output['loss']
            total_acc += step_output['masked_token_acc']
        self.train_results[
            f'epoch_{self.start_epoch+self.current_epoch+1}'] = {
                'loss': total_loss / len(training_step_outputs),
                'mask_acc': total_acc / len(training_step_outputs)
            }

    def validation_epoch_end(self, validation_step_outputs):
        # self.val_metrics_log[f'epoch_{self.start_epoch+self.current_epoch+1}'] = validation_step_outputs
        total_loss = torch.tensor(
            0,
            dtype=validation_step_outputs[0]['loss'].dtype,
            device=validation_step_outputs[0]['loss'].device)
        total_acc = torch.tensor(
            0,
            dtype=validation_step_outputs[0]['masked_token_acc'].dtype,
            device=validation_step_outputs[0]['masked_token_acc'].device)
        for step_output in validation_step_outputs:
            total_loss += step_output['loss']
            total_acc += step_output['masked_token_acc']
        self.val_results[f'epoch_{self.start_epoch+self.current_epoch+1}'] = {
            'loss': total_loss / len(validation_step_outputs),
            'mask_acc': total_acc / len(validation_step_outputs)
        }

    def on_train_epoch_end(self):
        completed_epoch = self.start_epoch + self.current_epoch + 1

        print_str = ">>>>"
        print_str += " train_loss: " + str(
            round(
                self.train_results[f'epoch_{completed_epoch}']['loss'].item(),
                4))
        print_str += " | train_mask_acc: " + str(
            round(
                self.train_results[f'epoch_{completed_epoch}']
                ['mask_acc'].item(), 4))
        print_str += " | val_loss: " + str(
            round(self.val_results[f'epoch_{completed_epoch}']['loss'].item(),
                  4))
        print_str += " | val_mask_acc: " + str(
            round(
                self.val_results[f'epoch_{completed_epoch}']
                ['mask_acc'].item(), 4))

        ckpt_path = os.path.join(self.ckpt_folder, f"epoch-{completed_epoch}")
        self.model.save_pretrained(ckpt_path)

        print(print_str)
Ejemplo n.º 2
0
        val_loop.set_postfix({
            'val_loss': val_loss.item(),
            'val_mask_acc': mask_acc.item()
        })

    avg_val_loss = total_val_loss / len(val_dataloader)
    avg_val_mask_acc = total_val_mask_acc / len(val_dataloader)
    print(
        f'\nloss: {avg_train_loss} | mask_acc: {avg_train_mask_acc} | val_loss: {avg_val_loss} | val_mask_acc: {avg_val_mask_acc}'
    )

    results[f'epoch_{epoch+1}'] = {
        'train_loss': avg_train_loss,
        'train_mask_acc': avg_train_mask_acc,
        'val_loss': avg_val_loss,
        'val_mask_acc': avg_val_mask_acc
    }
    ckpt_path = os.path.join(model_folder, f"epoch-{epoch+1}")
    model.save_pretrained(ckpt_path)

with open(config_path, 'w') as json_f:
    json.dump(Config, json_f)

with open(results_path, 'w') as json_f:
    json.dump(results, json_f)

end_time = datetime.now()
print(
    f"\n============= Total Training Time: {end_time - train_start_datetime} ============"
)