예제 #1
0
def test_step(network, data_loader, device):
    network.eval()

    top1 = AverageMeter()
    top5 = AverageMeter()
    loss_calculator = LossCalculator()

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = network(inputs)
            loss_calculator.calc_loss(outputs, targets)
            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

    return top1.avg, top5.avg, loss_calculator.get_loss_log()
예제 #2
0
def train_step(network, train_data_loader, test_data_loader, optimizer, device,
               epoch):
    network.train()
    # set benchmark flag to faster runtime
    torch.backends.cudnn.benchmark = True

    top1 = AverageMeter()
    top5 = AverageMeter()
    loss_calculator = LossCalculator()

    prev_time = datetime.now()

    for iteration, (inputs, targets) in enumerate(train_data_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = network(inputs)
        loss = loss_calculator.calc_loss(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

    cur_time = datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time: {:0>2d}:{:0>2d}:{:0>2d}".format(h, m, s)

    train_acc_str = '[Train] Top1: %2.4f, Top5: %2.4f, ' % (top1.avg, top5.avg)
    train_loss_str = 'Loss: %.4f. ' % loss_calculator.get_loss_log()

    test_top1, test_top5, test_loss = test_step(network, test_data_loader,
                                                device)

    test_acc_str = '[Test] Top1: %2.4f, Top5: %2.4f, ' % (test_top1, test_top5)
    test_loss_str = 'Loss: %.4f. ' % test_loss

    print('Epoch %d. ' % epoch + train_acc_str + train_loss_str +
          test_acc_str + test_loss_str + time_str)

    return None