def train(args, train_loader, model, criterion, optimizer, epoch, progress, train_time): batch_time = AverageMeter() data_time = AverageMeter() model.train() # sets the module in training mode correct = 0 end = time.time() for batch_idx, (data, target) in enumerate(train_loader): # Measure data loading time data_time.update(time.time() - end) data, target = data.to(device), target.to(device) optimizer.zero_grad() # zeroes the gradient buffers of all parameters output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() pred = output.data.max(1)[1] correct += pred.eq(target.data).sum().item() # Measure elapsed time batch_time.update(time.time() - end) end = time.time() # Print log if (batch_idx + 1) % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) train_time.update(batch_time.get_sum()) # Save progress train_acc = 100. * correct / len(train_loader.dataset) progress['train'].append( (epoch, loss.item(), train_acc, batch_time.get_sum(), batch_time.get_avg(), data_time.get_sum(), data_time.get_avg()))
# Train and record progress progress = {} progress['train'] = [] progress['test'] = [] train_time = AverageMeter() test_time = AverageMeter() print('==> Start training..') for epoch in range(start_epoch, start_epoch + args.epochs): adjust_learning_rate(optimizer, lr, epoch, milestones) train(args, train_loader, model, criterion, optimizer, epoch, progress, train_time) test(args, test_loader, model, criterion, epoch, progress, best_acc, test_time) progress['train_time'] = (train_time.get_avg(), train_time.get_sum()) # record average epoch time and total training time progress['test_time'] = (test_time.get_avg() / len(test_loader.dataset), test_time.get_avg()) # record average test time per image and average test time per test_loader.dataset # Save progress import pickle current_time = get_current_time() pickle.dump( progress, open( './' + args.model + ('-resume' if args.resume else '') + '_progress_' + current_time + '.pkl', 'wb'))