def train(dataloader, validloader, net, nepoch=10): start_epoch = 0 loss_function = nn.MSELoss() optimizer = optim.Adam(net.parameters(), lr=opt.lr) useGPU = torch.cuda.is_available() and not opt.cpu if useGPU: loss_function.cuda() if len(opt.weights) > 0: # load previous weights? checkpoint = torch.load(opt.weights) print('loading checkpoint', opt.weights) if opt.undomulti: checkpoint['state_dict'] = remove_dataparallel_wrapper( checkpoint['state_dict']) else: net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] if len(opt.scheduler) > 0: stepsize, gamma = int(opt.scheduler.split(',')[0]), float( opt.scheduler.split(',')[1]) scheduler = optim.lr_scheduler.StepLR(optimizer, stepsize, gamma=gamma, last_epoch=start_epoch - 1) count = 0 opt.t0 = time.perf_counter() for epoch in range(start_epoch, nepoch): mean_loss = 0 for i, bat in enumerate(dataloader): lr, hr = bat[0], bat[1] optimizer.zero_grad() if useGPU: sr = net(lr.cuda()) else: sr = net(lr) loss = loss_function(sr, hr.cuda()) loss.backward() optimizer.step() ######### Status and display ######### mean_loss += loss.data.item() print( '\r[%d/%d][%d/%d] Loss: %0.6f' % (epoch + 1, nepoch, i + 1, len(dataloader), loss.data.item()), end='') count += 1 if opt.log and count * opt.batchSize // 1000 > 0: t1 = time.perf_counter() - opt.t0 mem = torch.cuda.memory_allocated() print(epoch, count * opt.batchSize, t1, mem, mean_loss / count, file=opt.train_stats) opt.train_stats.flush() count = 0 # ---------------- Scheduler ----------------- if len(opt.scheduler) > 0: scheduler.step() for param_group in optimizer.param_groups: print('\nLearning rate', param_group['lr']) break # ---------------- Printing ----------------- print('\nEpoch %d done, %0.6f' % (epoch, (mean_loss / len(dataloader)))) print('\nEpoch %d done, %0.6f' % (epoch, (mean_loss / len(dataloader))), file=opt.fid) opt.fid.flush() if opt.log: opt.writer.add_scalar('data/mean_loss', mean_loss / len(dataloader), epoch) # ---------------- TEST ----------------- if (epoch + 1) % opt.testinterval == 0: testAndMakeCombinedPlots(net, validloader, opt, epoch) # if opt.scheduler: # scheduler.step(mean_loss / len(dataloader)) if (epoch + 1) % opt.saveinterval == 0: # torch.save(net.state_dict(), opt.out + '/prelim.pth') checkpoint = { 'epoch': epoch + 1, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(checkpoint, opt.out + '/prelim.pth') checkpoint = { 'epoch': nepoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(checkpoint, opt.out + '/final.pth')
dataloader, validloader = GetDataloaders(opt) net = GetModel(opt) if opt.log: opt.writer = SummaryWriter( comment='_%s_%s' % (opt.out.replace('\\', '/').split('/')[-1], opt.model)) opt.train_stats = open( opt.out.replace('\\', '/') + '/train_stats.csv', 'w') opt.test_stats = open( opt.out.replace('\\', '/') + '/test_stats.csv', 'w') print('iter,nsample,time,memory,meanloss', file=opt.train_stats) print('iter,time,memory,psnr,ssim', file=opt.test_stats) import time t0 = time.perf_counter() if not opt.test: train(dataloader, validloader, net, nepoch=opt.nepoch) else: if len(opt.weights) > 0: # load previous weights? checkpoint = torch.load(opt.weights) print('loading checkpoint', opt.weights) if opt.undomulti: checkpoint['state_dict'] = remove_dataparallel_wrapper( checkpoint['state_dict']) net.load_state_dict(checkpoint['state_dict']) print('time: ', time.perf_counter() - t0) testAndMakeCombinedPlots(net, validloader, opt) print('time: ', time.perf_counter() - t0)
def main(opt): opt.device = torch.device( 'cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu') os.makedirs(opt.out, exist_ok=True) shutil.copy2('options.py', opt.out) opt.fid = open(opt.out + '/log.txt', 'w') ostr = 'ARGS: ' + ' '.join(sys.argv[:]) print(opt, '\n') print(opt, '\n', file=opt.fid) print('\n%s\n' % ostr) print('\n%s\n' % ostr, file=opt.fid) print('getting dataloader', opt.root) dataloader, validloader = GetDataloaders(opt) if opt.log: opt.writer = SummaryWriter( log_dir=opt.out, comment='_%s_%s' % (opt.out.replace('\\', '/').split('/')[-1], opt.model)) opt.train_stats = open( opt.out.replace('\\', '/') + '/train_stats.csv', 'w') opt.test_stats = open( opt.out.replace('\\', '/') + '/test_stats.csv', 'w') print('iter,nsample,time,memory,meanloss', file=opt.train_stats) print('iter,time,memory,psnr,ssim', file=opt.test_stats) t0 = time.perf_counter() net = GetModel(opt) if not opt.test: train(opt, dataloader, validloader, net) # torch.save(net.state_dict(), opt.out + '/final.pth') else: if len(opt.weights) > 0: # load previous weights? checkpoint = torch.load(opt.weights) print('loading checkpoint', opt.weights) net.load_state_dict(checkpoint['state_dict']) print('time: %0.1f' % (time.perf_counter() - t0)) testAndMakeCombinedPlots(net, validloader, opt) opt.fid.close() if not opt.test: generate_convergence_plots(opt, opt.out + '/log.txt') print('time: %0.1f' % (time.perf_counter() - t0)) # optional clean up if opt.disposableTrainingData and not opt.test: print('deleting training data') # preserve a few samples os.makedirs('%s/training_data_subset' % opt.out, exist_ok=True) samplecount = 0 for file in glob.glob('%s/*' % opt.root): if os.path.isfile(file): basename = os.path.basename(file) shutil.copy2( file, '%s/training_data_subset/%s' % (opt.out, basename)) samplecount += 1 if samplecount == 10: break shutil.rmtree(opt.root)
def train(opt, dataloader, validloader, net): start_epoch = 0 if opt.task == 'segment' or opt.task == 'classification': loss_function = nn.CrossEntropyLoss() else: loss_function = nn.MSELoss() optimizer = optim.Adam(net.parameters(), lr=opt.lr) loss_function.cuda() if len(opt.weights) > 0: # load previous weights? checkpoint = torch.load(opt.weights) print('loading checkpoint', opt.weights) net.load_state_dict(checkpoint['state_dict']) if opt.lr == 1: # continue as it was optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] if len(opt.scheduler) > 0: # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=5, min_lr=0, eps=1e-08) stepsize, gamma = int(opt.scheduler.split(',')[0]), float( opt.scheduler.split(',')[1]) scheduler = optim.lr_scheduler.StepLR(optimizer, stepsize, gamma=gamma) if len(opt.weights) > 0: if 'scheduler' in checkpoint: scheduler.load_state_dict(checkpoint['scheduler']) opt.t0 = time.perf_counter() for epoch in range(start_epoch, opt.nepoch): count = 0 mean_loss = 0 # for param_group in optimizer.param_groups: # print('\nLearning rate', param_group['lr']) for i, bat in enumerate(dataloader): lr, hr = bat[0], bat[1] optimizer.zero_grad() sr = net(lr.to(opt.device)) loss = loss_function(sr, hr.to(opt.device)) loss.backward() optimizer.step() ######### Status and display ######### mean_loss += loss.data.item() print('\r[%d/%d][%d/%d] Loss: %0.6f' % (epoch + 1, opt.nepoch, i + 1, len(dataloader), loss.data.item()), end='') count += 1 if opt.log and count * opt.batchSize // 1000 > 0: t1 = time.perf_counter() - opt.t0 mem = torch.cuda.memory_allocated() opt.writer.add_scalar('data/mean_loss_per_1000', mean_loss / count, epoch) opt.writer.add_scalar('data/time_per_1000', t1, epoch) print(epoch, count * opt.batchSize, t1, mem, mean_loss / count, file=opt.train_stats) opt.train_stats.flush() count = 0 # ---------------- Scheduler ----------------- if len(opt.scheduler) > 0: scheduler.step() for param_group in optimizer.param_groups: print('\nLearning rate', param_group['lr']) break # ---------------- Printing ----------------- mean_loss = mean_loss / len(dataloader) t1 = time.perf_counter() - opt.t0 eta = (opt.nepoch - (epoch + 1)) * t1 / (epoch + 1) ostr = '\nEpoch [%d/%d] done, mean loss: %0.6f, time spent: %0.1fs, ETA: %0.1fs' % ( epoch + 1, opt.nepoch, mean_loss, t1, eta) print(ostr) print(ostr, file=opt.fid) opt.fid.flush() if opt.log: opt.writer.add_scalar('data/mean_loss', mean_loss / len(dataloader), epoch) # ---------------- TEST ----------------- if (epoch + 1) % opt.testinterval == 0: testAndMakeCombinedPlots(net, validloader, opt, epoch) if (epoch + 1) % opt.saveinterval == 0: # torch.save(net.state_dict(), opt.out + '/prelim.pth') checkpoint = { 'epoch': epoch + 1, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict() } if len(opt.scheduler) > 0: checkpoint['scheduler'] = scheduler.state_dict() torch.save(checkpoint, '%s/prelim%d.pth' % (opt.out, epoch + 1)) checkpoint = { 'epoch': opt.nepoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict() } if len(opt.scheduler) > 0: checkpoint['scheduler'] = scheduler.state_dict() torch.save(checkpoint, opt.out + '/final.pth')
def train(dataloader, validloader, net, nepoch=10): start_epoch = 0 loss_function = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=opt.lr) loss_function.cuda() loss_function_custom = nn.MSELoss() loss_function_custom.cuda() if len(opt.weights) > 0: # load previous weights? checkpoint = torch.load(opt.weights) print('loading checkpoint', opt.weights) if opt.undomulti: checkpoint['state_dict'] = remove_dataparallel_wrapper( checkpoint['state_dict']) if opt.modifyPretrainedModel: pretrained_dict = checkpoint['state_dict'] model_dict = net.state_dict() # 1. filter out unnecessary keys for k, v in list(pretrained_dict.items()): print(k) pretrained_dict = { k: v for k, v in list(pretrained_dict.items())[:-2] } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict net.load_state_dict(model_dict) # optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] else: net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] if len(opt.scheduler) > 0: stepsize, gamma = int(opt.scheduler.split(',')[0]), float( opt.scheduler.split(',')[1]) scheduler = optim.lr_scheduler.StepLR(optimizer, stepsize, gamma=gamma, last_epoch=start_epoch - 1) count = 0 opt.t0 = time.perf_counter() for epoch in range(start_epoch, nepoch): mean_loss = 0 for i, bat in enumerate(dataloader): lr, hr = bat[0], bat[1] optimizer.zero_grad() if opt.model == 'ffdnet': stdvec = torch.zeros(lr.shape[0]) for j in range(lr.shape[0]): noise = lr[j] - hr[j] stdvec[j] = torch.std(noise) noise = net(lr.cuda(), stdvec.cuda()) sr = torch.clamp(lr.cuda() - noise, 0, 1) gt_noise = lr.cuda() - hr.cuda() loss = loss_function(noise, gt_noise) elif opt.task == 'residualdenoising': noise = net(lr.cuda()) gt_noise = lr.cuda() - hr.cuda() loss = loss_function(noise, gt_noise) else: sr = net(lr.cuda()) if opt.task == 'segment': hr_classes = torch.round(2 * hr).long() loss = loss_function(sr.squeeze(), hr_classes.squeeze().cuda()) else: loss = loss_function(sr, hr.cuda()) loss.backward() optimizer.step() ######### Status and display ######### mean_loss += loss.data.item() print( '\r[%d/%d][%d/%d] Loss: %0.6f' % (epoch + 1, nepoch, i + 1, len(dataloader), loss.data.item()), end='') count += 1 if opt.log and count * opt.batchSize // 1000 > 0: t1 = time.perf_counter() - opt.t0 mem = torch.cuda.memory_allocated() print(epoch, count * opt.batchSize, t1, mem, mean_loss / count, file=opt.train_stats) opt.train_stats.flush() count = 0 # ---------------- Scheduler ----------------- if len(opt.scheduler) > 0: scheduler.step() for param_group in optimizer.param_groups: print('\nLearning rate', param_group['lr']) break # ---------------- Printing ----------------- print('\nEpoch %d done, %0.6f' % (epoch, (mean_loss / len(dataloader)))) print('\nEpoch %d done, %0.6f' % (epoch, (mean_loss / len(dataloader))), file=opt.fid) opt.fid.flush() # ---------------- TEST ----------------- if (epoch + 1) % opt.testinterval == 0: testAndMakeCombinedPlots(net, validloader, opt, epoch) # if opt.scheduler: # scheduler.step(mean_loss / len(dataloader)) if (epoch + 1) % opt.saveinterval == 0: # torch.save(net.state_dict(), opt.out + '/prelim.pth') checkpoint = { 'epoch': epoch + 1, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(checkpoint, opt.out + '/prelim.pth') checkpoint = { 'epoch': nepoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(checkpoint, opt.out + '/final.pth')