Esempio n. 1
0
 def compile(self, d_optimizer, g_optimizer, loss_fn, **kwargs):
     super(SRGan, self).compile(**kwargs)
     self.d_optimizer = d_optimizer
     self.g_optimizer = g_optimizer
     self.loss_fn = loss_fn
     self._psnr = PSNR(max_val=1.0)
     self._ssim = SSIM(max_val=1.0)
Esempio n. 2
0
    def __init__(self,
                 model,
                 dataloaders,
                 inc_overhead=False,
                 if_codec=None,
                 standard_epe=False):

        # use GPU if available
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        # model to device & inference mode
        self.model = model.to(self.device)
        self.model.train(False)

        # video dataloaders
        vid_dls = dataloaders
        self.f_s = vid_dls.f_s

        self.n_gop = vid_dls.n_gop
        if "PFrame" in self.model.name:
            # remove reference frame
            self.n_gop = self.n_gop - 1
        elif "BFrame" in self.model.name:
            # remove reference frames
            self.n_gop = self.n_gop - 2

        self.vid_dls = vid_dls.get_data_loaders()

        # I-Frame image codec
        self.if_codec = if_codec
        if if_codec is not None:
            self.img_codec = ImageCodec(codec=if_codec)

        # include overhead bits
        self.inc_overhead = inc_overhead

        # evaluation metrics

        # SSIM
        self.ssim = SSIM(
            data_range=1,
            multichannel=True,
            gaussian_weights=True,
        )

        # PSNR
        self.psnr = PSNR(data_range=1)

        # EPE using Farneback or LiteFlowNet
        self.epe = EPE(standard=standard_epe)

        self.standard_epe = standard_epe

        # VMAF
        self.vmaf = VMAF()
Esempio n. 3
0
    def __init__(self,
                 video_dir,
                 vid_ext="mp4",
                 frame_size=(224, 320),
                 num_frames=24,
                 method="DS",
                 mb_size=16,
                 search_dist=7):

        # per frame transform
        frame_transform = tf.Compose(
            [NpFrame2PIL("RGB"), tf.Resize(frame_size)])

        # composed transform
        video_transform = tf.Compose([
            CropVideoSequence(num_frames=num_frames),
            tf.Lambda(lambda frames: np.stack(
                [frame_transform(frame) for frame in frames]))
        ])

        # check video directory
        self.video_dataset = VideoDataset(root_dir=video_dir,
                                          vid_ext=vid_ext,
                                          transform=video_transform)

        self.num_videos = len(self.video_dataset)

        # motion parameters
        self.frame_size = frame_size
        self.num_frames = num_frames
        self.method = method
        self.mb_size = mb_size
        self.search_dist = search_dist

        # evaluation metrics

        # SSIM
        self.ssim = SSIM(
            data_range=255,
            multichannel=True,
            gaussian_weights=True,
        )

        # EPE using LiteFLowNet
        self.epe = EPE(standard=False)

        # PSNR
        self.psnr = PSNR(data_range=255)

        # VMAF
        self.vmaf = VMAF()
Esempio n. 4
0
    def __init__(self, config, sess=None):
        # config proto
        self.config = config
        self.channel_dim = self.config.channel_dim
        self.batch_size = self.config.batch_size
        self.patch_size = self.config.patch_size
        self.input_channels = self.config.input_channels

        # metrics
        self.ssim = SSIM(max_val=1.0)
        self.psnr = PSNR(max_val=1.0)

        # create session
        self.sess = sess
Esempio n. 5
0
def test(criterion, epoch):
    avg_mse=0
    avg_psnr = 0
    avg_ssim = 0
    print("===> Testing")
    with torch.no_grad():
        for iteration, batch in enumerate(testing_data_loader, 1):
            print('into test')
            input, target = batch[0], batch[1]
            if opt.cuda:
                input = input.cuda()
                target = target.cuda()
            prediction = model(input)
            #prediction=nn.parallel.data_parallel(model,input,range(2))
            mse = criterion(prediction, target)
            psnr = 10 * log10(1 / mse.item())
            ssim = SSIM(prediction, target)
            avg_psnr += psnr
            avg_ssim += ssim
            avg_mse += mse.item()

            if epoch%10 == 0:
                save_images(epoch,prediction,'epoch_{}_img_{}.jpg'.format(epoch,iteration),1)
                #prediction_output_filename= "result/prediction_{}.jpg".format(batch)
                #prediction.save(prediction_output_filename)
        test_loss_record="===>Testing Epoch[{}] Avg. PSNR: {:.4f} dB, SSIM:{:.10f}  MSE:{:.10f}".format(epoch,
                                                                                            avg_psnr / len(testing_data_loader),
                                                                                            avg_ssim / len(testing_data_loader),
                                                                                         avg_mse / len(testing_data_loader))
        print(test_loss_record)
        with open("test_loss_log.txt","a") as test_log_file:
            test_log_file.write(test_loss_record+ '\n')
Esempio n. 6
0
    def _calc_ssim(r_vid, c_vid):
        # calculate SSIM Guassian

        np_r_vid = sk.vread(r_vid)
        np_c_vid = sk.vread(c_vid)

        ssim = SSIM(data_range=255, multichannel=True,
                    gaussian_weights=True).calc_video(np_r_vid, np_c_vid)

        return ssim
Esempio n. 7
0
    def train(self):
        dataset = self.dataset(self.args.data,
                               self.args.gt,
                               self.args.val_size,
                               crop_div=8)
        out_dir = self._create_out_dir()

        train_dataset = dataset.get_dataset(data_type='train',
                                            batch_size=self.args.batch,
                                            repeat_count=None,
                                            crop_params=('random',
                                                         self.args.crop))

        valid_dataset = dataset.get_dataset(data_type='valid',
                                            batch_size=1,
                                            repeat_count=None,
                                            crop_params=('center',
                                                         self.args.crop))

        optimizer = tf.keras.optimizers.Adam(self.args.lr, beta_1=.5)
        self.model.compile(loss=MSE(),
                           optimizer=optimizer,
                           metrics=[PSNR(), SSIM()])

        early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_PSNR',
                                                      patience=10,
                                                      mode='max')
        csv_log = tf.keras.callbacks.CSVLogger(
            os.path.join(out_dir, f'train_{self.args.model}.log'))

        checkpoints_path = os.path.join(out_dir, self.args.model)
        saver = tf.keras.callbacks.ModelCheckpoint(
            filepath=checkpoints_path + '_{epoch:03d}-{val_PSNR:.2f}.h5',
            monitor='val_PSNR')

        callbacks = [csv_log, saver]
        if not self.args.no_early_stop:
            callbacks.append(early_stop)

        self.model.fit(train_dataset,
                       epochs=self.args.epochs,
                       callbacks=callbacks,
                       validation_data=valid_dataset,
                       verbose=1,
                       validation_steps=1,
                       steps_per_epoch=dataset.numel * self.args.repeat //
                       self.args.batch)
def main():
    args = _parse_args()

    model_dir = os.path.join(args.model_dir, 'srresnet')

    train_dataset = get_dataset(args.image_dir, args.batch_size,
                                args.crop_size, args.downscale_factor)

    mse_loss = tf.losses.MeanSquaredError()
    optimizer = tf.optimizers.Adam(1e-4)
    model = Generator(upscale_factor=args.downscale_factor)
    model.build((1, None, None, 3))

    model.compile(optimizer=optimizer,
                  loss=mse_loss,
                  metrics=[PSNR(max_val=1.0),
                           SSIM(max_val=1.0)])

    if args.weights_path:
        model.load_weights(args.weights_path)
        print('Loaded weights from {}'.format(args.weights_path))

    callbacks = [
        tf.keras.callbacks.TensorBoard(log_dir=os.path.join(model_dir, 'logs'),
                                       update_freq=500),
        tf.keras.callbacks.ModelCheckpoint(
            os.path.join(model_dir, 'srresnet_weights_epoch_') + '{epoch}',
            save_weights_only=True,
            save_best_only=False,
            monitor='loss',
            verbose=1,
            save_freq=args.step_per_epoch // 2),
    ]

    epochs = args.iterations // args.step_per_epoch

    model.fit(train_dataset,
              epochs=epochs,
              steps_per_epoch=args.step_per_epoch,
              callbacks=callbacks)
Esempio n. 9
0
    def evaluate(self):
        dataset = self.dataset(self.args.data, self.args.gt,
                               self.args.val_size)

        if self.args.val_size == 1. or self.args.val_size == 0.:
            dataset_length = dataset.numel
            data_type = 'all'
        else:
            dataset_length = int(dataset.numel * self.args.val_size)
            data_type = 'valid'

        eval_dataset = dataset.get_dataset(data_type=data_type,
                                           batch_size=1,
                                           repeat_count=1,
                                           crop_params=('pad', self.args.crop))

        prog_bar = ProgressBar(dataset_length, title='Evaluation')
        # TODO set metrics from config
        metrics = [PSNR(), SSIM()]

        if self.args.cpbd_mae:
            metrics.append(CPBD_MAE())
        if self.args.cpbd_mae:
            metrics.append(CPBD_PRED())
        if self.args.cpbd_mae:
            metrics.append(CPBD_TRUE())
        if self.args.lpips:
            metrics.append(LPIPS())

        for inputs, result in eval_dataset:
            predict = self.model.predict(inputs)
            update_str = ''
            for metric in metrics:
                metric.update_state(result.numpy(), predict)
                update_str += f'{metric.name}={metric.result():.4f} - '
            prog_bar.update(update_str)
def predict_lowlight_hsid_origin():

    #加载模型
    #hsid = HSID(36)
    hsid = HSIRDNECA_Denoise(K)
    hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    save_model_path = './checkpoints/hsirnd_denoise_l1loss'

    #hsid = hsid.to(DEVICE)
    hsid.load_state_dict(
        torch.load(save_model_path +
                   '/hsid_rdn_eca_l1_loss_600epoch_patchsize32_best.pth',
                   map_location='cuda:0')['gen'])

    #加载数据
    test_data_dir = './data/denoise/test/level25'
    test_set = HsiTrainDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #指定结果输出路径
    test_result_output_path = './data/denoise/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    hsid.eval()
    psnr_list = []
    for batch_idx, (noisy, label) in enumerate(test_dataloader):
        noisy = noisy.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)

        batch_size, width, height, band_num = noisy.shape
        denoised_hsi = np.zeros((width, height, band_num))

        noisy = noisy.to(DEVICE)
        label = label.to(DEVICE)

        with torch.no_grad():
            for i in range(band_num):  #遍历每个band去处理
                current_noisy_band = noisy[:, :, :, i]
                current_noisy_band = current_noisy_band[:, None]

                adj_spectral_bands = get_adjacent_spectral_bands(noisy, K, i)
                #adj_spectral_bands = torch.transpose(adj_spectral_bands,3,1) #将通道数置换到第二维
                adj_spectral_bands = adj_spectral_bands.permute(0, 3, 1, 2)
                adj_spectral_bands_unsqueezed = adj_spectral_bands.unsqueeze(1)
                #print(current_noisy_band.shape, adj_spectral_bands.shape)
                residual = hsid(current_noisy_band,
                                adj_spectral_bands_unsqueezed)
                denoised_band = residual + current_noisy_band
                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, i] += denoised_band_numpy

                test_label_current_band = label[:, :, :, i]

                label_band_numpy = test_label_current_band.cpu().numpy(
                ).astype(np.float32)
                label_band_numpy = np.squeeze(label_band_numpy)

                #print(denoised_band_numpy.shape, label_band_numpy.shape, label.shape)
                psnr = PSNR(denoised_band_numpy, label_band_numpy)
                psnr_list.append(psnr)

        mpsnr = np.mean(psnr_list)

        denoised_hsi_trans = denoised_hsi.transpose(2, 0, 1)
        test_label_hsi_trans = np.squeeze(label.cpu().numpy().astype(
            np.float32)).transpose(2, 0, 1)
        mssim = SSIM(denoised_hsi_trans, test_label_hsi_trans)
        sam = SAM(denoised_hsi_trans, test_label_hsi_trans)

        #计算pnsr和ssim
        print("=====averPSNR:{:.4f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(mpsnr, mssim, sam))

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat',
                 {'denoised': denoised_hsi})
Esempio n. 11
0
def train_model_residual_lowlight_rdn():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight_patchsize32/')
    #print('trainset32 training example:', len(train_set32))
    #train_set = HsiCubicTrainDataset('./data/train_lowlight/')

    #train_set_64 = HsiCubicTrainDataset('./data/train_lowlight_patchsize64/')

    #train_set_list = [train_set32, train_set_64]
    #train_set = ConcatDataset(train_set_list) #里面的样本大小必须是一致的,否则会连接失败
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    #test_data_dir = './data/test_lowlight/cuk12/'
    test_data_dir = './data/test_lowlight/cubic/'

    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    save_model_path = './checkpoints/hsirnd_cosine'
    if not os.path.exists(save_model_path):
        os.mkdir(save_model_path)

    #创建模型
    net = HSIRDN(K)
    init_params(net)
    net = nn.DataParallel(net).to(device)
    #net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    #scheduler = MultiStepLR(hsid_optimizer, milestones=[200,400], gamma=0.5)
    scheduler = CosineAnnealingLR(hsid_optimizer, T_max=600)

    #定义loss 函数
    #criterion = nn.MSELoss()

    is_resume = RESUME
    #唤醒训练
    if is_resume:
        path_chk_rest = dir_utils.get_last_path(save_model_path,
                                                'model_latest.pth')
        model_utils.load_checkpoint(net, path_chk_rest)
        start_epoch = model_utils.load_start_epoch(path_chk_rest) + 1
        model_utils.load_optim(hsid_optimizer, path_chk_rest)

        for i in range(1, start_epoch):
            scheduler.step()
        new_lr = scheduler.get_lr()[0]
        print(
            '------------------------------------------------------------------------------'
        )
        print("==> Resuming Training with learning rate:", new_lr)
        print(
            '------------------------------------------------------------------------------'
        )

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0
    if not is_resume:
        start_epoch = 1
    num_epoch = 600

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())

        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(noisy, cubic)
            alpha = 0.8
            loss = recon_criterion(residual, label - noisy)
            #loss = alpha*recon_criterion(residual, label-noisy) + (1-alpha)*loss_function_mse(residual, label-noisy)
            #loss = recon_criterion(residual, label-noisy)
            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            },
            f"{save_model_path}/hsid_rdn_4rdb_conise_l1_loss_600epoch_patchsize32_{epoch}.pth"
        )

        #测试代码
        net.eval()
        psnr_list = []
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual, axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

            test_label_current_band = test_label_hsi[:, :, batch_idx]
            psnr = PSNR(denoised_band_numpy, test_label_current_band)
            psnr_list.append(psnr)

        mpsnr = np.mean(psnr_list)

        denoised_hsi_trans = denoised_hsi.transpose(2, 0, 1)
        test_label_hsi_trans = test_label_hsi.transpose(2, 0, 1)
        mssim = SSIM(denoised_hsi_trans, test_label_hsi_trans)
        sam = SAM(denoised_hsi_trans, test_label_hsi_trans)

        #计算pnsr和ssim
        print("=====averPSNR:{:.4f}=====averSSIM:{:.4f}=====averSAM:{:.4f}".
              format(mpsnr, mssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': mpsnr,
            'average SSIM': mssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if mpsnr > best_psnr:
            best_psnr = mpsnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                },
                f"{save_model_path}/hsid_rdn_4rdb_conise_l1_loss_600epoch_patchsize32_best.pth"
            )

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     INIT_LEARNING_RATE))
        print(
            "------------------------------------------------------------------"
        )

        #保存当前模型
        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict()
            }, os.path.join(save_model_path, "model_latest.pth"))
    tb_writer.close()
    testloader = DataLoader(testset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=num_workers)
    # load full net
    full_net = SR_ITM_full_net(channels=feature_channels, scale=scale)
    load_name = osp.join(pretrained_model_dir,
                         'base_net_{:03d}.pth'.format(pretrained_epoch))
    print('loading checkpoint: {}'.format(load_name))
    checkpoint = torch.load(load_name, map_location=torch.device('cpu'))
    full_net.load_state_dict(checkpoint['model'], strict=False)

    criterions = {
        'mse': nn.MSELoss(reduction='mean'),
        'psnr': PSNR(peakval=1.0),
        'ssim': SSIM(data_range=1.0),
        'ms_ssim': MS_SSIM(data_range=1.0)
    }

    optimizer = optim.Adam(full_net.parameters(),
                           lr=lr,
                           weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                               milestones=milestones,
                                               gamma=gamma)

    if use_cuda is True:
        full_net.cuda()

    if load is True or resume is True:
        load_name = osp.join(checkpoint_dir,
def train_model_residual_lowlight_twostage_gan_best():

    #设置超参数
    batchsize = 128
    init_lr = 0.001
    K_adjacent_band = 36
    display_step = 20
    display_band = 20
    is_resume = False
    lambda_recon = 10

    start_epoch = 1

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight/')
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=batchsize,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    test_batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=test_batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = HSIDDenseNetTwoStage(K_adjacent_band)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    #创建discriminator
    disc = DiscriminatorABC(2, 4)
    init_params(disc)
    disc = disc.to(device)
    disc_opt = torch.optim.Adam(disc.parameters(), lr=init_lr)

    num_epoch = 100
    print('epoch count == ', num_epoch)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=init_lr)

    #Scheduler
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)
    warmup_epochs = 3
    #scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(hsid_optimizer, num_epoch-warmup_epochs+40, eta_min=1e-7)
    #scheduler = GradualWarmupScheduler(hsid_optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
    #scheduler.step()

    #唤醒训练
    if is_resume:
        model_dir = './checkpoints'
        path_chk_rest = dir_utils.get_last_path(model_dir, 'model_latest.pth')
        model_utils.load_checkpoint(net, path_chk_rest)
        start_epoch = model_utils.load_start_epoch(path_chk_rest) + 1
        model_utils.load_optim(hsid_optimizer, path_chk_rest)
        model_utils.load_disc_checkpoint(disc, path_chk_rest)
        model_utils.load_disc_optim(disc_opt, path_chk_rest)

        for i in range(1, start_epoch):
            scheduler.step()
        new_lr = scheduler.get_lr()[0]
        print(
            '------------------------------------------------------------------------------'
        )
        print("==> Resuming Training with learning rate:", new_lr)
        print(
            '------------------------------------------------------------------------------'
        )

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        #print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            ### Update discriminator ###
            disc_opt.zero_grad(
            )  # Zero out the gradient before backpropagation
            with torch.no_grad():
                fake, fake_stage2 = net(noisy, cubic)
            #print('noisy shape =', noisy.shape, fake_stage2.shape)
            #fake.detach()
            disc_fake_hat = disc(fake_stage2.detach() + noisy,
                                 noisy)  # Detach generator
            disc_fake_loss = adv_criterion(disc_fake_hat,
                                           torch.zeros_like(disc_fake_hat))
            disc_real_hat = disc(label, noisy)
            disc_real_loss = adv_criterion(disc_real_hat,
                                           torch.ones_like(disc_real_hat))
            disc_loss = (disc_fake_loss + disc_real_loss) / 2
            disc_loss.backward(retain_graph=True)  # Update gradients
            disc_opt.step()  # Update optimizer

            ### Update generator ###
            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual, residual_stage2 = net(noisy, cubic)
            disc_fake_hat = disc(residual_stage2 + noisy, noisy)
            gen_adv_loss = adv_criterion(disc_fake_hat,
                                         torch.ones_like(disc_fake_hat))

            alpha = 0.2
            beta = 0.2
            rec_loss = beta * (alpha*loss_fuction(residual, label-noisy) + (1-alpha) * recon_criterion(residual, label-noisy)) \
             + (1-beta) * (alpha*loss_fuction(residual_stage2, label-noisy) + (1-alpha) * recon_criterion(residual_stage2, label-noisy))

            loss = gen_adv_loss + lambda_recon * rec_loss

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                    print(
                        f"rec_loss =  {rec_loss.item()}, gen_adv_loss = {gen_adv_loss.item()}"
                    )

                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
                'disc': disc.state_dict(),
                'disc_opt': disc_opt.state_dict()
            }, f"checkpoints/two_stage_hsid_dense_gan_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual, residual_stage2 = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual_stage2

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual, axis=0)
                    residual_stage2_squeezed = torch.squeeze(residual_stage2,
                                                             axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual_stage2",
                                        residual_stage2_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if psnr > best_psnr:
            best_psnr = psnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                    'disc': disc.state_dict(),
                    'disc_opt': disc_opt.state_dict()
                }, f"checkpoints/two_stage_hsid_dense_gan_best.pth")

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     scheduler.get_lr()[0]))
        print(
            "------------------------------------------------------------------"
        )

        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
                'disc': disc.state_dict(),
                'disc_opt': disc_opt.state_dict()
            }, os.path.join('./checkpoints', "model_latest.pth"))

    tb_writer.close()
def train_model_residual_lowlight_rdn():

    device = DEVICE
    #准备数据
    train = np.load('./data/denoise/train_washington8.npy')
    train = train.transpose((2, 1, 0))

    test = np.load('./data/denoise/train_washington8.npy')
    #test=test.transpose((2,1,0))
    test = test.transpose((2, 1, 0))  #将通道维放在最前面

    save_model_path = './checkpoints/hsirnd_denoise_l1loss'
    if not os.path.exists(save_model_path):
        os.mkdir(save_model_path)

    #创建模型
    net = HSIRDNECA_Denoise(K)
    init_params(net)
    net = nn.DataParallel(net).to(device)
    #net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[200, 400], gamma=0.5)

    #定义loss 函数
    #criterion = nn.MSELoss()

    gen_epoch_loss_list = []

    cur_step = 0

    best_psnr = 0
    best_epoch = 0
    best_iter = 0
    start_epoch = 1
    num_epoch = 600

    mpsnr_list = []
    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())

        gen_epoch_loss = 0

        net.train()

        channels = 191  # 191 channels
        data_patches, data_cubic_patches = datagenerator(train, channels)

        data_patches = torch.from_numpy(data_patches.transpose((
            0,
            3,
            1,
            2,
        )))
        data_cubic_patches = torch.from_numpy(
            data_cubic_patches.transpose((0, 4, 1, 2, 3)))

        DDataset = DenoisingDataset(data_patches, data_cubic_patches, SIGMA)

        print('yes')
        DLoader = DataLoader(dataset=DDataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True)  # loader出问题了

        epoch_loss = 0
        start_time = time.time()

        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for step, x_y in enumerate(DLoader):
            #print('batch_idx=', batch_idx)
            batch_x_noise, batch_y_noise, batch_x = x_y[0], x_y[1], x_y[2]

            batch_x_noise = batch_x_noise.to(device)
            batch_y_noise = batch_y_noise.to(device)
            batch_x = batch_x.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(batch_x_noise, batch_y_noise)
            alpha = 0.8
            loss = recon_criterion(residual, batch_x - batch_x_noise)
            #loss = alpha*recon_criterion(residual, label-noisy) + (1-alpha)*loss_function_mse(residual, label-noisy)
            #loss = recon_criterion(residual, label-noisy)
            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            if step % 10 == 0:
                print('%4d %4d / %4d loss = %2.8f' %
                      (epoch + 1, step, data_patches.size(0) // BATCH_SIZE,
                       loss.item() / BATCH_SIZE))

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            },
            f"{save_model_path}/hsid_rdn_eca_l1_loss_600epoch_patchsize32_{epoch}.pth"
        )

        #测试代码
        net.eval()
        """
        channel_s = 191  # 设置多少波段
        data_patches, data_cubic_patches = datagenerator(test, channel_s)

        data_patches = torch.from_numpy(data_patches.transpose((0, 3, 1, 2,)))
        data_cubic_patches = torch.from_numpy(data_cubic_patches.transpose((0, 4, 1, 2, 3)))

        DDataset = DenoisingDataset(data_patches, data_cubic_patches, SIGMA)
        DLoader = DataLoader(dataset=DDataset, batch_size=BATCH_SIZE, shuffle=True)
        epoch_loss = 0
        
        for step, x_y in enumerate(DLoader):
            batch_x_noise, batch_y_noise, batch_x = x_y[0], x_y[1], x_y[2]

            batch_x_noise = batch_x_noise.to(DEVICE)
            batch_y_noise = batch_y_noise.to(DEVICE)
            batch_x = batch_x.to(DEVICE)
            residual = net(batch_x_noise, batch_y_noise)

            loss = loss_fuction(residual, batch_x-batch_x_noise)

            epoch_loss += loss.item()

            if step % 10 == 0:
                print('%4d %4d / %4d test loss = %2.4f' % (
                    epoch + 1, step, data_patches.size(0) // BATCH_SIZE, loss.item() / BATCH_SIZE))
        """
        #加载数据
        test_data_dir = './data/denoise/test/'
        test_set = HsiTrainDataset(test_data_dir)

        test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

        #指定结果输出路径
        test_result_output_path = './data/denoise/testresult/'
        if not os.path.exists(test_result_output_path):
            os.makedirs(test_result_output_path)

        #逐个通道的去噪
        """
        分配一个numpy数组,存储去噪后的结果
        遍历所有通道,
        对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
        调用hsid进行预测
        将预测到的residual和输入的noise加起来,得到输出band

        将去噪后的结果保存成mat结构
        """
        psnr_list = []
        for batch_idx, (noisy, label) in enumerate(test_dataloader):
            noisy = noisy.type(torch.FloatTensor)
            label = label.type(torch.FloatTensor)

            batch_size, width, height, band_num = noisy.shape
            denoised_hsi = np.zeros((width, height, band_num))

            noisy = noisy.to(DEVICE)
            label = label.to(DEVICE)

            with torch.no_grad():
                for i in range(band_num):  #遍历每个band去处理
                    current_noisy_band = noisy[:, :, :, i]
                    current_noisy_band = current_noisy_band[:, None]

                    adj_spectral_bands = get_adjacent_spectral_bands(
                        noisy, K, i)
                    #adj_spectral_bands = torch.transpose(adj_spectral_bands,3,1) #将通道数置换到第二维
                    adj_spectral_bands = adj_spectral_bands.permute(0, 3, 1, 2)
                    adj_spectral_bands_unsqueezed = adj_spectral_bands.unsqueeze(
                        1)
                    #print(current_noisy_band.shape, adj_spectral_bands.shape)
                    residual = net(current_noisy_band,
                                   adj_spectral_bands_unsqueezed)
                    denoised_band = residual + current_noisy_band
                    denoised_band_numpy = denoised_band.cpu().numpy().astype(
                        np.float32)
                    denoised_band_numpy = np.squeeze(denoised_band_numpy)

                    denoised_hsi[:, :, i] += denoised_band_numpy

                    test_label_current_band = label[:, :, :, i]

                    label_band_numpy = test_label_current_band.cpu().numpy(
                    ).astype(np.float32)
                    label_band_numpy = np.squeeze(label_band_numpy)

                    #print(denoised_band_numpy.shape, label_band_numpy.shape, label.shape)
                    psnr = PSNR(denoised_band_numpy, label_band_numpy)
                    psnr_list.append(psnr)

            mpsnr = np.mean(psnr_list)
            mpsnr_list.append(mpsnr)

            denoised_hsi_trans = denoised_hsi.transpose(2, 0, 1)
            test_label_hsi_trans = np.squeeze(label.cpu().numpy().astype(
                np.float32)).transpose(2, 0, 1)
            mssim = SSIM(denoised_hsi_trans, test_label_hsi_trans)
            sam = SAM(denoised_hsi_trans, test_label_hsi_trans)

            #计算pnsr和ssim
            print(
                "=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
                format(mpsnr, mssim, sam))

        #保存best模型
        if mpsnr > best_psnr:
            best_psnr = mpsnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                },
                f"{save_model_path}/hsid_rdn_eca_l1_loss_600epoch_patchsize32_best.pth"
            )

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, mpsnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     INIT_LEARNING_RATE))
        print(
            "------------------------------------------------------------------"
        )
Esempio n. 15
0
def train():
    parser = argparse.ArgumentParser(description='Trainer')
    parser.add_argument('--model', help='H5 TF upsample model', default=None)
    parser.add_argument('--scale', help='Model scale', default=2, type=int)
    args = parser.parse_args()

    scale = args.scale
    train_folder = './frames'
    valid_folder = './test_data'
    crop_size = 96
    repeat_count = 1
    batch_size = 16
    in_channels = 9
    out_channels = 3

    epochs = 100
    lr = 1e-4
    out_dir = 'results'
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    if args.model is not None:
        model = tf.keras.models.load_model(args.model, compile=False)
        model.summary()
        print('Model loaded')
    else:
        model = ResUNet(scale, in_channels, out_channels)
        print('New model created')

    train_dataset = get_dataset(
        train_folder,
        batch_size,
        crop_size,
        scale,
        repeat_count
    )

    valid_dataset = get_dataset(
        valid_folder,
        batch_size,
        crop_size,
        scale,
        repeat_count
    )

    optimizer = tf.keras.optimizers.Adam(lr, beta_1=.5)
    model.compile(
        loss=tf.losses.MeanAbsoluteError(),
        optimizer=optimizer,
        metrics=[PSNR(), SSIM()])

    early_stop = tf.keras.callbacks.EarlyStopping(
        monitor='val_PSNR', patience=10, mode='max')
    csv_log = tf.keras.callbacks.CSVLogger(
        os.path.join(out_dir, f'train_resunet.log'))

    saver = tf.keras.callbacks.ModelCheckpoint(
        filepath=out_dir + '/resunet_{epoch:03d}-{val_PSNR:.2f}.h5',
        monitor='val_PSNR')

    callbacks = [csv_log, saver, early_stop]

    model.fit(
        train_dataset,
        epochs=epochs,
        callbacks=callbacks,
        validation_data=valid_dataset,
        verbose=1
    )
def predict_lowlight_hsid_origin():

    #加载模型
    #hsid = HSID(36)
    hsid = HSID_origin(24)
    #hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    hsid = hsid.to(DEVICE)
    hsid.load_state_dict(
        torch.load('./checkpoints/hsid_origin_best.pth',
                   map_location='cuda:0')['gen'])

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = './data/test_lowlight/cuk12/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    hsid.eval()
    for batch_idx, (noisy_test, cubic_test,
                    label_test) in enumerate(test_dataloader):
        noisy_test = noisy_test.type(torch.FloatTensor)
        label_test = label_test.type(torch.FloatTensor)
        cubic_test = cubic_test.type(torch.FloatTensor)

        noisy_test = noisy_test.to(DEVICE)
        label_test = label_test.to(DEVICE)
        cubic_test = cubic_test.to(DEVICE)

        with torch.no_grad():

            residual = hsid(noisy_test, cubic_test)
            denoised_band = noisy_test + residual

            denoised_band_numpy = denoised_band.cpu().numpy().astype(
                np.float32)
            denoised_band_numpy = np.squeeze(denoised_band_numpy)

            denoised_hsi[:, :, batch_idx] = denoised_band_numpy

    psnr = PSNR(denoised_hsi, test_label_hsi)
    ssim = SSIM(denoised_hsi, test_label_hsi)
    sam = SAM(denoised_hsi, test_label_hsi)

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat',
                 {'denoised': denoised_hsi})

    #计算pnsr和ssim
    print("=====averPSNR:{:.4f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(
        psnr, ssim, sam))
def predict_lowlight_residual():

    #加载模型
    encam = ENCAM()
    #hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    encam = encam.to(DEVICE)

    encam.eval()
    encam.load_state_dict(
        torch.load('./checkpoints/encam_best_08_27.pth',
                   map_location='cuda:0')['gen'])

    #加载数据
    mat_src_path = '../HSID/data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label = scio.loadmat(mat_src_path)['label']
    #test=test.transpose((2,0,1)) #将通道维放在最前面:191*1280*307

    test_data_dir = '../HSID/data/test_lowlight/origin/'
    test_set = HsiLowlightTestDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    for batch_idx, (noisy, label) in enumerate(test_dataloader):
        noisy = noisy.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)

        batch_size, width, height, band_num = noisy.shape
        denoised_hsi = np.zeros((width, height, band_num))

        noisy = noisy.to(DEVICE)
        label = label.to(DEVICE)

        with torch.no_grad():
            for i in range(band_num):  #遍历每个band去处理
                current_noisy_band = noisy[:, :, :, i]
                current_noisy_band = current_noisy_band[:, None]

                adj_spectral_bands = get_adjacent_spectral_bands(
                    noisy, K, i)  # shape: batch_size, width, height, band_num
                adj_spectral_bands = adj_spectral_bands.permute(
                    0, 3, 1,
                    2)  #交换第一维和第三维 ,shape: batch_size, band_num, height, width
                adj_spectral_bands = torch.unsqueeze(adj_spectral_bands, 1)
                adj_spectral_bands = adj_spectral_bands.to(DEVICE)
                print('adj_spectral_bands : ', adj_spectral_bands.shape)
                print('adj_spectral_bands shape[4] =',
                      adj_spectral_bands.shape[4])
                #这里需要将current_noisy_band和adj_spectral_bands拆分成4份,每份大小为batchsize,1, band_num , height/2, width/2
                current_noisy_band_00 = current_noisy_band[:, :,
                                                           0:current_noisy_band
                                                           .shape[2] // 2,
                                                           0:current_noisy_band
                                                           .shape[3] // 2]
                adj_spectral_bands_00 = adj_spectral_bands[:, :, :,
                                                           0:adj_spectral_bands
                                                           .shape[3] // 2,
                                                           0:adj_spectral_bands
                                                           .shape[4] // 2]
                residual_00 = encam(current_noisy_band_00,
                                    adj_spectral_bands_00)
                denoised_band_00 = current_noisy_band_00 + residual_00

                current_noisy_band_00 = current_noisy_band[:, :,
                                                           0:current_noisy_band
                                                           .shape[2] // 2,
                                                           0:current_noisy_band
                                                           .shape[3] // 2]
                adj_spectral_bands_00 = adj_spectral_bands[:, :, :,
                                                           0:adj_spectral_bands
                                                           .shape[3] // 2,
                                                           0:adj_spectral_bands
                                                           .shape[4] // 2]
                residual_00 = encam(current_noisy_band_00,
                                    adj_spectral_bands_00)
                denoised_band_01 = current_noisy_band_00 + residual_00

                current_noisy_band_00 = current_noisy_band[:, :, 0:(
                    current_noisy_band.shape[2] //
                    2), 0:(current_noisy_band.shape[3] // 2)]
                adj_spectral_bands_00 = adj_spectral_bands[:, :, :,
                                                           0:adj_spectral_bands
                                                           .shape[3] // 2,
                                                           0:adj_spectral_bands
                                                           .shape[4] // 2]
                residual_00 = encam(current_noisy_band_00,
                                    adj_spectral_bands_00)
                denoised_band_10 = current_noisy_band_00 + residual_00

                current_noisy_band_00 = current_noisy_band[:, :,
                                                           0:current_noisy_band
                                                           .shape[2] // 2,
                                                           0:current_noisy_band
                                                           .shape[3] // 2]
                adj_spectral_bands_11 = adj_spectral_bands[:, :, :,
                                                           0:adj_spectral_bands
                                                           .shape[3] // 2,
                                                           0:adj_spectral_bands
                                                           .shape[4] // 2]
                residual_00 = encam(current_noisy_band_00,
                                    adj_spectral_bands_00)
                denoised_band_11 = current_noisy_band_00 + residual_00

                denoised_band_0 = torch.cat(
                    (denoised_band_00, denoised_band_01), dim=3)
                denoised_band_1 = torch.cat(
                    (denoised_band_10, denoised_band_11), dim=3)
                denoised_band = torch.cat((denoised_band_0, denoised_band_1),
                                          dim=2)

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, i] = denoised_band_numpy

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat',
                 {'denoised': denoised_hsi})

    psnr = PSNR(denoised_hsi, test_label)
    ssim = SSIM(denoised_hsi, test_label)
    sam = SAM(denoised_hsi, test_label)
    #计算pnsr和ssim
    print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(
        psnr, ssim, sam))
def predict_lowlight_residual():

    #加载模型
    #hsid = HSID(36)
    hsid = MultiStageHSIDUpscale(36)
    #hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    hsid = hsid.to(DEVICE)
    hsid.load_state_dict(torch.load('./checkpoints/hsid_multistage_upscale_patchsize64_best.pth', map_location='cuda:0')['gen'])

    #加载数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label = scio.loadmat(mat_src_path)['label']
    #test=test.transpose((2,0,1)) #将通道维放在最前面:191*1280*307

    test_data_dir = './data/test_lowlight/origin/'
    test_set = HsiLowlightTestDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    for batch_idx, (noisy, label) in enumerate(test_dataloader):
        noisy = noisy.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)
        
        batch_size, width, height, band_num = noisy.shape
        denoised_hsi = np.zeros((width, height, band_num))

        noisy = noisy.to(DEVICE)
        label = label.to(DEVICE)

        with torch.no_grad():
            for i in range(band_num): #遍历每个band去处理
                current_noisy_band = noisy[:,:,:,i]
                current_noisy_band = current_noisy_band[:,None]

                adj_spectral_bands = get_adjacent_spectral_bands(noisy, K, i)# shape: batch_size, width, height, band_num
                adj_spectral_bands = adj_spectral_bands.permute(0, 3,1,2)#交换第一维和第三维 ,shape: batch_size, band_num, height, width               
                adj_spectral_bands = adj_spectral_bands.to(DEVICE)
                residual = hsid(current_noisy_band, adj_spectral_bands)
                denoised_band = current_noisy_band + residual[0]

                denoised_band_numpy = denoised_band.cpu().numpy().astype(np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:,:,i] = denoised_band_numpy

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat', {'denoised': denoised_hsi})

    psnr = PSNR(denoised_hsi, test_label)
    ssim = SSIM(denoised_hsi, test_label)
    sam = SAM(denoised_hsi, test_label)
    #计算pnsr和ssim
    print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(psnr, ssim, sam)) 
Esempio n. 19
0
    graph = tf.get_default_graph()

    prediction = sess.run([prediction], feed_dict={inputs: input_test})
    prediction = np.uint8(np.squeeze(prediction) * 255)

    avg_frame = np.uint8((frame1 / 2. + frame3 / 2.))

    loss = sum(sum(sum((prediction / 255. - frame2 / 255.)**2))) / 2

    psnr = PSNR(prediction, frame2)
    # avg_psnr = PSNR(avg_frame, frame2)
    pre_gray = prediction[:, :, 0]
    frame2_gray = frame2[:, :, 0]
    # avg_frame_gray = avg_frame[:,:,0]
    ssim = SSIM(pre_gray, frame2_gray).mean()
    ms_ssim = MSSSIM(pre_gray, frame2_gray)
    # avg_SSIM = SSIM(avg_frame_gray, frame2_gray).mean()
    # avg_MS_SSIM = MSSSIM(avg_frame_gray, frame2_gray)

    # print "Loss = %.2f" % loss
    # print "avg_PSNR = %.2f, avg_SSIM = %.4f, avg_MS-SSIM = %.4f" % ( avg_psnr, avg_SSIM, avg_MS_SSIM)
    print "Loss = %.2f, PSNR = %.2f, SSIM = %.4f, MS-SSIM = %.4f" % (
        loss, psnr, ssim, ms_ssim)

    # print "loss = %f, PSNR = %f" % (loss, psnr)

    plt.subplot(221)
    plt.title("frame1")
    plt.imshow(frame1)
    plt.imsave("results/frame1.png", frame1)
Esempio n. 20
0
class SRGan(tf.keras.Model):
    def __init__(self, upscale_factor=4, generator_weights=None, **kwargs):
        super(SRGan, self).__init__(**kwargs)
        self.generator = Generator(weights=generator_weights)
        self.discriminator = Discriminator()
        self.vgg_loss = VGGLoss()

    def compile(self, d_optimizer, g_optimizer, loss_fn, **kwargs):
        super(SRGan, self).compile(**kwargs)
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self._psnr = PSNR(max_val=1.0)
        self._ssim = SSIM(max_val=1.0)

    def train_step(self, data):
        lr_images, hr_images = data
        batch_size = tf.shape(lr_images)[0]

        ones = tf.ones([batch_size])
        zeros = tf.zeros([batch_size])

        with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
            sr_images = self.generator(lr_images, training=True)

            fake_logits = self.discriminator(sr_images, training=True)
            real_logits = self.discriminator(hr_images, training=True)

            d_loss_fake = tf.reduce_mean(self.loss_fn(zeros, fake_logits))
            d_loss_real = tf.reduce_mean(self.loss_fn(ones, real_logits))
            d_loss = d_loss_fake + d_loss_real

            content_loss = self.vgg_loss(hr_images, sr_images)
            g_loss = tf.reduce_mean(self.loss_fn(ones, fake_logits))
            perceptual_loss = content_loss + 1e-3 * g_loss

            d_loss_scaled = \
                d_loss / self.distribute_strategy.num_replicas_in_sync
            perceptual_loss_scaled = \
                perceptual_loss / self.distribute_strategy.num_replicas_in_sync

        d_grads = d_tape.gradient(d_loss_scaled,
                                  self.discriminator.trainable_weights)
        g_grads = g_tape.gradient(perceptual_loss_scaled,
                                  self.generator.trainable_weights)

        self.d_optimizer.apply_gradients(
            zip(d_grads, self.discriminator.trainable_weights))
        self.g_optimizer.apply_gradients(
            zip(g_grads, self.generator.trainable_weights))

        self._psnr.update_state(hr_images, sr_images)
        self._ssim.update_state(hr_images, sr_images)

        return {
            'psnr': self._psnr.result(),
            'ssim': self._ssim.result(),
            'perceptual_loss': perceptual_loss,
            'content_loss': content_loss,
            'g_loss': g_loss,
            'd_loss_real': d_loss_real,
            'd_loss_fake': d_loss_fake
        }
Esempio n. 21
0
class EvalMVC(object):
    def __init__(self,
                 video_dir,
                 vid_ext="mp4",
                 frame_size=(224, 320),
                 num_frames=24,
                 method="DS",
                 mb_size=16,
                 search_dist=7):

        # per frame transform
        frame_transform = tf.Compose(
            [NpFrame2PIL("RGB"), tf.Resize(frame_size)])

        # composed transform
        video_transform = tf.Compose([
            CropVideoSequence(num_frames=num_frames),
            tf.Lambda(lambda frames: np.stack(
                [frame_transform(frame) for frame in frames]))
        ])

        # check video directory
        self.video_dataset = VideoDataset(root_dir=video_dir,
                                          vid_ext=vid_ext,
                                          transform=video_transform)

        self.num_videos = len(self.video_dataset)

        # motion parameters
        self.frame_size = frame_size
        self.num_frames = num_frames
        self.method = method
        self.mb_size = mb_size
        self.search_dist = search_dist

        # evaluation metrics

        # SSIM
        self.ssim = SSIM(
            data_range=255,
            multichannel=True,
            gaussian_weights=True,
        )

        # EPE using LiteFLowNet
        self.epe = EPE(standard=False)

        # PSNR
        self.psnr = PSNR(data_range=255)

        # VMAF
        self.vmaf = VMAF()

    def avg_time(self):
        # average vector estimation and compensation time (sec)

        total_time = 0.0

        for r_vid in self.video_dataset:
            # sum time
            start_time = timer()
            self.ipp_bmc(r_vid, self.ipp_bme(r_vid))[1:]
            end_time = timer()
            total_time += end_time - start_time

        avg_time = total_time / self.num_videos

        return avg_time

    def avg_vmaf(self):
        # average VMAF

        total_vmaf = 0.0

        for r_vid in self.video_dataset:
            # sum vmaf values
            total_vmaf += self.calc_vmaf(
                r_vid[1:] / 255,
                self.ipp_bmc(r_vid, self.ipp_bme(r_vid))[1:] / 255)

        avg_vmaf = total_vmaf / self.num_videos

        return avg_vmaf

    def avg_ssim(self):
        # average SSIM
        total_ssim = 0.0

        for r_vid in self.video_dataset:

            # sum ssim values
            total_ssim += self.calc_ssim(
                r_vid[1:],
                self.ipp_bmc(r_vid, self.ipp_bme(r_vid))[1:])

        avg_ssim = total_ssim / self.num_videos

        return avg_ssim

    def avg_psnr(self):
        # average PSNR

        total_psnr = 0.0

        for r_vid in self.video_dataset:
            # sum psnr values
            total_psnr += self.calc_psnr(
                r_vid[1:],
                self.ipp_bmc(r_vid, self.ipp_bme(r_vid))[1:])

        avg_psnr = total_psnr / self.num_videos

        return avg_psnr

    def avg_epe(self):
        # average EPE

        total_epe = 0.0

        for r_vid in self.video_dataset:
            # sum epe values
            total_epe += self.calc_epe(
                r_vid / 255,
                self.ipp_bmc(r_vid, self.ipp_bme(r_vid)) / 255)

        avg_epe = total_epe / self.num_videos

        return avg_epe

    def avg_bpp(self):
        # average bits-per-pixel

        total_bpp = 0.0

        for r_vid in self.video_dataset:
            # sum bpp values
            total_bpp += self.calc_bpp(self.ipp_bme(r_vid))

        avg_bpp = total_bpp / self.num_videos

        return avg_bpp

    def ipp_bme(self, videodata):
        # I, P, P, P Block Motion Estimation
        motion = sk_m.blockMotion(videodata,
                                  method=self.method,
                                  mbSize=self.mb_size,
                                  p=self.search_dist)
        # motion (numFrames - 1, height / mbSize, width / mbSize, 2)
        return motion

    def ipp_bmc(self, videodata, motion):
        # I, P, P, P Block Motion Compensation
        bmc = sk_m.blockComp(videodata, motionVect=motion, mbSize=self.mb_size)

        return bmc

    def display_avg_stats(self):
        # print averaged scores
        print("Bpp  : {}".format(self.avg_bpp()))
        print("PSNR : {}".format(self.avg_psnr()))
        print("SSIM : {}".format(self.avg_ssim()))
        print("VMAF : {}".format(self.avg_vmaf()))
        print("EPE  : {}".format(self.avg_epe()))
        print("Time (sec) : {}".format(self.avg_time()))
        return

    def display_bmc_video(self, index=0):
        # display Block Motion Compensated Video
        r_vid = self.video_dataset[index]
        c_vid = self.ipp_bmc(r_vid, self.ipp_bme(r_vid))

        # print evaluation metrics
        bpp_str = "bpp : {}".format(
            round(self.calc_bpp(self.ipp_bme(r_vid)), 4))
        psnr_str = "PSNR : {}".format(
            round(self.calc_psnr(r_vid[1:], c_vid[1:]), 2))
        ssim_str = "SSIM : {}".format(
            round(self.calc_ssim(r_vid[1:], c_vid[1:]), 2))
        vmaf_str = "VMAF : {}".format(
            round(self.calc_ssim(r_vid[1:] / 255, c_vid[1:] / 255), 2))

        # set-up plot
        x_label = "".join([psnr_str, ssim_str, vmaf_str])
        img_t.setup_plot("", y_label=bpp_str, x_label=x_label)

        # display compensated sequence
        vid_t.display_frames(c_vid / 255)

        return

    def calc_vmaf(self, r_vid, c_vid):
        # calculate VMAF
        return self.vmaf.calc_video(r_vid, c_vid)

    def calc_ssim(self, r_vid, c_vid):
        # calculate SSIM
        return self.ssim.calc_video(r_vid, c_vid)

    def calc_psnr(self, r_vid, c_vid):
        # calculate PSNR
        return self.psnr.calc_video(r_vid, c_vid)

    def calc_epe(self, r_vid, c_vid):
        # calculate EPE
        return self.epe.calc_video(r_vid, c_vid)

    def calc_bpp(self, motion_vectors):
        # calculate bpp for Motion Vectors
        # Note: this is direct binarisation without overhead for retaining shape
        # i.e. how many bpp do we need to convey motion

        total_bits = 0.0

        t, h, w, _ = motion_vectors.shape

        for f in range(t):
            for y in range(h):
                for x in range(w):
                    dx, dy = motion_vectors[f, y, x]

                    if dy != 0.0 or dx != 0.0:
                        total_bits += self.bit_count(dy) + self.bit_count(
                            dx) + self.bit_count(x) + self.bit_count(y)
        # bits per pixel
        f_h, f_w = self.frame_size
        bpp = total_bits / (f_h * f_w * t)

        return bpp

    def calc_cc(self, metric, save_dir="./"):
        # calculate compression curve

        if metric not in ["PSNR", "SSIM", "VMAF", "EPE"]:
            raise KeyError(
                "Specified metric : {}, is not currently supported!".format(
                    metric))

        # calculate metric values
        met, bpp = self._prog_eval(metric)

        # compression curve dictionary
        curve = {"bpp": bpp, "metric": met}

        # create file name
        file_name = "".join([save_dir, "/", self.method, "_", metric, '.npy'])

        # save curve as numpy file
        np.save(file_name, curve)

    def _prog_eval(self, metric):

        # metric & bpp lists
        m = []
        b = []

        # macro-block sizes
        og_mb_size = self.mb_size
        mb_sizes = [4, 8, 16]

        for mb_size in mb_sizes:

            self.mb_size = mb_size

            if metric == "PSNR":
                m_val = self.avg_psnr()

            elif metric == "SSIM":
                m_val = self.avg_ssim()

            elif metric == "VMAF":
                m_val = self.avg_vmaf()

            elif metric == "EPE":
                m_val = self.avg_epe()
            else:
                m_val = None

            b_val = self.avg_bpp()

            # append values
            m.append(m_val)
            b.append(b_val)

        # reset macro-block size to original
        self.mb_size = og_mb_size

        return m, b

    @staticmethod
    def bit_count(val):
        # return number of bits needed to represent val
        return len(np.binary_repr(int(val)))
Esempio n. 22
0
class DerainNet:
    model_name = 'ReMAEN'
    
    '''Derain Net: all the implemented layer are included (e.g. MAEB,
                                                                convGRU
                                                                shared channel attention,
                                                                channel attention).

        Params:
            config: the training configuration
            sess: runing session
    '''
    
    def __init__(self, config, sess=None):
        # config proto
        self.config = config
        self.channel_dim = self.config.channel_dim
        self.batch_size = self.config.batch_size
        self.patch_size = self.config.patch_size
        self.input_channels = self.config.input_channels
        
        # metrics
        self.ssim = SSIM(max_val=1.0)
        self.psnr = PSNR(max_val=1.0)

        # create session
        self.sess = sess
    
    # global average pooling
    def globalAvgPool2D(self, input_x):
        global_avgpool2d = tf.contrib.keras.layers.GlobalAvgPool2D()
        return global_avgpool2d(input_x)
    
    # leaky relu
    def leakyRelu(self, input_x):
        leaky_relu = tf.contrib.keras.layers.LeakyReLU(alpha=0.2)
        return leaky_relu(input_x)

    # squeeze-and-excitation block
    def SEBlock(self, input_x, input_dim=32, reduce_dim=8, scope='SEBlock'):
        with tf.variable_scope(scope) as scope:
            # global scale
            global_pl = self.globalAvgPool2D(input_x)
            reduce_fc1 = slim.fully_connected(global_pl, reduce_dim, activation_fn=tf.nn.relu)
            reduce_fc2 = slim.fully_connected(reduce_fc1, input_dim, activation_fn=None)
            g_scale = tf.nn.sigmoid(reduce_fc2)
            g_scale = tf.expand_dims(g_scale, axis=1)
            g_scale = tf.expand_dims(g_scale, axis=1)
            gs_input = input_x*g_scale
            return gs_input

    # GRU with convolutional version
    def convGRU(self, input_x, h, out_dim, scope='convGRU'):
        with tf.variable_scope(scope):
            if h is None:
                self.conv_xz = slim.conv2d(input_x, out_dim, 3, 1, scope='conv_xz')
                self.conv_xn = slim.conv2d(input_x, out_dim, 3, 1, scope='conv_xn')
                z = tf.nn.sigmoid(self.conv_xz)
                f = tf.nn.tanh(self.conv_xn)
                h = z*f
            else:
                self.conv_hz = slim.conv2d(h, out_dim, 3, 1, scope='conv_hz')
                self.conv_hr = slim.conv2d(h, out_dim, 3, 1, scope='conv_hr')

                self.conv_xz = slim.conv2d(input_x, out_dim, 3, 1, scope='conv_xz')
                self.conv_xr = slim.conv2d(input_x, out_dim, 3, 1, scope='conv_xr')
                self.conv_xn = slim.conv2d(input_x, out_dim, 3, 1, scope='conv_xn')
                r = tf.nn.sigmoid(self.conv_hr+self.conv_xr)
                z = tf.nn.sigmoid(self.conv_hz+self.conv_xz)
                
                self.conv_hn = slim.conv2d(r*h, out_dim, 3, 1, scope='conv_hn')
                n = tf.nn.tanh(self.conv_xn + self.conv_hn)
                h = (1-z)*h + z*n

        # shared channel attention block
        se = self.SEBlock(h, out_dim, reduce_dim=int(out_dim/4))
        h = self.leakyRelu(se)
        return h, h

    # multi-scale aggregation and enhancement block(MAEB)
    def MAEB(self, input_x, scope_name, dilated_factors=3):
        '''MAEB: multi-scale aggregation and enhancement block
            Params:
                input_x: input data
                scope_name: the scope name of the MAEB (customer definition)
                dilated_factor: the maximum number of dilated factors(default=3, range from 1 to 3)

            Return:
                return the output the MAEB
                
            Input shape:
                4D tensor with shape '(batch_size, height, width, channels)'
                
            Output shape:
                4D tensor with shape '(batch_size, height, width, channels)'
        '''
        dilate_c = []  
        with tf.variable_scope(scope_name):
            for i in range(1,dilated_factors+1):
                d1 = self.leakyRelu(slim.conv2d(input_x, self.channel_dim, 3, 1, rate=i, activation_fn=None, scope='d1'))
                d2 = self.leakyRelu(slim.conv2d(d1, self.channel_dim, 3, 1, rate=i, activation_fn=None, scope='d2'))
                dilate_c.append(d2)

            add = tf.add_n(dilate_c)
            shape = add.get_shape().as_list()
            output = self.SEBlock(add, shape[-1], reduce_dim=int(shape[-1]/4))
            return output

    # multi-scale aggregation and enhancement network
    def derainNet(self, input_x, scope_name='derainNet'):    
        '''ReMAEN: recurrent multi-scale aggregation and enhancement network
            Params:
                input_x: input data
                scope_name: the scope name of the ReMAEN (customer definition, default='derainnet')
            Return:
                return the derained results

            Input shape:
                4D tensor with shape '(batch_size, height, width, channels)'
                
            Output shape:
                4D tensor with shape '(batch_size, height, width, channels)'            
        '''
        # reuse: tf.AUTO_REUSE(such setting will enable the network to reuse parameters automatically)
        with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE):
            with slim.arg_scope([slim.conv2d,slim.conv2d_transpose], weights_initializer=tf.contrib.layers.xavier_initializer(),
                                              normalizer_fn = None,
                                              activation_fn = None,
                                              padding='SAME'):
                old_states = [None for _ in range(7)]
                stages = 3
                derain = input_x

                for i in range(stages):
                    cur_states = []
                    with tf.variable_scope('ReMAEN'):
                        with tf.variable_scope('extracting_path'):
                            MAEB1 = self.MAEB(derain, scope_name='MAEB1')
                            gru1, h1 = self.convGRU(MAEB1, old_states[0], self.channel_dim, scope='convGRU1')
                            cur_states.append(h1)

                            MAEB2 = self.MAEB(gru1, scope_name='MAEB2')
                            gru2, h2 = self.convGRU(MAEB2, old_states[1], self.channel_dim, scope='convGRU2')
                            cur_states.append(h2)
                            
                            MAEB3 = self.MAEB(gru2, scope_name='MAEB3')
                            gru3, h3 = self.convGRU(MAEB3, old_states[2], self.channel_dim, scope='convGRU3')
                            cur_states.append(h3)

                            MAEB4 = self.MAEB(gru3, scope_name='MAEB4')
                            gru4, h4 = self.convGRU(MAEB4, old_states[3], self.channel_dim, scope='convGRU4')
                            cur_states.append(h4)
                            
                        with tf.variable_scope('responding_path'):
                            up5 = slim.conv2d(gru4, self.channel_dim, 3, 1, activation_fn=tf.nn.relu, scope='conv5')
                            add5 = tf.add(up5, MAEB3)
                            gru5, h5 = self.convGRU(add5, old_states[4], self.channel_dim, scope='convGRU5')
                            cur_states.append(h5)
                            
                            up6 = slim.conv2d(gru5, self.channel_dim, 3, 1, activation_fn=tf.nn.relu, scope='conv6')
                            add6 = tf.add(up6, MAEB2)
                            gru6, h6 = self.convGRU(add6, old_states[5], self.channel_dim, scope='convGRU6')
                            cur_states.append(h6)
                            
                            up7 = slim.conv2d(gru6, self.channel_dim, 3, 1, activation_fn=tf.nn.relu, scope='conv7')
                            add7 = tf.add(up7, MAEB1)
                            gru7, h7 = self.convGRU(add7, old_states[6], self.channel_dim, scope='convGRU7')
                            cur_states.append(h7)
                        
                    # residual map generator
                    with tf.variable_scope('RMG'):
                        rmg_conv = slim.conv2d(gru7, self.channel_dim, 3, 1)
                        rmg_conv_se = self.leakyRelu(self.SEBlock(rmg_conv, self.channel_dim, reduce_dim=int(self.channel_dim/4)))
                        residual = slim.conv2d(rmg_conv_se, self.input_channels, 3, 1)
                    
                    derain = derain - residual
                    old_states = [tf.identity(s) for s in cur_states]

        return derain, residual
    
    def build(self):
        # placeholder
        self.rain = tf.placeholder(tf.float32, [None, None, None, self.input_channels], name='rain')
        self.norain = tf.placeholder(tf.float32, [None, None, None, self.input_channels], name='norain')
        self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
        
        # derainnet
        self.out, self.residual = self.derainNet(self.rain)
        self.finer_out = tf.clip_by_value(self.out, 0, 1.0)
        self.finer_residual = tf.clip_by_value(tf.abs(self.residual), 0, 1)
        
        # metrics
        self.ssim_finer_tensor = tf.reduce_mean(self.ssim._ssim(self.norain, self.out, 0, 0))
        self.psnr_finer_tensor = tf.reduce_mean(self.psnr.compute_psnr(self.norain, self.out))
        self.ssim_val = tf.reduce_mean(self.ssim._ssim(self.norain, self.finer_out, 0, 0))
        self.psnr_val = tf.reduce_mean(self.psnr.compute_psnr(self.norain, self.finer_out))
        
        # loss function
        # MSE loss
        self.l2_loss = tf.reduce_mean(tf.square(self.out - self.norain))
        # edge loss, kernel is imported from settings
        self.norain_edge = tf.nn.relu(tf.nn.conv2d(tf.image.rgb_to_grayscale(self.norain), kernel, [1,1,1,1],padding='SAME'))
        self.derain_edge = tf.nn.relu(tf.nn.conv2d(tf.image.rgb_to_grayscale(self.out), kernel, [1,1,1,1],padding='SAME'))
        self.edge_loss = tf.reduce_mean(tf.square(self.norain_edge-self.derain_edge))
        # total loss
        self.total_loss = self.l2_loss + 0.1*self.edge_loss
        
        # optimization
        t_vars = tf.trainable_variables()
        g_vars = [var for var in t_vars if 'derainNet' in var.name]
        self.train_ops = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=self.config.beta1, beta2=self.config.beta2).minimize(self.total_loss, var_list=g_vars)
        
        # summary
        self.l2_loss_summary = tf.summary.scalar('l2_loss', self.l2_loss)
        self.total_loss_summary = tf.summary.scalar('total_loss', self.total_loss)
        self.edge_loss_summary = tf.summary.scalar('edge_loss', self.edge_loss)
        self.ssim_summary = tf.summary.scalar('ssim', self.ssim_val)
        self.psnr_summary = tf.summary.scalar('psnr', self.psnr_val)
        self.summaries = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter(self.config.logs_dir, self.sess.graph)
        
        # saver
        global_variables = tf.global_variables()
        var_to_store = [var for var in global_variables if 'derainNet' in var.name]
        self.saver = tf.train.Saver(var_list=var_to_store)

        # trainable variables
        num_params = 0
        for var in g_vars:
            tmp_num = 1
            for i in var.get_shape().as_list():
                tmp_num = tmp_num*i
            num_params = num_params + tmp_num
        print('numbers of trainable parameters:{}'.format(num_params))

    # training phase
    def train(self):
        # initialize variables
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        # load training model
        check_bool = self.load_model()
        if check_bool:
            print('[!!!] load model successfully')
        else:
            print('[***] fail to load model')
        
        lr_ = self.config.lr
        start_time = time.time()
        for counter in range(self.config.iterations):
            if counter == 30000:
                lr_ = 0.1*lr_

            # obtain training image pairs
            img, label = read_data(self.config.train_dataset, self.config.data_path, self.batch_size, self.patch_size, self.config.trainset_size)
            _, total_loss, summaries, ssim, psnr = self.sess.run([self.train_ops,
                                                               self.total_loss,
                                                               self.summaries,
                                                               self.ssim_val,
                                                               self.psnr_val], feed_dict={self.rain:img,
                                                                                           self.norain:label,
                                                                                           self.lr:lr_})

            print('Iteration:{}, phase:{}, loss:{:.4f}, ssim:{:.4f}, psnr:{:.4f}, lr:{}, iterations:{}'.format(counter,
                                                                                                                 self.config.phase,
                                                                                                                 total_loss,
                                                                                                                 ssim,
                                                                                                                 psnr,
                                                                                                                 lr_,
                                                                                                                 self.config.iterations))
                                
            self.summary_writer.add_summary(summaries, global_step=counter)
            if np.mod(counter, 100)==0:
                self.sample(self.config.sample_dir, counter)

            if np.mod(counter, 500)==0:
                self.save_model()
        
        # save final model
        if counter == self.config.iterations-1:
            self.save_model()

        # training time
        end_time = time.time()
        print('training time:{} hours'.format((end_time-start_time)/3600.0))

    # sampling phase
    def sample(self, sample_dir, iterations):
        # obtaining sampling image pairs
        test_img, test_label = read_data(self.config.test_dataset, self.config.data_path, self.batch_size, self.patch_size, self.config.testset_size)
        finer_out, finer_residual = self.sess.run([self.finer_out, self.finer_residual], feed_dict={self.rain:test_img})
        
        # save sampling images
        test_img_uint8 = np.uint8(test_img*255.0)
        test_label_uint8 = np.uint8(test_label*255.0)
        finer_out_uint8 = np.uint8(finer_out*255.0)
        finer_residual = np.uint8(finer_residual*255.0)
        sample = np.concatenate([test_img_uint8, test_label_uint8, finer_out_uint8, finer_residual], 2)
        save_images(sample, [int(np.sqrt(self.batch_size))+1, int(np.sqrt(self.batch_size))+1], '{}/{}_{}_{:04d}.jpg'.format(self.config.sample_dir,
                                                                                                                             self.config.test_dataset,
                                                                                                                             self.config.phase,
                                                                                                                             iterations))
    
    # testing phase
    def test(self):
        rain = tf.placeholder(tf.float32, [None, None, None, self.input_channels], name='test_rain')
        norain = tf.placeholder(tf.float32, [None, None, None, self.input_channels], name='test_norain')
        
        out, residual = self.derainNet(rain)
        finer_out = tf.clip_by_value(out, 0, 1.0)
        finer_residual = tf.clip_by_value(tf.abs(residual), 0, 1.0)

        ssim_val = tf.reduce_mean(self.ssim._ssim(norain, finer_out, 0, 0))
        psnr_val = tf.reduce_mean(self.psnr.compute_psnr(norain, finer_out))

        # load model
        self.saver = tf.train.Saver()
        check_bool = self.load_model()
        if check_bool:
            print('[!!!] load model successfully')
        else:
            try:
                tf.global_variables_initializer().run()
            except:
                tf.initialize_all_variables().run()
            print('[***] fail to load model')

        try:
            test_num, test_data_format, test_label_format = test_dic[self.config.test_dataset]
        except:
            print('no testing dataset named {}'.format(self.config.test_dataset))
            return

        ssim = []
        psnr = []
        for index in range(1, test_num+1):
            test_data_fn = test_data_format.format(index)
            test_label_fn = test_label_format.format(index)
            
            test_data_path = os.path.join(self.config.test_path.format(self.config.test_dataset), test_data_fn)
            test_label_path = os.path.join(self.config.test_path.format(self.config.test_dataset), test_label_fn)

            test_data_uint8 = cv2.imread(test_data_path)
            test_label_uint8 = cv2.imread(test_label_path)

            test_data_float = test_data_uint8/255.0
            test_label_float = test_label_uint8/255.0
            
            test_data = np.expand_dims(test_data_float, 0)
            test_label = np.expand_dims(test_label_float, 0)
            
            t = 0
            s_t = time.time()
            finer_out_val, finer_residual_val, tmp_ssim, tmp_psnr = self.sess.run([finer_out,
                                                                                   finer_residual,
                                                                                   ssim_val,
                                                                                   psnr_val] , feed_dict={rain:test_data,
                                                                                                          norain:test_label})

            e_t = time.time()            
            total_t = e_t - s_t
            t = t + total_t

            # save psnr and ssim metrics
            ssim.append(tmp_ssim)
            psnr.append(tmp_psnr)
            # save testing image
            test_label = np.uint8(test_label*255)
            finer_out_val = np.uint8(finer_out_val*255)
            finer_residual_val = np.uint8(finer_residual_val*255)
            save_images(finer_out_val, [1,1], '{}/{}_{}'.format(self.config.test_dir, self.config.test_dataset, test_data_fn))
            save_images(test_label, [1,1], '{}/{}'.format(self.config.test_dir, test_data_fn))
            save_images(finer_residual_val, [1,1], '{}/residual_{}'.format(self.config.test_dir, test_data_fn))
            print('test image {}: ssim:{}, psnr:{} time:{:.4f}'.format(test_data_fn, tmp_ssim, tmp_psnr, total_t))
        
        mean_ssim = np.mean(ssim)
        mean_psnr = np.mean(psnr)
        print('Test phase: ssim:{}, psnr:{}'.format(mean_ssim, mean_psnr))
        print('Average time:{}'.format(t/(test_num-1)))

    # save model            
    @property
    def model_dir(self):
        return "{}_{}_{}".format(
            self.model_name, self.config.train_dataset,
            self.batch_size)
    @property
    def model_pos(self):
        return '{}/{}/{}'.format(self.config.checkpoint_dir, self.model_dir, self.model_name)

    def save_model(self):
        if not os.path.exists(self.config.checkpoint_dir):
            os.mkdir(self.config.checkpoint_dir)
        self.saver.save(self.sess, self.model_pos)
        
    def load_model(self):
        if not os.path.isfile(os.path.join(self.config.checkpoint_dir, self.model_dir,'checkpoint')):
            return False
        else:
            self.saver.restore(self.sess, self.model_pos)
            return True
Esempio n. 23
0
    pred = model.predict(x=blur)

    batch, h, w, c = blur.shape

    mse_pred_orig = tf.keras.losses.MSE(orig.reshape(batch, h * w * c),
                                        pred.reshape(batch,
                                                     h * w * c)).numpy()
    mse_blur_orig = tf.keras.losses.MSE(orig.reshape(batch, h * w * c),
                                        blur.reshape(batch,
                                                     h * w * c)).numpy()

    psnr_pred_orig = PSNR(tf.Variable(orig), tf.Variable(pred)).numpy()
    psnr_blur_orig = PSNR(tf.Variable(orig), tf.Variable(blur)).numpy()

    ssim_pred_orig = SSIM(tf.Variable(orig), tf.Variable(pred)).numpy()
    ssim_blur_orig = SSIM(tf.Variable(orig), tf.Variable(blur)).numpy()

    # MI SERVE IL BEST MODEL PER ESTRARRE IMMAGINI

    mask = psnr_blur_orig < 30

    mse_pred_orig[mask].mean()
    psnr_pred_orig[mask].mean()
    ssim_pred_orig[mask].mean()

    idx = (ssim_pred_orig - ssim_blur_orig)

    psnr_pred_orig[6]

    psnr_pred_orig[idx > 0]
def train_model_residual_lowlight_twostage_unet():

    learning_rate = INIT_LEARNING_RATE * 0.5
    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight/')
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = TwoStageHSIDWithUNet(K)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    for epoch in range(NUM_EPOCHS):
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual, residual_stage2 = net(noisy, cubic)
            loss = loss_function_with_tvloss(
                residual, label - noisy) + loss_function_with_tvloss(
                    residual_stage2, label - noisy)

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/two_stage_unet_hsid_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual, residual_stage2 = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual_stage2

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

    tb_writer.close()
def train_model():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_cubic/')
    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    test_label_hsi = np.load('./data/origin/test_washington.npy')

    #加载测试数据
    test_data_dir = './data/test_level25/'
    test_set = HsiTrainDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #创建模型
    net = HSID_1x3(K)
    init_params(net)
    net = nn.DataParallel(net).to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer,
                            milestones=[15, 30, 45],
                            gamma=0.25)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    for epoch in range(NUM_EPOCHS):

        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):

            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            denoised_img = net(noisy, cubic)
            loss = loss_fuction(denoised_img, label)

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        scheduler.step()
        print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/hsid_{epoch}.pth")

        #预测代码
        net.eval()
        for batch_idx, (noisy, label) in enumerate(test_dataloader):
            noisy = noisy.type(torch.FloatTensor)
            label = label.type(torch.FloatTensor)

            batch_size, width, height, band_num = noisy.shape
            denoised_hsi = np.zeros((width, height, band_num))

            noisy = noisy.to(DEVICE)
            label = label.to(DEVICE)

            with torch.no_grad():
                for i in range(band_num):  #遍历每个band去处理
                    current_noisy_band = noisy[:, :, :, i]
                    current_noisy_band = current_noisy_band[:, None]

                    adj_spectral_bands = get_adjacent_spectral_bands(
                        noisy, K,
                        i)  # shape: batch_size, width, height, band_num
                    adj_spectral_bands = torch.transpose(
                        adj_spectral_bands, 3, 1
                    )  #交换第一维和第三维 ,shape: batch_size, band_num, height, width
                    denoised_band = net(current_noisy_band, adj_spectral_bands)

                    denoised_band_numpy = denoised_band.cpu().numpy().astype(
                        np.float32)
                    denoised_band_numpy = np.squeeze(denoised_band_numpy)

                    denoised_hsi[:, :, i] = denoised_band_numpy

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))

    tb_writer.close()
Esempio n. 26
0
images and objective ones.
It just takes the two optical images as source and gets the indices for them.
"""

from config import *
from metrics import PSNR, SSIM

from PIL import Image
import os
import torch
import torchvision.transforms as transforms

Image.MAX_IMAGE_PIXELS = 1000000000

psnr = PSNR()
ssim = SSIM()

for region in regions:
    o0 = Image.open(os.path.join(source_folder, region, 'o0.jpg'))
    o0 = transforms.ToTensor().__call__(o0)
    o0 = o0.unsqueeze(0)
    o1 = Image.open(os.path.join(source_folder, region, 'o1.jpg'))
    o1 = transforms.ToTensor().__call__(o1)
    o1 = o1.unsqueeze(0)

    psnr_value = psnr(o1, o0).item()
    ssim_value = ssim(o1, o0).item()

    print('>>>INDICES FOR %s<<<' % (region[2:]))
    print('PSNR: %.4f SSIM: %.4f' % (psnr_value, ssim_value))
Esempio n. 27
0
class DerainNet:
    model_name = 'ReHEN'
    '''Derain Net: all the implemented layer are included (e.g. SEBlock,
                                                                HEU,
                                                                REU,
                                                                ReHEB).
        Params:
            config: the training configuration
            sess: runing session
    '''
    def __init__(self, config, sess=None):
        # config proto
        self.config = config
        self.channel_dim = self.config.channel_dim
        self.batch_size = self.config.batch_size
        self.patch_size = self.config.patch_size
        self.input_channels = self.config.input_channels

        # metrics
        self.ssim = SSIM(max_val=1.0)
        self.psnr = PSNR(max_val=1.0)

        # create session
        self.sess = sess

    # global average pooling
    def globalAvgPool2D(self, input_x):
        global_avgpool2d = tf.contrib.keras.layers.GlobalAvgPool2D()
        return global_avgpool2d(input_x)

    # leaky relu
    def leakyRelu(self, input_x):
        leaky_relu = tf.contrib.keras.layers.LeakyReLU(alpha=0.2)
        return leaky_relu(input_x)

    # squeeze-and-excitation block
    def SEBlock(self, input_x, input_dim=32, reduce_dim=8, scope='SEBlock'):
        with tf.variable_scope(scope) as scope:
            # global scale
            global_pl = self.globalAvgPool2D(input_x)
            reduce_fc1 = slim.fully_connected(global_pl,
                                              reduce_dim,
                                              activation_fn=tf.nn.relu)
            reduce_fc2 = slim.fully_connected(reduce_fc1,
                                              input_dim,
                                              activation_fn=None)
            g_scale = tf.nn.sigmoid(reduce_fc2)
            g_scale = tf.expand_dims(g_scale, axis=1)
            g_scale = tf.expand_dims(g_scale, axis=1)
            gs_input = input_x * g_scale
            return gs_input

    # recurrent enhancement unit
    def REU(self, input_x, h, out_dim, scope='REU'):
        with tf.variable_scope(scope):
            if h is None:
                self.conv_xz = slim.conv2d(input_x,
                                           out_dim,
                                           3,
                                           1,
                                           scope='conv_xz')
                self.conv_xn = slim.conv2d(input_x,
                                           out_dim,
                                           3,
                                           1,
                                           scope='conv_xn')
                z = tf.nn.sigmoid(self.conv_xz)
                f = tf.nn.tanh(self.conv_xn)
                h = z * f
            else:
                self.conv_hz = slim.conv2d(h, out_dim, 3, 1, scope='conv_hz')
                self.conv_hr = slim.conv2d(h, out_dim, 3, 1, scope='conv_hr')

                self.conv_xz = slim.conv2d(input_x,
                                           out_dim,
                                           3,
                                           1,
                                           scope='conv_xz')
                self.conv_xr = slim.conv2d(input_x,
                                           out_dim,
                                           3,
                                           1,
                                           scope='conv_xr')
                self.conv_xn = slim.conv2d(input_x,
                                           out_dim,
                                           3,
                                           1,
                                           scope='conv_xn')
                r = tf.nn.sigmoid(self.conv_hr + self.conv_xr)
                z = tf.nn.sigmoid(self.conv_hz + self.conv_xz)

                self.conv_hn = slim.conv2d(r * h,
                                           out_dim,
                                           3,
                                           1,
                                           scope='conv_hn')
                n = tf.nn.tanh(self.conv_xn + self.conv_hn)
                h = (1 - z) * h + z * n

        # channel attention block
        se = self.SEBlock(h, out_dim, reduce_dim=int(out_dim / 4))
        h = self.leakyRelu(se)
        return h, h

    # hierarchy enhancement unit
    def HEU(self, input_x, is_training=False, scope='HEU'):
        with tf.variable_scope(scope) as scope:
            local_shortcut = input_x
            dense_shortcut = input_x

            for i in range(1, 3):
                with tf.variable_scope('ResBlock_{}'.format(i)):
                    with tf.variable_scope('Conv1'):
                        conv_tmp1 = slim.conv2d(local_shortcut,
                                                self.channel_dim, 3, 1)
                        conv_tmp1_bn = bn(conv_tmp1, is_training,
                                          UPDATE_G_OPS_COLLECTION)
                        out_tmp1 = tf.nn.relu(conv_tmp1_bn)

                    with tf.variable_scope('Conv2'):
                        conv_tmp2 = slim.conv2d(out_tmp1, self.channel_dim, 3,
                                                1)
                        conv_tmp2_bn = bn(conv_tmp2, is_training,
                                          UPDATE_G_OPS_COLLECTION)
                        out_tmp2 = tf.nn.relu(conv_tmp2_bn)
                        conv_shortcut = tf.add(local_shortcut, out_tmp2)

                dense_shortcut = tf.concat([dense_shortcut, conv_shortcut], -1)
                local_shortcut = conv_shortcut

            with tf.variable_scope('Trans'):
                conv_tmp3 = slim.conv2d(dense_shortcut, self.channel_dim, 3, 1)
                conv_tmp3_bn = bn(conv_tmp3, is_training,
                                  UPDATE_G_OPS_COLLECTION)
                conv_tmp3_se = self.SEBlock(conv_tmp3_bn,
                                            self.channel_dim,
                                            reduce_dim=int(self.channel_dim /
                                                           4))
                out_tmp3 = tf.nn.relu(conv_tmp3_se)
                heu_f = tf.add(input_x, out_tmp3)

            return heu_f

    # recurrent hierarchy enhancement block
    def ReHEB(self, input_x, h, is_training=False, scope='ReHEB'):
        with tf.variable_scope(scope):
            if input_x.get_shape().as_list()[-1] == 3:
                heu = input_x
            else:
                heu = self.HEU(input_x, is_training=is_training)
            reheb, h = self.REU(heu, h, out_dim=self.channel_dim)
        return reheb, h

    # recurrent hierarchy and enhancement network
    def derainNet(self, input_x, is_training=False, scope_name='derainNet'):
        '''ReHEN: recurrent hierarchy and enhancement network
            Params:
                input_x: input data
                is_training: training phase or testing phase
                scope_name: the scope name of the ReHEN (customer definition, default='derainNet')
            Return:
                return the derained results

            Input shape:
                4D tensor with shape '(batch_size, height, width, channels)'
                
            Output shape:
                4D tensor with shape '(batch_size, height, width, channels)'
        '''
        # reuse: tf.AUTO_REUSE(such setting will enable the network to reuse parameters automatically)
        with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
            # convert is_training variable to tensor type
            is_training = tf.convert_to_tensor(is_training,
                                               dtype='bool',
                                               name='is_training')
            with slim.arg_scope(
                [slim.conv2d, slim.conv2d_transpose],
                    weights_initializer=tf.contrib.layers.xavier_initializer(),
                    normalizer_fn=None,
                    activation_fn=None,
                    padding='SAME'):

                stages = 4
                block_num = 5
                old_states = [None for _ in range(block_num)]
                oups = []
                ori = input_x
                shallow_f = input_x

                for stg in range(stages):
                    # recurrent hierarchy enhancement block (ReHEB)
                    with tf.variable_scope('ReHEB'):
                        states = []
                        for i in range(block_num):
                            sp = 'ReHEB_{}'.format(i)
                            shallow_f, st = self.ReHEB(shallow_f,
                                                       old_states[i],
                                                       is_training=is_training,
                                                       scope=sp)
                            states.append(st)

                    further_f = shallow_f

                    # residual map generator (RMG)
                    with tf.variable_scope('RMG'):
                        rm_conv = slim.conv2d(further_f, self.channel_dim, 3,
                                              1)
                        rm_conv_se = self.SEBlock(rm_conv,
                                                  self.channel_dim,
                                                  reduce_dim=int(
                                                      self.channel_dim / 4))
                        rm_conv_a = self.leakyRelu(rm_conv_se)
                        neg_residual_conv = slim.conv2d(
                            rm_conv_a, self.input_channels, 3, 1)
                        neg_residual = neg_residual_conv
                    shallow_f = ori - neg_residual
                    oups.append(shallow_f)
                    old_states = [tf.identity(s) for s in states]

        return oups, shallow_f, neg_residual

    def build(self):
        # placeholder
        self.rain = tf.placeholder(tf.float32,
                                   [None, None, None, self.input_channels],
                                   name='rain')
        self.norain = tf.placeholder(tf.float32,
                                     [None, None, None, self.input_channels],
                                     name='norain')
        self.lr = tf.placeholder(tf.float32, None, name='learning_rate')

        # derainnet
        self.oups, self.out, self.residual = self.derainNet(
            self.rain, is_training=self.config.is_training)
        self.finer_out = tf.clip_by_value(self.out, 0, 1.0)
        self.finer_residual = tf.clip_by_value(tf.abs(self.residual), 0, 1)

        # metrics
        self.ssim_finer_tensor = tf.reduce_mean(
            self.ssim._ssim(self.norain, self.out, 0, 0))
        self.psnr_finer_tensor = tf.reduce_mean(
            self.psnr.compute_psnr(self.norain, self.out))
        self.ssim_val = tf.reduce_mean(
            self.ssim._ssim(self.norain, self.finer_out, 0, 0))
        self.psnr_val = tf.reduce_mean(
            self.psnr.compute_psnr(self.norain, self.finer_out))

        # loss function
        # MSE loss
        self.l2_loss = tf.reduce_sum([
            tf.reduce_mean(tf.square(out - self.norain)) for out in self.oups
        ])
        # SSIM loss
        self.ssim_loss = tf.log(1.0 / (self.ssim_finer_tensor + 1e-5))
        # PSNR loss
        self.psnr_loss = 1.0 / (self.psnr_finer_tensor + 1e-3)
        # total loss
        self.total_loss = self.l2_loss + 0.001 * self.ssim_loss + 0.1 * self.psnr_loss

        # optimization
        t_vars = tf.trainable_variables()
        g_vars = [var for var in t_vars if 'derainNet' in var.name]
        loss_train_ops = tf.train.AdamOptimizer(
            learning_rate=self.lr,
            beta1=self.config.beta1,
            beta2=self.config.beta2).minimize(self.total_loss, var_list=g_vars)

        # batchnorm training ops
        batchnorm_ops = tf.get_collection(UPDATE_G_OPS_COLLECTION)
        bn_update_ops = tf.group(*batchnorm_ops)
        self.train_ops = tf.group(loss_train_ops, bn_update_ops)

        # summary
        self.l2_loss_summary = tf.summary.scalar('l2_loss', self.l2_loss)
        self.total_loss_summary = tf.summary.scalar('total_loss',
                                                    self.total_loss)
        self.edge_loss_summary = tf.summary.scalar('ssim_loss', self.ssim_loss)
        self.edge_loss_summary = tf.summary.scalar('psnr_loss', self.psnr_loss)
        self.ssim_summary = tf.summary.scalar('ssim', self.ssim_val)
        self.psnr_summary = tf.summary.scalar('psnr', self.psnr_val)
        self.summaries = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter(self.config.logs_dir,
                                                    self.sess.graph)

        # saver
        global_variables = tf.global_variables()
        var_to_store = [
            var for var in global_variables if 'derainNet' in var.name
        ]
        self.saver = tf.train.Saver(var_list=var_to_store)

        # trainable variables
        num_params = 0
        for var in g_vars:
            tmp_num = 1
            for i in var.get_shape().as_list():
                tmp_num = tmp_num * i
            num_params = num_params + tmp_num
        print('numbers of trainable parameters:{}'.format(num_params))

    # training phase
    def train(self):
        # initialize variables
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        # load training model
        check_bool = self.load_model()
        if check_bool:
            print('[!!!] load model successfully')
        else:
            print('[***] fail to load model')

        lr_ = self.config.lr
        start_time = time.time()
        for counter in range(self.config.iterations):
            if counter == 50000:
                lr_ = 0.1 * lr_

            # obtain training image pairs
            img, label = read_data(self.config.train_dataset,
                                   self.config.data_path, self.batch_size,
                                   self.patch_size, self.config.trainset_size)
            _, total_loss, summaries, ssim, psnr = self.sess.run(
                [
                    self.train_ops, self.total_loss, self.summaries,
                    self.ssim_val, self.psnr_val
                ],
                feed_dict={
                    self.rain: img,
                    self.norain: label,
                    self.lr: lr_
                })

            print(
                'Iteration:{}, phase:{}, loss:{:.4f}, ssim:{:.4f}, psnr:{:.4f}, lr:{}, iterations:{}'
                .format(counter, self.config.phase, total_loss, ssim, psnr,
                        lr_, self.config.iterations))

            self.summary_writer.add_summary(summaries, global_step=counter)
            if np.mod(counter, 100) == 0:
                self.sample(self.config.sample_dir, counter)

            if np.mod(counter, 500) == 0:
                self.save_model()

        # save final model
        if counter == self.config.iterations - 1:
            self.save_model()

        # training time
        end_time = time.time()
        print('training time:{} hours'.format(
            (end_time - start_time) / 3600.0))

    # sampling phase
    def sample(self, sample_dir, iterations):
        # obtaining sampling image pairs
        test_img, test_label = read_data(self.config.test_dataset,
                                         self.config.data_path,
                                         self.batch_size, self.patch_size,
                                         self.config.testset_size)
        finer_out, finer_residual = self.sess.run(
            [self.finer_out, self.finer_residual],
            feed_dict={self.rain: test_img})

        # save sampling images
        test_img_uint8 = np.uint8(test_img * 255.0)
        test_label_uint8 = np.uint8(test_label * 255.0)
        finer_out_uint8 = np.uint8(finer_out * 255.0)
        finer_residual = np.uint8(finer_residual * 255.0)
        sample = np.concatenate([
            test_img_uint8, test_label_uint8, finer_out_uint8, finer_residual
        ], 2)
        save_images(
            sample, [
                int(np.sqrt(self.batch_size)) + 1,
                int(np.sqrt(self.batch_size)) + 1
            ], '{}/{}_{}_{:04d}.jpg'.format(self.config.sample_dir,
                                            self.config.test_dataset,
                                            self.config.phase, iterations))

    # testing phase
    def test(self):
        rain = tf.placeholder(tf.float32,
                              [None, None, None, self.input_channels],
                              name='test_rain')
        norain = tf.placeholder(tf.float32,
                                [None, None, None, self.input_channels],
                                name='test_norain')

        oups, out, residual = self.derainNet(
            rain, is_training=self.config.is_training)
        finer_out = tf.clip_by_value(out, 0, 1.0)
        finer_residual = tf.clip_by_value(tf.abs(residual), 0, 1.0)

        ssim_val = tf.reduce_mean(self.ssim._ssim(norain, finer_out, 0, 0))
        psnr_val = tf.reduce_mean(self.psnr.compute_psnr(norain, finer_out))

        # load model
        self.saver = tf.train.Saver()
        check_bool = self.load_model()
        if check_bool:
            print('[!!!] load model successfully')
        else:
            try:
                tf.global_variables_initializer().run()
            except:
                tf.initialize_all_variables().run()
            print('[***] fail to load model')

        try:
            test_num, test_data_format, test_label_format = test_dic[
                self.config.test_dataset]
        except:
            print('no testing dataset named {}'.format(
                self.config.test_dataset))
            return

        ssim = []
        psnr = []
        for index in range(1, test_num + 1):
            test_data_fn = test_data_format.format(index)
            test_label_fn = test_label_format.format(index)

            test_data_path = os.path.join(
                self.config.test_path.format(self.config.test_dataset),
                test_data_fn)
            test_label_path = os.path.join(
                self.config.test_path.format(self.config.test_dataset),
                test_label_fn)

            test_data_uint8 = cv2.imread(test_data_path)
            test_label_uint8 = cv2.imread(test_label_path)

            test_data_float = test_data_uint8 / 255.0
            test_label_float = test_label_uint8 / 255.0

            test_data = np.expand_dims(test_data_float, 0)
            test_label = np.expand_dims(test_label_float, 0)

            t = 0
            s_t = time.time()
            finer_out_val, finer_residual_val, tmp_ssim, tmp_psnr = self.sess.run(
                [finer_out, finer_residual, ssim_val, psnr_val],
                feed_dict={
                    rain: test_data,
                    norain: test_label
                })

            e_t = time.time()
            total_t = e_t - s_t
            t = t + total_t

            # save psnr and ssim metrics
            ssim.append(tmp_ssim)
            psnr.append(tmp_psnr)
            # save testing image
            test_label = np.uint8(test_label * 255)
            finer_out_val = np.uint8(finer_out_val * 255)
            finer_residual_val = np.uint8(finer_residual_val * 255)
            save_images(
                finer_out_val, [1, 1],
                '{}/{}_{}'.format(self.config.test_dir,
                                  self.config.test_dataset, test_data_fn))
            save_images(test_label, [1, 1],
                        '{}/{}'.format(self.config.test_dir, test_data_fn))
            save_images(
                finer_residual_val, [1, 1],
                '{}/residual_{}'.format(self.config.test_dir, test_data_fn))
            print('test image {}: ssim:{}, psnr:{} time:{:.4f}'.format(
                test_data_fn, tmp_ssim, tmp_psnr, total_t))

        mean_ssim = np.mean(ssim)
        mean_psnr = np.mean(psnr)
        print('Test phase: ssim:{}, psnr:{}'.format(mean_ssim, mean_psnr))
        print('Average time:{}'.format(t / (test_num - 1)))

    # save model
    @property
    def model_dir(self):
        return "{}_{}_{}".format(self.model_name, self.config.train_dataset,
                                 self.batch_size)

    @property
    def model_pos(self):
        return '{}/{}/{}'.format(self.config.checkpoint_dir, self.model_dir,
                                 self.model_name)

    def save_model(self):
        if not os.path.exists(self.config.checkpoint_dir):
            os.mkdir(self.config.checkpoint_dir)
        self.saver.save(self.sess, self.model_pos)

    def load_model(self):
        if not os.path.isfile(
                os.path.join(self.config.checkpoint_dir, self.model_dir,
                             'checkpoint')):
            return False
        else:
            self.saver.restore(self.sess, self.model_pos)
            return True
Esempio n. 28
0
def evaluate(dataloader, D, G, sample_size, device, now, region):
    """
    This function is used to evaluate the performance of the model.
    It is not a validation funtion, but a test one because uses a test dataset
    different from the training one and wants to get information on how well
    the model perform. Hyperparameters are taken from the paper from which the
    idea comes.
    The process is similar to training. Data is taken from dataloader, split
    and then noise is added to create the multi-layer input sample for the
    generator and the ground truth for the discriminator. The prediction is
    normalized in the range [0:1] because G does not have sigmoid layer.
    The functions for the performance data run and data is saved in the
    correct list. Images are saved in lists too.

    Args:
        dataloader (obj): this is the dataloader that loads the test dataset
        D (obj): discriminator model, acts as a function
        G (obj): generator model, acts as a function
        sample_size (int): width and height of the samples
        device (obj): type of device used to process data, used to exploit gpus
        now (str): stores the name for the folder where to save images
    """

    # Lists from config file are loaded to save performance data
    # Modules saved as variables to run evaluation indeces
    global eval_data
    psnr = PSNR()
    ssim = SSIM()

    i = 0
    start = time.time()
    # To not use the gradient attribute
    with torch.no_grad():
        # Cycle over the dataloader
        print('^^^^^^^^^^^^^^^^')
        print('>>>TEST PHASE<<<')
        print('Evaluation loop on region' + region)
        for img, _ in dataloader:
            # Get back to three dimensions
            img = img.squeeze(0).to(device)
            if i % 10 == 0 or i == 1:
                print('Validation on sample ', i)

            # Ground truth directly prepared for comparison (range [0:1])
            y = img[:3, :, :].to(device)
            y = transforms.Normalize(mean=[-1, -1, -1], std=[2, 2,
                                                             2]).__call__(y)
            y = y.unsqueeze(0).to(device)
            #print('Memory allocated:', torch.cuda.memory_allocated())

            # Backup images
            x = img[3:, :, :].to(device)
            x = x.unsqueeze(0).to(device)
            #print('Memory allocated:', torch.cuda.memory_allocated())

            # Prediction normalized in range [0:1]
            y_pred = G(x).cuda()
            y_pred = y_pred.squeeze(0).to(device)
            y_pred = transforms.Normalize(mean=[-1, -1, -1],
                                          std=[2, 2, 2]).__call__(y_pred)
            y_pred = y_pred.unsqueeze(0).to(device)
            #print('Memory allocated:', torch.cuda.memory_allocated())

            # Data set in local variables
            psnr_value = psnr(y_pred, y).item()
            ssim_value = ssim(y_pred, y).item()

            # Data showed during test
            if i % 10 == 0:
                print('Iter: [%d/%d]\tPSNR: %.4f\tSSIM: %.4f\t' %
                      (i, len(dataloader), psnr_value, ssim_value))

            # Data saved in lists
            eval_data['psnr_list'].append(psnr_value)
            eval_data['ssim_list'].append(ssim_value)

            # Images saved in the correct folder
            y = y.squeeze(0)
            y_pred = y_pred.squeeze(0)

            f = transforms.ToPILImage().__call__(y_pred.detach().cpu())
            f.save(
                os.path.join(save_path, region, now, switch[1], 'fake',
                             'save' + str(i) + '.jpg'))
            r = transforms.ToPILImage().__call__(y.detach().cpu())
            r.save(
                os.path.join(save_path, region, now, switch[1], 'real',
                             'save' + str(i) + '.jpg'))

            i += 1
    end = time.time()
    print('Elapsed time in minutes:', (end - start) / 60)
Esempio n. 29
0
def train_model_multistage_lowlight():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight_patchsize32/')
    #print('trainset32 training example:', len(train_set32))

    #train_set_64 = HsiCubicTrainDataset('./data/train_lowlight_patchsize64/')

    #train_set_list = [train_set32, train_set_64]
    #train_set = ConcatDataset(train_set_list) #里面的样本大小必须是一致的,否则会连接失败
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = MultiStageHSID(K)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0
    start_epoch = 1
    num_epoch = 100

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        print(scheduler.get_lr())
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(noisy, cubic)
            #loss = loss_fuction(residual, label-noisy)
            loss = np.sum([
                loss_fuction(residual[j], label) for j in range(len(residual))
            ])
            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/hsid_multistage_patchsize64_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual[0]

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual[0], axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if psnr > best_psnr:
            best_psnr = psnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                }, f"checkpoints/hsid_multistage_patchsize64_best.pth")

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     scheduler.get_lr()[0]))
        print(
            "------------------------------------------------------------------"
        )

        #保存当前模型
        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict()
            }, os.path.join('./checkpoints', "model_latest.pth"))
    tb_writer.close()
Esempio n. 30
0
def fit_model(model, data_loaders, channels, criterion, optimizer, scheduler, device, n_epochs, val_freq, checkpoint_dir, model_name):
    """
    Training of the denoiser model.
    :param model: torch Module
        Neural network to fit.
    :param data_loaders: dict
        Dictionary with torch DataLoaders with training and validation datasets.
    :param channels: int
        Number of image channels
    :param criterion: torch Module
        Loss function.
    :param optimizer: torch Optimizer
        Gradient descent optimization algorithm.
    :param scheduler: torch lr_scheduler
        Learning rate scheduler.
    :param device: torch device
        Device used during training (CPU/GPU).
    :param n_epochs: int
        Number of epochs to fit the model.
    :param val_freq: int
        How many training epochs to run between validations.
    :param checkpoint_dir: str
        Path to the directory where the model checkpoints and CSV log files will be stored.
    :param model_name: str
        Prefix name of the trained model saved in checkpoint_dir.
    :return: None
    """
    psnr = PSNR(data_range=1., reduction='sum')
    ssim = SSIM(channels, data_range=1., reduction='sum')
    os.makedirs(checkpoint_dir, exist_ok=True)
    logfile_path = os.path.join(checkpoint_dir,  ''.join([model_name, '_logfile.csv']))
    model_path = os.path.join(checkpoint_dir, ''.join([model_name, '-{:03d}-{:.4e}-{:.4f}-{:.4f}.pth']))
    file_logger = FileLogger(logfile_path)
    best_model_path, best_psnr = '', -np.inf
    since = time.time()

    for epoch in range(1, n_epochs + 1):
        lr = optimizer.param_groups[0]['lr']
        epoch_logger = EpochLogger()
        epoch_log = dict()

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                print('\nEpoch: {}/{} - Learning rate: {:.4e}'.format(epoch, n_epochs, lr))
                description = 'Training - Loss:{:.5e} - PSNR:{:.5f} - SSIM:{:.5f}'
            elif phase == 'val' and epoch % val_freq == 0:
                model.eval()
                description = 'Validation - Loss:{:.5e} - PSNR:{:.5f} - SSIM:{:.5f}'
            else:
                break

            iterator = tqdm(enumerate(data_loaders[phase], 1), total=len(data_loaders[phase]), ncols=110)
            iterator.set_description(description.format(0, 0, 0))
            n_samples = 0

            for step, (inputs, targets) in iterator:
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                n_samples += inputs.size()[0]
                metrics = {
                    'loss': loss.item() * inputs.size()[0],
                    'psnr': psnr(outputs, targets).item(),
                    'ssim': ssim(outputs, targets).item()
                }
                epoch_logger.update_log(metrics, phase)
                log = epoch_logger.get_log(n_samples, phase)
                iterator.set_description(description.format(log[phase + ' loss'], log[phase + ' psnr'], log[phase + ' ssim']))

            if phase == 'val':
                # Apply Reduce LR On Plateau if it is the case and save the model if the validation PSNR is improved.
                if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(log['val psnr'])
                if log['val psnr'] > best_psnr:
                    best_psnr = log['val psnr']
                    best_model_path = model_path.format(epoch, log['val loss'], log['val psnr'], log['val ssim'])
                    torch.save(model.state_dict(), best_model_path)

            elif scheduler is not None:         # Apply another scheduler at epoch level.
                scheduler.step()

            epoch_log = {**epoch_log, **log}

        # Save the current epoch metrics in a CVS file.
        epoch_data = {'epoch': epoch, 'learning rate': lr, **epoch_log}
        file_logger(epoch_data)

    # Save the last model and report training time.
    best_model_path = model_path.format(epoch, log['val loss'], log['val psnr'], log['val ssim'])
    torch.save(model.state_dict(), best_model_path)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best PSNR: {:4f}'.format(best_psnr))