shuffle=False) model = Net(upscale_factor=UPSCALE_FACTOR) criterion = AdjacentFrameLoss() if torch.cuda.is_available(): model = model.cuda() criterion = criterion.cuda() print('# parameters:', sum(param.numel() for param in model.parameters())) optimizer = optim.Adam(model.parameters(), lr=1e-2) scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1) engine = Engine() meter_loss = tnt.meter.AverageValueMeter() meter_psnr = PSNRMeter() train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'}) train_psnr_logger = VisdomPlotLogger('line', opts={'title': 'Train PSNR'}) val_loss_logger = VisdomPlotLogger('line', opts={'title': 'Val Loss'}) val_psnr_logger = VisdomPlotLogger('line', opts={'title': 'Val PSNR'}) engine.hooks['on_sample'] = on_sample engine.hooks['on_forward'] = on_forward engine.hooks['on_start_epoch'] = on_start_epoch engine.hooks['on_end_epoch'] = on_end_epoch engine.train(processor, train_loader, maxepoch=NUM_EPOCHS, optimizer=optimizer)
def main(factor): global meter_loss global meter_psnr global scheduler global engine global epoch_num global psnr_value global loss_value global train_loader global val_loader global model global criterion global UPSCALE_FACTOR parser = argparse.ArgumentParser(description='Super Resolution Training') parser.add_argument('--upscale_factor', default=3, type=int, help='super resolution upscale factor') parser.add_argument('--num_epochs', default=100, type=int, help='super resolution epochs number') opt = parser.parse_args() UPSCALE_FACTOR = opt.upscale_factor NUM_EPOCHS = opt.num_epochs if factor != 3: UPSCALE_FACTOR = factor train_set = DatasetFromFolder('data/train', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(), target_transform=transforms.ToTensor()) val_set = DatasetFromFolder('data/val', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(), target_transform=transforms.ToTensor()) train_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=64, shuffle=True) val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=64, shuffle=False) model = SPCNNet(upscale_factor=UPSCALE_FACTOR) criterion = nn.MSELoss() if torch.cuda.is_available(): model = model.cuda() criterion = criterion.cuda() print('# upscale factor:', UPSCALE_FACTOR) print('# parameters:', sum(param.numel() for param in model.parameters())) optimizer = optim.Adam(model.parameters(), lr=1e-3) scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1) engine = Engine() meter_loss = tnt.meter.AverageValueMeter() meter_psnr = PSNRMeter() epoch_num = [] psnr_value = [] loss_value = [] engine.hooks['on_sample'] = on_sample engine.hooks['on_forward'] = on_forward engine.hooks['on_start_epoch'] = on_start_epoch engine.hooks['on_end_epoch'] = on_end_epoch engine.train(processor, train_loader, maxepoch=NUM_EPOCHS, optimizer=optimizer) plt.plot(epoch_num, psnr_value, lw=2, ls='-', label="PSNR--x"+str(UPSCALE_FACTOR), color="r", marker="+") plt.xlabel("epoch time(s)", fontsize=16, horizontalalignment="right") plt.ylabel("PSNR value", fontsize=16, horizontalalignment="right") plt.legend() plt.savefig('D:\大三上\数字图像处理\SR_Project\plots\PSNRx'+str(UPSCALE_FACTOR)+'.png') plt.show() plt.plot(epoch_num, loss_value, lw=2, ls='-', label="Loss--x"+str(UPSCALE_FACTOR), color="r", marker="+") plt.xlabel("epoch time(s)", fontsize=16, horizontalalignment="right") plt.ylabel("Loss value", fontsize=16, horizontalalignment="right") plt.legend() plt.savefig('D:\大三上\数字图像处理\SR_Project\plots\LOSSx'+str(UPSCALE_FACTOR)+'.png') plt.show()