예제 #1
0
def train_epoch(model, device, train_loader, criterion, optimizer, k, warm_up,
                lr, writer, epoch):
    # training phase
    print("Training Progress:")
    metrics = Metrics(args.dataset, train=True)
    model.train()

    for batch_idx, (batch, labels) in enumerate(tqdm(train_loader)):
        iteration = epoch * len(train_loader) + batch_idx
        optimizer.zero_grad()
        batch = batch.type(torch.FloatTensor).to(device)
        labels = labels.to(device)

        outputs = model(batch)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        # warm up
        if k <= warm_up:
            k = learning_rate_scheduler(optimizer, k, warm_up, lr)

        # Batch metrics
        metrics.update_metrics(outputs, labels, loss)
        if iteration % 10 == 0:
            metrics.write_to_tensorboard(writer, iteration)

    # Epoch metrics
    final_metrics = metrics.get_epoch_metrics()

    return (final_metrics, k)
예제 #2
0
def validate_epoch(model, device, validation_loader, criterion, scheduler,
                   writer, epoch):
    with torch.no_grad():
        # validation phase
        print("Validation Progress:")
        metrics = Metrics(args.dataset, train=False)
        model.eval()

        for batch_idx, (batch, labels) in enumerate(tqdm(validation_loader)):
            batch = batch.type(torch.FloatTensor).to(device)
            labels = labels.to(device)

            outputs = model(batch)
            loss = criterion(outputs, labels)

            # Batch metrics
            metrics.update_metrics(outputs, labels, loss)

        # Epoch metrics
        final_metrics = metrics.get_epoch_metrics()
        metrics.write_to_tensorboard(writer, epoch)
        scheduler.step(final_metrics["Loss"])

    return final_metrics