def test_step(network, data_loader, device): network.eval() top1 = AverageMeter() top5 = AverageMeter() loss_calculator = LossCalculator() with torch.no_grad(): for inputs, targets in data_loader: inputs, targets = inputs.to(device), targets.to(device) outputs = network(inputs) loss_calculator.calc_loss(outputs, targets) prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) return top1.avg, top5.avg, loss_calculator.get_loss_log()
def train_step(network, train_data_loader, test_data_loader, optimizer, device, epoch): network.train() # set benchmark flag to faster runtime torch.backends.cudnn.benchmark = True top1 = AverageMeter() top5 = AverageMeter() loss_calculator = LossCalculator() prev_time = datetime.now() for iteration, (inputs, targets) in enumerate(train_data_loader): inputs, targets = inputs.to(device), targets.to(device) outputs = network(inputs) loss = loss_calculator.calc_loss(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() prec1, prec5 = accuracy(outputs, targets, topk=(1, 5)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) cur_time = datetime.now() h, remainder = divmod((cur_time - prev_time).seconds, 3600) m, s = divmod(remainder, 60) time_str = "Time: {:0>2d}:{:0>2d}:{:0>2d}".format(h, m, s) train_acc_str = '[Train] Top1: %2.4f, Top5: %2.4f, ' % (top1.avg, top5.avg) train_loss_str = 'Loss: %.4f. ' % loss_calculator.get_loss_log() test_top1, test_top5, test_loss = test_step(network, test_data_loader, device) test_acc_str = '[Test] Top1: %2.4f, Top5: %2.4f, ' % (test_top1, test_top5) test_loss_str = 'Loss: %.4f. ' % test_loss print('Epoch %d. ' % epoch + train_acc_str + train_loss_str + test_acc_str + test_loss_str + time_str) return None