def training(args, train_loader, valid_loader, model, optimizer, device):
    train_metrics = Accuracy()
    best_valid_acc = 0
    total_iter = 0
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(args.epochs):
        train_trange = tqdm(enumerate(train_loader),
                            total=len(train_loader),
                            desc='training')
        train_loss = 0
        train_metrics.reset()
        for i, batch in train_trange:
            model.train()
            prob = run_iter(batch, model, device, training=True)
            answer = batch['label'].to(device)
            loss = criterion(prob, answer)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_iter += 1
            train_loss += loss.item()
            train_metrics.update(prob, answer)
            train_trange.set_postfix(
                loss=train_loss / (i + 1),
                **{train_metrics.name: train_metrics.print_score()})

            if total_iter % args.eval_steps == 0:
                valid_acc = testing(valid_loader, model, device, valid=True)
                if valid_acc > best_valid_acc:
                    best_valid_acc = valid_acc
                    torch.save(
                        model,
                        os.path.join(
                            args.model_dir,
                            'fine-tuned_bert_{}.pkl'.format(args.seed)))

    # Final validation
    valid_acc = testing(valid_loader, model, device, valid=True)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
        torch.save(
            model,
            os.path.join(args.model_dir,
                         'fine-tuned_bert_{}.pkl'.format(args.seed)))
    print('Best Valid Accuracy:{}'.format(best_valid_acc))
def testing(dataloader, model, device, valid):
    metrics = Accuracy()
    criterion = torch.nn.CrossEntropyLoss()
    trange = tqdm(enumerate(dataloader),
                  total=len(dataloader),
                  desc='validation' if valid else 'testing')
    model.eval()
    total_loss = 0
    metrics.reset()
    for k, batch in trange:
        model.eval()
        prob = run_iter(batch, model, device, training=False)
        answer = batch['label'].to(device)
        loss = criterion(prob, batch['label'].to(device))
        total_loss += loss.item()
        metrics.update(prob, answer)
        trange.set_postfix(loss=total_loss / (k + 1),
                           **{metrics.name: metrics.print_score()})
    acc = metrics.match / metrics.n
    return acc
Ejemplo n.º 3
0
        optimizer.zero_grad()
        logits = net(train_batch)
        train_loss = ce_loss(logits, train_batch['labels'])
        # output, train_loss = utils.run_step(train_batch, net, tokenizer, ce_loss, device)

        train_loss.backward()
        optimizer.step()

        running_loss_train += train_loss.item()

        _, predicted = torch.max(logits, 1)
        metric_acc.update_batch(predicted, train_batch['labels'])

        if i % PRINT_EVERY == 0:
            train_accuracy = metric_acc.get_metrics_summary()
            metric_acc.reset()

            print(
                f'Epoch: {epoch+1}, Step: {i}/{n_iteration}, Accuracy: {train_accuracy}, \
                    Runningloss: {running_loss_train/PRINT_EVERY}')
            running_loss_train = 0

    with torch.no_grad():
        for i, val_batch in enumerate(val_dataloader):
            logits = net(val_batch)
            val_loss = ce_loss(logits, val_batch['labels'])
            # logits, val_loss = utils.run_step(val_batch, net, tokenizer, ce_loss, device)

            running_loss_val += val_loss.item()

            _, predicted = torch.max(logits, 1)