def main(lr, batch_size, epoch, gpu, train_set, valid_set): # ------------- Part for tensorboard -------------- writer = SummaryWriter(comment="_equal_CZAR") # ------------- Part for tensorboard -------------- # -------------- Some prepare --------------------- torch.backends.cudnn.enabled = True torch.cuda.set_device(gpu) # torch.set_default_tensor_type('torch.cuda.FloatTensor') # -------------- Some prepare --------------------- BATCH_SIZE = batch_size EPOCH = epoch LEARNING_RATE = lr belta1 = 0.9 belta2 = 0.999 trainset = mydataset(train_set, transform_train) valset = mydataset(valid_set) trainLoader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) valLoader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=False) opter = Opter(128, 128, batch_size) SepConvNet = Network(opter).cuda() SepConvNet.apply(weights_init) # SepConvNet.load_state_dict(torch.load('/mnt/hdd/xiasifeng/sepconv/sepconv_mutiscale_LD/SepConv_iter33-ltype_fSATD_fs-lr_0.001-trainloss_0.1497-evalloss_0.1357-evalpsnr_29.6497.pkl')) # SepConvNet_cost = nn.MSELoss().cuda() # SepConvNet_cost = nn.L1Loss().cuda() SepConvNet_cost = sepconv.SATDLoss().cuda() SepConvNet_optimizer = optim.Adamax(SepConvNet.parameters(), lr=LEARNING_RATE, betas=(belta1, belta2)) SepConvNet_schedule = optim.lr_scheduler.ReduceLROnPlateau( SepConvNet_optimizer, factor=0.1, patience=3, verbose=True, min_lr=1e-5) # ---------------- Time part ------------------- start_time = time.time() global_step = 0 # ---------------- Time part ------------------- for epoch in range(0, EPOCH): SepConvNet.train().cuda() cnt = 0 sumloss = 0.0 # The sumloss is for the whole training_set tsumloss = 0.0 # The tsumloss is for the printinterval printinterval = 300 print("---------------[Epoch%3d]---------------" % (epoch + 1)) for imgL, imgR, label in trainLoader: global_step = global_step + 1 cnt = cnt + 1 SepConvNet_optimizer.zero_grad() imgL = var(imgL).cuda() imgR = var(imgR).cuda() label = var(label).cuda() with torch.no_grad(): # Remember here we need the back-forward flow diff = opter.calcOpt(imgR, imgL) warped, output = SepConvNet(diff, imgL, imgR) loss_out = SepConvNet_cost(output, label) loss_warp = SepConvNet_cost(warped, label) loss = 0.5 * loss_out + 0.5 * loss_warp loss.backward() SepConvNet_optimizer.step() sumloss = sumloss + loss_out.data.item() tsumloss = tsumloss + loss_out.data.item() if cnt % printinterval == 0: writer.add_image("Target image", label[0], cnt) writer.add_image("Warped image", warped[0], cnt) writer.add_image("Final image", output[0], cnt) writer.add_scalar('Train Batch SATD loss', loss_out.data.item(), int(global_step / printinterval)) writer.add_scalar('Train Interval SATD loss', tsumloss / printinterval, int(global_step / printinterval)) print( 'Epoch [%d/%d], Iter [%d/%d], Time [%4.4f], Batch loss [%.6f], Interval loss [%.6f]' % (epoch + 1, EPOCH, cnt, len(trainset) // BATCH_SIZE, time.time() - start_time, loss_out.data.item(), tsumloss / printinterval)) tsumloss = 0.0 print('Epoch [%d/%d], iter: %d, Time [%4.4f], Avg Loss [%.6f]' % (epoch + 1, EPOCH, cnt, time.time() - start_time, sumloss / cnt)) # ---------------- Part for validation ---------------- trainloss = sumloss / cnt SepConvNet.eval().cuda() evalcnt = 0 pos = 0.0 sumloss = 0.0 psnr = 0.0 for imgL, imgR, label in valLoader: imgL = var(imgL).cuda() imgR = var(imgR).cuda() label = var(label).cuda() with torch.no_grad(): # Remember here we need the back-forward flow diff = opter.calcOpt(imgR, imgL) with torch.no_grad(): warped, output = SepConvNet(diff, imgL, imgR) loss_out = SepConvNet_cost(output, label) loss_warp = SepConvNet_cost(warped, label) loss = 0.5 * loss_out + 0.5 * loss_warp sumloss = sumloss + loss_out.data.item() psnr = psnr + calcPSNR.calcPSNR(output.cpu().data.numpy(), label.cpu().data.numpy()) evalcnt = evalcnt + 1 # ------------- Tensorboard part ------------- writer.add_scalar("Valid SATD loss", sumloss / evalcnt, epoch) writer.add_scalar("Valid PSNR", psnr / valset.__len__(), epoch) # ------------- Tensorboard part ------------- print('Validation loss [%.6f], Average PSNR [%.4f]' % (sumloss / evalcnt, psnr / valset.__len__())) SepConvNet_schedule.step(psnr / valset.__len__()) torch.save( SepConvNet.state_dict(), os.path.join( '.', 'equal_CZAR_iter' + str(epoch + 1) + '-ltype_fSATD_fs' + '-lr_' + str(LEARNING_RATE) + '-trainloss_' + str(round(trainloss, 4)) + '-evalloss_' + str(round(sumloss / evalcnt, 4)) + '-evalpsnr_' + str(round(psnr / valset.__len__(), 4)) + '.pkl')) writer.close()
def main(lr, batch_size, epoch, gpu, train_set, valid_set): # ------------- Part for tensorboard -------------- writer = SummaryWriter(log_dir='tb/ft1_baseline_mask') # ------------- Part for tensorboard -------------- torch.backends.cudnn.enabled = True torch.cuda.set_device(gpu) BATCH_SIZE = batch_size EPOCH = epoch LEARNING_RATE = lr belta1 = 0.9 belta2 = 0.999 trainset = mydataset(train_set, transform_train) valset = mydataset(valid_set) trainLoader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) valLoader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=False) SepConvNet = Network().cuda() # SepConvNet.apply(weights_init) SepConvNet.load_state_dict( torch.load( '/mnt/hdd/iku/ISCAS/train/mask_baseline_iter52-ltype_fSATD_fs-lr_0.001-trainloss_0.1279-evalloss_0.1181-evalpsnr_29.6526.pkl' )) # MSE_cost = nn.MSELoss().cuda() # SepConvNet_cost = nn.L1Loss().cuda() SepConvNet_cost = sepconv.SATDLoss().cuda() SepConvNet_optimizer = optim.Adamax(SepConvNet.parameters(), lr=LEARNING_RATE, betas=(belta1, belta2)) SepConvNet_schedule = optim.lr_scheduler.ReduceLROnPlateau( SepConvNet_optimizer, factor=0.1, patience=3, verbose=True, min_lr=1e-7) # ---------------- Time part ------------------- start_time = time.time() global_step = 0 # ---------------- Time part ------------------- # ---------------- Opt part ----------------------- opter = Opter(gpu) # ------------------------------------------------- for epoch in range(0, EPOCH): SepConvNet.train().cuda() cnt = 0 sumloss = 0.0 # The sumloss is for the whole training_set tsumloss = 0.0 # The tsumloss is for the printinterval printinterval = 300 print("---------------[Epoch%3d]---------------" % (epoch + 1)) for imgL, imgR, label in trainLoader: global_step = global_step + 1 cnt = cnt + 1 SepConvNet_optimizer.zero_grad() imgL = var(imgL).cuda() imgR = var(imgR).cuda() label = var(label).cuda() output = SepConvNet(imgL, imgR) loss = SepConvNet_cost(output, label) loss.backward() SepConvNet_optimizer.step() sumloss = sumloss + loss.data.item() tsumloss = tsumloss + loss.data.item() if cnt % printinterval == 0: writer.add_image("Ref image", imgR[0], cnt) writer.add_image("Pred image", output[0], cnt) writer.add_image("Target image", label[0], cnt) writer.add_scalar('Train Batch SATD loss', loss.data.item(), int(global_step / printinterval)) writer.add_scalar('Train Interval SATD loss', tsumloss / printinterval, int(global_step / printinterval)) print( 'Epoch [%d/%d], Iter [%d/%d], Time [%4.4f], Batch loss [%.6f], Interval loss [%.6f]' % (epoch + 1, EPOCH, cnt, len(trainset) // BATCH_SIZE, time.time() - start_time, loss.data.item(), tsumloss / printinterval)) tsumloss = 0.0 print('Epoch [%d/%d], iter: %d, Time [%4.4f], Avg Loss [%.6f]' % (epoch + 1, EPOCH, cnt, time.time() - start_time, sumloss / cnt)) # ---------------- Part for validation ---------------- trainloss = sumloss / cnt SepConvNet.eval().cuda() evalcnt = 0 pos = 0.0 sumloss = 0.0 psnr = 0.0 for imgL, imgR, label in valLoader: imgL = var(imgL).cuda() imgR = var(imgR).cuda() label = var(label).cuda() with torch.no_grad(): output = SepConvNet(imgL, imgR) loss = SepConvNet_cost(output, label) sumloss = sumloss + loss.data.item() psnr = psnr + calcPSNR.calcPSNR(output.cpu().data.numpy(), label.cpu().data.numpy()) evalcnt = evalcnt + 1 # ------------- Tensorboard part ------------- writer.add_scalar("Valid SATD loss", sumloss / evalcnt, epoch) writer.add_scalar("Valid PSNR", psnr / valset.__len__(), epoch) # ------------- Tensorboard part ------------- print('Validation loss [%.6f], Average PSNR [%.4f]' % (sumloss / evalcnt, psnr / valset.__len__())) SepConvNet_schedule.step(psnr / valset.__len__()) torch.save( SepConvNet.state_dict(), os.path.join( '.', 'ft1_mask_baseline_iter' + str(epoch + 1) + '-ltype_fSATD_fs' + '-lr_' + str(LEARNING_RATE) + '-trainloss_' + str(round(trainloss, 4)) + '-evalloss_' + str(round(sumloss / evalcnt, 4)) + '-evalpsnr_' + str(round(psnr / valset.__len__(), 4)) + '.pkl')) writer.close()
def main(lr, batch_size, epoch, gpu, train_set, valid_set): # ------------- Part for tensorboard -------------- # writer = SummaryWriter(log_dir='tb/LSTM_ft1') # ------------- Part for tensorboard -------------- torch.backends.cudnn.enabled = True torch.cuda.set_device(gpu) BATCH_SIZE=batch_size EPOCH=epoch LEARNING_RATE = lr belta1 = 0.9 belta2 = 0.999 trainset = vimeodataset(train_set, 'filelist.txt',transform_train) valset = vimeodataset(valid_set, 'test.txt') trainLoader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) valLoader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False) assert(len(valset) % BATCH_SIZE == 0) SepConvNet = Network().cuda() # SepConvNet.apply(weights_init) SepConvNet.load_my_state_dict(torch.load('tail_LSTM_iter15-ltype_fSATD_fs-lr_0.001-trainloss_0.6045-evalloss_0.1127-evalpsnr_30.2671.pkl', map_location='cuda:%d'%(gpu))) # SepConvNet.load_state_dict(torch.load('beta_LSTM_iter8-ltype_fSATD_fs-lr_0.001-trainloss_0.557-evalloss_0.1165-evalpsnr_29.8361.pkl')) # @@@ Test result: child from 0-27 is the raw model~ grad_list = [18,19,20,21, 25,26,27] child_cnt = 0 for child in SepConvNet.children(): if child_cnt in grad_list: child_cnt += 1 continue child_cnt += 1 for param in child.parameters(): param.requires_grad = False # cs = list(SepConvNet.children()) # ps = list(cs[17].parameters()) # IPython.embed() # exit() # MSE_cost = nn.MSELoss().cuda() # SepConvNet_cost = nn.L1Loss().cuda() SepConvNet_cost = sepconv.SATDLoss().cuda() # SepConvNet_optimizer = optim.Adamax(SepConvNet.parameters(),lr=LEARNING_RATE, betas=(belta1,belta2)) SepConvNet_optimizer = optim.Adamax(filter(lambda p: p.requires_grad, SepConvNet.parameters()),lr=LEARNING_RATE, betas=(belta1,belta2)) SepConvNet_schedule = optim.lr_scheduler.ReduceLROnPlateau(SepConvNet_optimizer, factor=0.1, patience = 3, verbose=True) # ---------------- Time part ------------------- start_time = time.time() global_step = 0 # ---------------- Time part ------------------- # ---------------- Opt part ----------------------- # opter = Opter(gpu) # ------------------------------------------------- # print('[!] Ready to train!') # IPython.embed() for epoch in range(0,EPOCH): SepConvNet.train().cuda() cnt = 0 sumloss = 0.0 # The sumloss is for the whole training_set tsumloss = 0.0 # The tsumloss is for the printinterval sumloss_b = 0.0 # The sumloss is for the whole training_set tsumloss_b = 0.0 # The tsumloss is for the printinterval printinterval = 500 print("---------------[Epoch%3d]---------------"%(epoch + 1)) for label_list in trainLoader: bad_list = label_list[7:] label_list = label_list[:7] # IPython.embed() # exit() global_step = global_step + 1 cnt = cnt + 1 for i in range(5): imgL = var(bad_list[i]).cuda() imgR = var(bad_list[i+1]).cuda() poor_label = var(bad_list[i+2]).cuda() label = var(label_list[i+2]).cuda() label_L = var(label_list[i]).cuda() # ----------- Forward prediction ----------- SepConvNet_optimizer.zero_grad() if i == 0: output_f, stat = SepConvNet(imgL, imgR, 0) else: output_f, stat = SepConvNet(imgL, imgR, 0, res_c, stat) loss = SepConvNet_cost(output_f, label) loss.backward(retain_graph=True) sumloss = sumloss + loss.data.item() tsumloss = tsumloss + loss.data.item() # ----------- Backward prediction ----------- SepConvNet_optimizer.zero_grad() output_b, stat = SepConvNet(output_f, imgR, 1, tensorHidden=stat) loss = SepConvNet_cost(output_b, label_L) if i < 4: loss.backward(retain_graph=True) else: loss.backward(retain_graph=False) sumloss_b = sumloss_b + loss.data.item() tsumloss_b = tsumloss_b + loss.data.item() res_f = poor_label - output_f res_b = imgL - output_b res_c = torch.cat([res_f, res_b], 1) if cnt % printinterval == 0: print('Epoch [%d/%d], Iter [%d/%d], Time [%4.4f], Back loss[%.6f], Interval loss [%.6f]' % (epoch + 1, EPOCH, cnt, len(trainset) // BATCH_SIZE, time.time() - start_time, tsumloss_b / printinterval / 5, tsumloss / printinterval / 5)) tsumloss = 0.0 tsumloss_b = 0.0 print('Epoch [%d/%d], iter: %d, Time [%4.4f], Avg Loss [%.6f]' % (epoch + 1, EPOCH, cnt, time.time() - start_time, sumloss / cnt / 5)) # ---------------- Part for validation ---------------- trainloss = sumloss / cnt SepConvNet.eval().cuda() evalcnt = 0 pos = 0.0 sumloss = 0.0 sumloss_b = 0.0 psnr = 0.0 psnr_b = 0.0 for label_list in valLoader: bad_list = label_list[7:] label_list = label_list[:7] loss_s = [] with torch.no_grad(): for i in range(5): imgL = var(bad_list[i]).cuda() imgR = var(bad_list[i+1]).cuda() poor_label = var(bad_list[i+2]).cuda() label = var(label_list[i+2]).cuda() label_L = var(label_list[i]).cuda() # ----------- Forward prediction ----------- if i == 0: output_f, stat = SepConvNet(imgL, imgR, 0) else: output_f, stat = SepConvNet(imgL, imgR, 0, res_c, stat) loss = SepConvNet_cost(output_f, label) psnr = psnr + calcPSNR.calcPSNR(output_f.cpu().data.numpy(), label.cpu().data.numpy()) sumloss = sumloss + loss.data.item() # sumloss = sumloss + loss.data.item() # tsumloss = tsumloss + loss.data.item() # ----------- Backward prediction ----------- output_b, stat = SepConvNet(output_f, imgR, 1, tensorHidden=stat) loss_b = SepConvNet_cost(output_b, label_L) psnr_b = psnr + calcPSNR.calcPSNR(output_b.cpu().data.numpy(), label_L.cpu().data.numpy()) sumloss_b = sumloss_b + loss_b.data.item() # sumloss_b = sumloss_b + loss.data.item() # tsumloss_b = tsumloss_b + loss.data.item() res_f = poor_label - output_f res_b = imgL - output_b res_c = torch.cat([res_f, res_b], 1) evalcnt = evalcnt + 5 # ------------- Tensorboard part ------------- # writer.add_scalar("Valid SATD loss", sumloss / evalcnt, epoch) # writer.add_scalar("Valid PSNR", psnr / valset.__len__(), epoch) # ------------- Tensorboard part ------------- print('Validation loss [%.6f], Average PSNR [%.4f], [!] Backward loss [%.6f] PSNR[%.4f]' % ( sumloss / evalcnt, psnr / evalcnt, sumloss_b / evalcnt, psnr_b / evalcnt)) SepConvNet_schedule.step(psnr / evalcnt) torch.save(SepConvNet.state_dict(), os.path.join('.', 'test_share_dual_LSTM_iter' + str(epoch + 1) + '-ltype_fSATD_fs' + '-lr_' + str(LEARNING_RATE) + '-trainloss_' + str(round(trainloss, 4)) + '-evalloss_' + str(round(sumloss / evalcnt, 4)) + '-evalpsnr_' + str(round(psnr / evalcnt, 4)) + '.pkl'))
def main(lr, batch_size, epoch, gpu, train_set, valid_set): # ------------- Part for tensorboard -------------- # writer = SummaryWriter(log_dir='tb/LSTM_ft1') # ------------- Part for tensorboard -------------- torch.backends.cudnn.enabled = True torch.cuda.set_device(gpu) BATCH_SIZE = batch_size EPOCH = epoch LEARNING_RATE = lr belta1 = 0.9 belta2 = 0.999 trainset = vimeodataset(train_set, 'filelist.txt', transform_train) valset = vimeodataset(valid_set, 'test.txt') trainLoader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) valLoader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False) assert (len(valset) % BATCH_SIZE == 0) # SepConvNet.apply(weights_init) # SepConvNet.load_state_dict(torch.load('beta_LSTM_iter8-ltype_fSATD_fs-lr_0.001-trainloss_0.557-evalloss_0.1165-evalpsnr_29.8361.pkl')) SepConvNet = Network().cuda() SepConvNet.load_my_state_dict( torch.load( 'ft2_baseline_iter86-ltype_fSATD_fs-lr_0.001-trainloss_0.1249-evalloss_0.1155-evalpsnr_29.9327.pkl', map_location='cuda:%d' % (gpu))) # MSE_cost = nn.MSELoss().cuda() # SepConvNet_cost = nn.L1Loss().cuda() SepConvNet_cost = sepconv.SATDLoss().cuda() SepConvNet_optimizer = optim.Adamax(SepConvNet.parameters(), lr=LEARNING_RATE, betas=(belta1, belta2)) SepConvNet_schedule = optim.lr_scheduler.ReduceLROnPlateau( SepConvNet_optimizer, factor=0.1, patience=3, verbose=True, min_lr=1e-6) # ---------------- Time part ------------------- start_time = time.time() global_step = 0 # ---------------- Time part ------------------- # ---------------- Opt part ----------------------- # opter = Opter(gpu) # ------------------------------------------------- for epoch in range(0, EPOCH): SepConvNet.train().cuda() cnt = 0 sumloss = 0.0 # The sumloss is for the whole training_set tsumloss = 0.0 # The tsumloss is for the printinterval printinterval = 100 print("---------------[Epoch%3d]---------------" % (epoch + 1)) for label_list in trainLoader: bad_list = label_list[7:] label_list = label_list[:7] # IPython.embed() # exit() global_step = global_step + 1 cnt = cnt + 1 loss_s = [] for i in range(5): imgL = var(bad_list[i]).cuda() imgR = var(bad_list[i + 1]).cuda() label = var(label_list[i + 2]).cuda() poor_label = var(bad_list[i + 2]).cuda() if i == 0: SepConvNet_optimizer.zero_grad() output, stat = SepConvNet(imgL, imgR) res = poor_label - output loss = SepConvNet_cost(output, label) loss.backward(retain_graph=True) SepConvNet_optimizer.step() sumloss = sumloss + loss.data.item() tsumloss = tsumloss + loss.data.item() elif i < 4: SepConvNet_optimizer.zero_grad() output, stat = SepConvNet(imgL, imgR, res, stat) res = poor_label - output loss = SepConvNet_cost(output, label) loss.backward(retain_graph=True) SepConvNet_optimizer.step() sumloss = sumloss + loss.data.item() tsumloss = tsumloss + loss.data.item() else: SepConvNet_optimizer.zero_grad() output, stat = SepConvNet(imgL, imgR, res, stat) res = poor_label - output loss = SepConvNet_cost(output, label) loss.backward() SepConvNet_optimizer.step() sumloss = sumloss + loss.data.item() tsumloss = tsumloss + loss.data.item() if cnt % printinterval == 0: print( 'Epoch [%d/%d], Iter [%d/%d], Time [%4.4f], Batch loss [%.6f], Interval loss [%.6f]' % (epoch + 1, EPOCH, cnt, len(trainset) // BATCH_SIZE, time.time() - start_time, loss.data.item(), tsumloss / printinterval / 5)) tsumloss = 0.0 print('Epoch [%d/%d], iter: %d, Time [%4.4f], Avg Loss [%.6f]' % (epoch + 1, EPOCH, cnt, time.time() - start_time, sumloss / cnt / 5)) # ---------------- Part for validation ---------------- trainloss = sumloss / cnt SepConvNet.eval().cuda() evalcnt = 0 pos = 0.0 sumloss = 0.0 psnr = 0.0 for label_list in valLoader: bad_list = label_list[7:] label_list = label_list[:7] loss_s = [] with torch.no_grad(): for i in range(5): imgL = var(bad_list[i]).cuda() imgR = var(bad_list[i + 1]).cuda() label = var(label_list[i + 2]).cuda() poor_label = var(bad_list[i + 2]).cuda() if i == 0: output, stat = SepConvNet(imgL, imgR) psnr = psnr + calcPSNR.calcPSNR( output.cpu().data.numpy(), label.cpu().data.numpy()) res = poor_label - output loss = SepConvNet_cost(output, label) sumloss = sumloss + loss.data.item() else: output, stat = SepConvNet(imgL, imgR, res, stat) psnr = psnr + calcPSNR.calcPSNR( output.cpu().data.numpy(), label.cpu().data.numpy()) res = poor_label - output loss = SepConvNet_cost(output, label) sumloss = sumloss + loss.data.item() evalcnt = evalcnt + 5 # ------------- Tensorboard part ------------- # writer.add_scalar("Valid SATD loss", sumloss / evalcnt, epoch) # writer.add_scalar("Valid PSNR", psnr / valset.__len__(), epoch) # ------------- Tensorboard part ------------- print('Validation loss [%.6f], Average PSNR [%.4f]' % (sumloss / evalcnt, psnr / evalcnt)) SepConvNet_schedule.step(psnr / evalcnt) torch.save( SepConvNet.state_dict(), os.path.join( '.', 'tail2_LSTM_iter' + str(epoch + 1) + '-ltype_fSATD_fs' + '-lr_' + str(LEARNING_RATE) + '-trainloss_' + str(round(trainloss, 4)) + '-evalloss_' + str(round(sumloss / evalcnt, 4)) + '-evalpsnr_' + str(round(psnr / evalcnt, 4)) + '.pkl'))
def main(lr, batch_size, epoch, gpu, train_set, valid_set): # ------------- Part for tensorboard -------------- # writer = SummaryWriter(log_dir='tb/LSTM_ft1') # ------------- Part for tensorboard -------------- torch.backends.cudnn.enabled = True torch.cuda.set_device(gpu) BATCH_SIZE = batch_size EPOCH = epoch LEARNING_RATE = lr belta1 = 0.9 belta2 = 0.999 trainset = vimeodataset(train_set, 'filelist.txt', transform_train) valset = vimeodataset(valid_set, 'test.txt') trainLoader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) valLoader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE * 2, shuffle=False) assert (len(valset) % BATCH_SIZE == 0) SepConvNet = Network().cuda() # SepConvNet.apply(weights_init) # SepConvNet.load_my_state_dict(torch.load('SepConv_iter95-ltype_fSATD_fs-lr_0.0001-trainloss_0.1441-evalloss_0.1324-evalpsnr_29.9585.pkl', map_location="cuda:%d"%(gpu))) SepConvNet.load_my_state_dict( torch.load( 'SepConv_iter95-ltype_fSATD_fs-lr_0.0001-trainloss_0.1441-evalloss_0.1324-evalpsnr_29.9585.pkl', map_location="cuda:%d" % (gpu))) # MSE_cost = nn.MSELoss().cuda() # SepConvNet_cost = nn.L1Loss().cuda() child_cnt = 0 skip_childs = list( set(range(33)) - set([14, 15, 16, 17, 20, 21, 22, 23, 26, 27, 28, 29])) for child in SepConvNet.children(): # print('----------- Children:%d ----------------'%(child_cnt)) # print(child) param_cnt = 0 if not child_cnt in skip_childs: for param in child.parameters(): # print("Param: %d in child: %d is frozen"%(param_cnt, child_cnt)) param.requires_grad = False param_cnt += 1 child_cnt += 1 SepConvNet_cost = sepconv.SATDLoss().cuda() # SepConvNet_optimizer = optim.Adamax(SepConvNet.parameters(),lr=LEARNING_RATE, betas=(belta1,belta2)) SepConvNet_optimizer = optim.Adamax(filter(lambda p: p.requires_grad, SepConvNet.parameters()), lr=LEARNING_RATE, betas=(belta1, belta2)) SepConvNet_schedule = optim.lr_scheduler.ReduceLROnPlateau( SepConvNet_optimizer, factor=0.1, patience=3, verbose=True, min_lr=1e-6) # IPython.embed() # exit() # ---------------- Time part ------------------- start_time = time.time() global_step = 0 # ---------------- Time part ------------------- # ---------------- Opt part ----------------------- # opter = Opter(gpu) # ------------------------------------------------- for epoch in range(0, EPOCH): SepConvNet.train().cuda() cnt = 0 sumloss = 0.0 # The sumloss is for the whole training_set tsumloss = 0.0 # The tsumloss is for the printinterval printinterval = 300 print("---------------[Epoch%3d]---------------" % (epoch + 1)) for label_list in trainLoader: bad_list = label_list[7:] label_list = label_list[:7] # IPython.embed() # exit() global_step = global_step + 1 cnt = cnt + 1 loss_s = [] for i in range(5): imgL = var(bad_list[i]).cuda() imgR = var(bad_list[i + 1]).cuda() label = var(label_list[i + 2]).cuda() poor_label = var(bad_list[i + 2]).cuda() if i == 0: SepConvNet_optimizer.zero_grad() output, output_a, output_b, stat = SepConvNet(imgL, imgR) res = poor_label - output # loss = SepConvNet_cost(output, label) loss = 0.5*SepConvNet_cost(output, label) + \ 0.2*SepConvNet_cost(output_a,func.upsample(label, size=(label.shape[2] // 4, label.shape[3] // 4), mode='bilinear',align_corners=True)) + \ 0.3*SepConvNet_cost(output_b,func.upsample(label, size=(label.shape[2] // 2, label.shape[3] // 2), mode='bilinear',align_corners=True)) loss.backward(retain_graph=True) SepConvNet_optimizer.step() sumloss = sumloss + loss.data.item() tsumloss = tsumloss + loss.data.item() elif i < 4: SepConvNet_optimizer.zero_grad() output, output_a, output_b, stat = SepConvNet( imgL, imgR, res, stat) res = poor_label - output # loss = SepConvNet_cost(output, label) loss = 0.5*SepConvNet_cost(output, label) + \ 0.2*SepConvNet_cost(output_a,func.upsample(label, size=(label.shape[2] // 4, label.shape[3] // 4), mode='bilinear',align_corners=True)) + \ 0.3*SepConvNet_cost(output_b,func.upsample(label, size=(label.shape[2] // 2, label.shape[3] // 2), mode='bilinear',align_corners=True)) loss.backward(retain_graph=True) SepConvNet_optimizer.step() sumloss = sumloss + loss.data.item() tsumloss = tsumloss + loss.data.item() else: SepConvNet_optimizer.zero_grad() output, output_a, output_b, stat = SepConvNet( imgL, imgR, res, stat) res = poor_label - output # loss = SepConvNet_cost(output, label) loss = 0.5*SepConvNet_cost(output, label) + \ 0.2*SepConvNet_cost(output_a,func.upsample(label, size=(label.shape[2] // 4, label.shape[3] // 4), mode='bilinear',align_corners=True)) + \ 0.3*SepConvNet_cost(output_b,func.upsample(label, size=(label.shape[2] // 2, label.shape[3] // 2), mode='bilinear',align_corners=True)) loss.backward() SepConvNet_optimizer.step() sumloss = sumloss + loss.data.item() tsumloss = tsumloss + loss.data.item() if cnt % printinterval == 0: print( 'Epoch [%d/%d], Iter [%d/%d], Time [%4.4f], Batch loss [%.6f], Interval loss [%.6f]' % (epoch + 1, EPOCH, cnt, len(trainset) // BATCH_SIZE, time.time() - start_time, loss.data.item(), tsumloss / printinterval / 5)) tsumloss = 0.0 print('Epoch [%d/%d], iter: %d, Time [%4.4f], Avg Loss [%.6f]' % (epoch + 1, EPOCH, cnt, time.time() - start_time, sumloss / cnt / 5)) # ---------------- Part for validation ---------------- trainloss = sumloss / cnt SepConvNet.eval().cuda() evalcnt = 0 pos = 0.0 sumloss = 0.0 psnr = 0.0 for label_list in valLoader: bad_list = label_list[7:] label_list = label_list[:7] loss_s = [] with torch.no_grad(): for i in range(5): imgL = var(bad_list[i]).cuda() imgR = var(bad_list[i + 1]).cuda() label = var(label_list[i + 2]).cuda() poor_label = var(bad_list[i + 2]).cuda() if i == 0: output, output_a, output_b, stat = SepConvNet( imgL, imgR) psnr = psnr + calcPSNR.calcPSNR( output.cpu().data.numpy(), label.cpu().data.numpy()) res = poor_label - output loss = SepConvNet_cost(output, label) sumloss = sumloss + loss.data.item() else: output, output_a, output_b, stat = SepConvNet( imgL, imgR, res, stat) psnr = psnr + calcPSNR.calcPSNR( output.cpu().data.numpy(), label.cpu().data.numpy()) res = poor_label - output loss = SepConvNet_cost(output, label) sumloss = sumloss + loss.data.item() evalcnt = evalcnt + 5 # ------------- Tensorboard part ------------- # writer.add_scalar("Valid SATD loss", sumloss / evalcnt, epoch) # writer.add_scalar("Valid PSNR", psnr / valset.__len__(), epoch) # ------------- Tensorboard part ------------- print('Validation loss [%.6f], Average PSNR [%.4f]' % (sumloss / evalcnt, psnr / evalcnt)) SepConvNet_schedule.step(psnr / evalcnt) torch.save( SepConvNet.state_dict(), os.path.join( '.', 'multiscale_test_LSTM_iter' + str(epoch + 1) + '-ltype_fSATD_fs' + '-lr_' + str(LEARNING_RATE) + '-trainloss_' + str(round(trainloss, 4)) + '-evalloss_' + str(round(sumloss / evalcnt, 4)) + '-evalpsnr_' + str(round(psnr / evalcnt, 4)) + '.pkl'))