예제 #1
0
def train(
    config: TrainConfig,
    model: BartForConditionalGeneration,
    train_dataloader: DataLoader,
    dev_dataloader: DataLoader,
    optimizer: Adam,
    logger: logging.Logger,
    device=torch.device,
):
    """ 지정된 Epoch만큼 모델을 학습시키는 함수입니다. """
    model.to(device)
    global_step = 0
    for epoch in range(1, config.num_epochs + 1):
        model.train()
        loss_sum = 0.0
        for data in train_dataloader:
            global_step += 1
            data = _change_device(data, device)
            optimizer.zero_grad()
            output = model.forward(
                input_ids=data[0],
                attention_mask=data[1],
                decoder_input_ids=data[2],
                labels=data[3],
                decoder_attention_mask=data[4],
                return_dict=True,
            )
            loss = output["loss"]
            loss.backward()
            loss_sum += loss.item()

            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            if global_step % config.train_log_interval == 0:
                mean_loss = loss_sum / config.train_log_interval
                logger.info(
                    f"Epoch {epoch} Step {global_step} " f"Loss {mean_loss:.4f} Perplexity {math.exp(mean_loss):8.2f}"
                )
                loss_sum = 0.0
            if global_step % config.dev_log_interval == 0:
                _validate(model, dev_dataloader, logger, device)
            if global_step % config.save_interval == 0:
                model.save_pretrained(f"{config.save_model_file_prefix}_{global_step}")
예제 #2
0
            n_val += len(inputs['input_ids'])

    print("n_val", n_val)
    avg_val_loss = tot_val_loss.item() / n_val
    return n_val, avg_val_loss


tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
if pretrained:
    model = BartForConditionalGeneration.from_pretrained('facebook/bart-base', dropout=args.dropout)
else:
    config = BartConfig.from_pretrained('facebook/bart-base')
    config.dropout = args.dropout
    model = BartForConditionalGeneration(config)
model.to(DEVICE)
optimizer = AdamW(list(model.parameters()), lr=args.lr)
print("Loaded model")

# TODO load data
dataset = load_data(args.data, ["walkthrough0"] + [f"randcmd{i}" for i in range(100)], tokenizer, max_seq_len, max_data_size=4000)
print("Loaded train data")
dev_dataset = load_data(args.data, [f"randcmd{i}" for i in range(100,200)], tokenizer, max_seq_len, max_data_size=500)
print("Loaded dev data")

# initial eval
print("Initial eval")
n_val, avg_val_loss = eval_model(args, model, dev_dataset, tokenizer, eval_batchsize)
print(f"INIT, avg val loss: {avg_val_loss}")
best_val_loss = avg_val_loss

if args.eval_only: