Example #1
0
def loss_fuction_with_edge(x, y):
    MSEloss = sum_squared_error()
    loss1 = MSEloss(x, y)
    edgeloss = EdgeLoss()
    loss2 = edgeloss(x, y)

    return loss1 + loss2
def train_model_residual_lowlight_twostage():

    start_epoch = 1

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight/')
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = HSIDDenseNetTwoStageRDN(K)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    num_epoch = 100
    print('epoch count == ', num_epoch)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)

    #Scheduler
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)
    warmup_epochs = 3
    #scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(hsid_optimizer, num_epoch-warmup_epochs+40, eta_min=1e-7)
    #scheduler = GradualWarmupScheduler(hsid_optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
    #scheduler.step()

    #唤醒训练
    if RESUME:
        model_dir = './checkpoints'
        path_chk_rest = utils.get_last_path(model_dir, '_latest.pth')
        utils.load_checkpoint(net, path_chk_rest)
        start_epoch = utils.load_start_epoch(path_chk_rest) + 1
        utils.load_optim(hsid_optimizer, path_chk_rest)

        for i in range(1, start_epoch):
            scheduler.step()
        new_lr = scheduler.get_lr()[0]
        print(
            '------------------------------------------------------------------------------'
        )
        print("==> Resuming Training with learning rate:", new_lr)
        print(
            '------------------------------------------------------------------------------'
        )

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0

    criterion_char = CharbonnierLoss()
    criterion_edge = EdgeLoss()

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        #print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual, residual_stage2 = net(noisy, cubic)
            #loss = loss_fuction(residual, label-noisy) + loss_fuction(residual_stage2, label-noisy)
            restored_stage1 = noisy + residual
            restored_stage2 = noisy + residual_stage2
            #print(residual_stage2.shape)
            loss_char1 = criterion_char(restored_stage1.repeat(1, 3, 1, 1),
                                        label.repeat(1, 3, 1, 1))
            loss_char2 = criterion_char(restored_stage2.repeat(1, 3, 1, 1),
                                        label.repeat(1, 3, 1, 1))
            loss_char = loss_char1 + loss_char2
            loss_edge1 = criterion_edge(restored_stage1.repeat(1, 3, 1, 1),
                                        label.repeat(1, 3, 1, 1))
            loss_edge2 = criterion_edge(restored_stage2.repeat(1, 3, 1, 1),
                                        label.repeat(1, 3, 1, 1))
            loss_edge = loss_edge1 + loss_edge2
            loss = loss_char + (0.05 * loss_edge)

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/two_stage_hsid_dense_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual, residual_stage2 = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual_stage2

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual, axis=0)
                    residual_stage2_squeezed = torch.squeeze(residual_stage2,
                                                             axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual_stage2",
                                        residual_stage2_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if psnr > best_psnr:
            best_psnr = psnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                }, f"checkpoints/two_stage_hsid_dense_rdn_best.pth")

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     scheduler.get_lr()[0]))
        print(
            "------------------------------------------------------------------"
        )

        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict()
            }, os.path.join('./checkpoints', "model_latest.pth"))

    tb_writer.close()
Example #3
0
def main():
    args = parseargs()

    torch.manual_seed(42)

    model = SINet(train_encoder_only=True)

    configs = {
        0: {
            'batch_size': 36,
            'edge_size': 5,
            'mask_scale': 8,
            'image_size': (224, 224),
        },
        300: {
            'batch_size': 32,
            'edge_size': 15,
            'image_size': (224, 224),
        },
    }

    data_loader = {
        epoch: create_data_loader(cfg)
        for epoch, cfg in configs.items()
    }

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=5e-4,
                                 weight_decay=2e-4)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        [150, 250, 450, 550],
                                                        gamma=0.5)

    loss_fn = {
        epoch: EdgeLoss(cfg['edge_size'])
        for epoch, cfg in configs.items()
    }

    if args.use_cuda and torch.cuda.is_available():
        torch.cuda.init()

    trainer = Trainer(data_loader,
                      model,
                      optimizer,
                      loss_fn,
                      args.debug,
                      args.use_cuda,
                      best_model_filename='best_encoder_only_model.pt')

    if not osp.exists(trainer.checkpoint_dir):
        os.makedirs(trainer.checkpoint_dir)

    initial_epoch = 0
    if args.skip_encoder:
        assert osp.exists(trainer.best_model_checkpoint_filepath
                          ), 'Checkpoint file does not exist'
        initial_epoch = 300

    for epoch in range(initial_epoch, 600):
        print(f'Epoch\t{epoch}')

        lr = 0
        for param_group in trainer.optimizer.param_groups:
            lr = param_group['lr']
        print(f'Learning rate: {str(lr)}')

        if epoch == 300:
            print(f'Enabling Information Blocking and loading best model')
            trainer.model.train_encoder_only = False

            trainer.load_previous_best_model()

            trainer.best_iou = 0
            trainer.best_model_filename = 'best_model.pt'

            for param_group in trainer.optimizer.param_groups:
                param_group['lr'] = 5e-4

        trainer.train_one_epoch(epoch)

        trainer.validate(epoch)

        lr_scheduler.step()
    print(f'Final best model @{trainer.best_iou:.04f}')