def loss_fuction_with_edge(x, y): MSEloss = sum_squared_error() loss1 = MSEloss(x, y) edgeloss = EdgeLoss() loss2 = edgeloss(x, y) return loss1 + loss2
def train_model_residual_lowlight_twostage(): start_epoch = 1 device = DEVICE #准备数据 train_set = HsiCubicTrainDataset('./data/train_lowlight/') print('total training example:', len(train_set)) train_loader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True) #加载测试label数据 mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat' test_label_hsi = scio.loadmat(mat_src_path)['label'] #加载测试数据 batch_size = 1 test_data_dir = './data/test_lowlight/cubic/' test_set = HsiCubicLowlightTestDataset(test_data_dir) test_dataloader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False) batch_size, channel, width, height = next(iter(test_dataloader))[0].shape band_num = len(test_dataloader) denoised_hsi = np.zeros((width, height, band_num)) #创建模型 net = HSIDDenseNetTwoStageRDN(K) init_params(net) #net = nn.DataParallel(net).to(device) net = net.to(device) num_epoch = 100 print('epoch count == ', num_epoch) #创建优化器 #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999)) hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE) #Scheduler scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1) warmup_epochs = 3 #scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(hsid_optimizer, num_epoch-warmup_epochs+40, eta_min=1e-7) #scheduler = GradualWarmupScheduler(hsid_optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) #scheduler.step() #唤醒训练 if RESUME: model_dir = './checkpoints' path_chk_rest = utils.get_last_path(model_dir, '_latest.pth') utils.load_checkpoint(net, path_chk_rest) start_epoch = utils.load_start_epoch(path_chk_rest) + 1 utils.load_optim(hsid_optimizer, path_chk_rest) for i in range(1, start_epoch): scheduler.step() new_lr = scheduler.get_lr()[0] print( '------------------------------------------------------------------------------' ) print("==> Resuming Training with learning rate:", new_lr) print( '------------------------------------------------------------------------------' ) #定义loss 函数 #criterion = nn.MSELoss() global tb_writer tb_writer = get_summary_writer(log_dir='logs') gen_epoch_loss_list = [] cur_step = 0 first_batch = next(iter(train_loader)) best_psnr = 0 best_epoch = 0 best_iter = 0 criterion_char = CharbonnierLoss() criterion_edge = EdgeLoss() for epoch in range(start_epoch, num_epoch + 1): epoch_start_time = time.time() scheduler.step() #print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0])) print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0])) print(scheduler.get_lr()) gen_epoch_loss = 0 net.train() #for batch_idx, (noisy, label) in enumerate([first_batch] * 300): for batch_idx, (noisy, cubic, label) in enumerate(train_loader): #print('batch_idx=', batch_idx) noisy = noisy.to(device) label = label.to(device) cubic = cubic.to(device) hsid_optimizer.zero_grad() #denoised_img = net(noisy, cubic) #loss = loss_fuction(denoised_img, label) residual, residual_stage2 = net(noisy, cubic) #loss = loss_fuction(residual, label-noisy) + loss_fuction(residual_stage2, label-noisy) restored_stage1 = noisy + residual restored_stage2 = noisy + residual_stage2 #print(residual_stage2.shape) loss_char1 = criterion_char(restored_stage1.repeat(1, 3, 1, 1), label.repeat(1, 3, 1, 1)) loss_char2 = criterion_char(restored_stage2.repeat(1, 3, 1, 1), label.repeat(1, 3, 1, 1)) loss_char = loss_char1 + loss_char2 loss_edge1 = criterion_edge(restored_stage1.repeat(1, 3, 1, 1), label.repeat(1, 3, 1, 1)) loss_edge2 = criterion_edge(restored_stage2.repeat(1, 3, 1, 1), label.repeat(1, 3, 1, 1)) loss_edge = loss_edge1 + loss_edge2 loss = loss_char + (0.05 * loss_edge) loss.backward() # calcu gradient hsid_optimizer.step() # update parameter gen_epoch_loss += loss.item() if cur_step % display_step == 0: if cur_step > 0: print( f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}" ) else: print("Pretrained initial state") tb_writer.add_scalar("MSE loss", loss.item(), cur_step) #step ++,每一次循环,每一个batch的处理,叫做一个step cur_step += 1 gen_epoch_loss_list.append(gen_epoch_loss) tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch) #scheduler.step() #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0]) torch.save( { 'gen': net.state_dict(), 'gen_opt': hsid_optimizer.state_dict(), }, f"checkpoints/two_stage_hsid_dense_{epoch}.pth") #测试代码 net.eval() for batch_idx, (noisy_test, cubic_test, label_test) in enumerate(test_dataloader): noisy_test = noisy_test.type(torch.FloatTensor) label_test = label_test.type(torch.FloatTensor) cubic_test = cubic_test.type(torch.FloatTensor) noisy_test = noisy_test.to(DEVICE) label_test = label_test.to(DEVICE) cubic_test = cubic_test.to(DEVICE) with torch.no_grad(): residual, residual_stage2 = net(noisy_test, cubic_test) denoised_band = noisy_test + residual_stage2 denoised_band_numpy = denoised_band.cpu().numpy().astype( np.float32) denoised_band_numpy = np.squeeze(denoised_band_numpy) denoised_hsi[:, :, batch_idx] = denoised_band_numpy if batch_idx == 49: residual_squeezed = torch.squeeze(residual, axis=0) residual_stage2_squeezed = torch.squeeze(residual_stage2, axis=0) denoised_band_squeezed = torch.squeeze(denoised_band, axis=0) label_test_squeezed = torch.squeeze(label_test, axis=0) noisy_test_squeezed = torch.squeeze(noisy_test, axis=0) tb_writer.add_image(f"images/{epoch}_restored", denoised_band_squeezed, 1, dataformats='CHW') tb_writer.add_image(f"images/{epoch}_residual", residual_squeezed, 1, dataformats='CHW') tb_writer.add_image(f"images/{epoch}_residual_stage2", residual_stage2_squeezed, 1, dataformats='CHW') tb_writer.add_image(f"images/{epoch}_label", label_test_squeezed, 1, dataformats='CHW') tb_writer.add_image(f"images/{epoch}_noisy", noisy_test_squeezed, 1, dataformats='CHW') psnr = PSNR(denoised_hsi, test_label_hsi) ssim = SSIM(denoised_hsi, test_label_hsi) sam = SAM(denoised_hsi, test_label_hsi) #计算pnsr和ssim print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}". format(psnr, ssim, sam)) tb_writer.add_scalars("validation metrics", { 'average PSNR': psnr, 'average SSIM': ssim, 'avarage SAM': sam }, epoch) #通过这个我就可以看到,那个epoch的性能是最好的 #保存best模型 if psnr > best_psnr: best_psnr = psnr best_epoch = epoch best_iter = cur_step torch.save( { 'epoch': epoch, 'gen': net.state_dict(), 'gen_opt': hsid_optimizer.state_dict(), }, f"checkpoints/two_stage_hsid_dense_rdn_best.pth") print( "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]" % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr)) print( "------------------------------------------------------------------" ) print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}". format(epoch, time.time() - epoch_start_time, gen_epoch_loss, scheduler.get_lr()[0])) print( "------------------------------------------------------------------" ) torch.save( { 'epoch': epoch, 'gen': net.state_dict(), 'gen_opt': hsid_optimizer.state_dict() }, os.path.join('./checkpoints', "model_latest.pth")) tb_writer.close()
def main(): args = parseargs() torch.manual_seed(42) model = SINet(train_encoder_only=True) configs = { 0: { 'batch_size': 36, 'edge_size': 5, 'mask_scale': 8, 'image_size': (224, 224), }, 300: { 'batch_size': 32, 'edge_size': 15, 'image_size': (224, 224), }, } data_loader = { epoch: create_data_loader(cfg) for epoch, cfg in configs.items() } optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=2e-4) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [150, 250, 450, 550], gamma=0.5) loss_fn = { epoch: EdgeLoss(cfg['edge_size']) for epoch, cfg in configs.items() } if args.use_cuda and torch.cuda.is_available(): torch.cuda.init() trainer = Trainer(data_loader, model, optimizer, loss_fn, args.debug, args.use_cuda, best_model_filename='best_encoder_only_model.pt') if not osp.exists(trainer.checkpoint_dir): os.makedirs(trainer.checkpoint_dir) initial_epoch = 0 if args.skip_encoder: assert osp.exists(trainer.best_model_checkpoint_filepath ), 'Checkpoint file does not exist' initial_epoch = 300 for epoch in range(initial_epoch, 600): print(f'Epoch\t{epoch}') lr = 0 for param_group in trainer.optimizer.param_groups: lr = param_group['lr'] print(f'Learning rate: {str(lr)}') if epoch == 300: print(f'Enabling Information Blocking and loading best model') trainer.model.train_encoder_only = False trainer.load_previous_best_model() trainer.best_iou = 0 trainer.best_model_filename = 'best_model.pt' for param_group in trainer.optimizer.param_groups: param_group['lr'] = 5e-4 trainer.train_one_epoch(epoch) trainer.validate(epoch) lr_scheduler.step() print(f'Final best model @{trainer.best_iou:.04f}')