Example #1
0
def train(
    config: TrainConfig,
    model: BartForConditionalGeneration,
    train_dataloader: DataLoader,
    dev_dataloader: DataLoader,
    optimizer: Adam,
    logger: logging.Logger,
    device=torch.device,
):
    """ 지정된 Epoch만큼 모델을 학습시키는 함수입니다. """
    model.to(device)
    global_step = 0
    for epoch in range(1, config.num_epochs + 1):
        model.train()
        loss_sum = 0.0
        for data in train_dataloader:
            global_step += 1
            data = _change_device(data, device)
            optimizer.zero_grad()
            output = model.forward(
                input_ids=data[0],
                attention_mask=data[1],
                decoder_input_ids=data[2],
                labels=data[3],
                decoder_attention_mask=data[4],
                return_dict=True,
            )
            loss = output["loss"]
            loss.backward()
            loss_sum += loss.item()

            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            if global_step % config.train_log_interval == 0:
                mean_loss = loss_sum / config.train_log_interval
                logger.info(
                    f"Epoch {epoch} Step {global_step} " f"Loss {mean_loss:.4f} Perplexity {math.exp(mean_loss):8.2f}"
                )
                loss_sum = 0.0
            if global_step % config.dev_log_interval == 0:
                _validate(model, dev_dataloader, logger, device)
            if global_step % config.save_interval == 0:
                model.save_pretrained(f"{config.save_model_file_prefix}_{global_step}")
Example #2
0
def _validate(
    model: BartForConditionalGeneration,
    dev_dataloader: DataLoader,
    logger: logging.Logger,
    device: torch.device,
):
    model.eval()
    loss_sum = 0.0
    with torch.no_grad():
        for data in tqdm(dev_dataloader):
            data = _change_device(data, device)
            output = model.forward(
                input_ids=data[0],
                attention_mask=data[1],
                decoder_input_ids=data[2],
                labels=data[3],
                decoder_attention_mask=data[4],
                return_dict=True,
            )
            loss = output["loss"]
            loss_sum += loss.item()
    mean_loss = loss_sum / len(dev_dataloader)
    logger.info(f"[Validation] Loss {mean_loss:.4f} Perplexity {math.exp(mean_loss):8.2f}")
    model.train()
Example #3
0
dev_dataset = load_data(args.data, [f"randcmd{i}" for i in range(100,200)], tokenizer, max_seq_len, max_data_size=500)
print("Loaded dev data")

# initial eval
print("Initial eval")
n_val, avg_val_loss = eval_model(args, model, dev_dataset, tokenizer, eval_batchsize)
print(f"INIT, avg val loss: {avg_val_loss}")
best_val_loss = avg_val_loss

if args.eval_only:
    exit()

# training loop
print("Start training")
for i in range(args.epochs):
    model.train()
    lang_train_losses = []
    for j, (inputs, lang_tgts, init_state, tgt_state) in enumerate(convert_to_transformer_batches(args, dataset, tokenizer, batchsize)):
        optimizer.zero_grad()
        return_dict = model(
            input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'],
            decoder_input_ids=lang_tgts['input_ids'], labels=lang_tgts['input_ids'], return_dict=True,
        )
        lang_loss, dec_output, encoder_hidden = return_dict.loss, return_dict.logits, return_dict.encoder_last_hidden_state
        # encoder_outputs = (encoder_hidden,)
        lang_train_losses.append(lang_loss.item())
        lang_loss.backward()
        optimizer.step()
        if j%100 == 0:
            print(f"epoch {i}, batch {j}, loss: {lang_loss.item()}", flush=True)
            # break