Пример #1
0
 def compile(self, d_optimizer, g_optimizer, loss_fn, **kwargs):
     super(SRGan, self).compile(**kwargs)
     self.d_optimizer = d_optimizer
     self.g_optimizer = g_optimizer
     self.loss_fn = loss_fn
     self._psnr = PSNR(max_val=1.0)
     self._ssim = SSIM(max_val=1.0)
Пример #2
0
    def _calc_psnr(r_vid, c_vid):
        # calculate PSNR

        np_r_vid = sk.vread(r_vid)
        np_c_vid = sk.vread(c_vid)

        psnr = PSNR(data_range=255).calc_video(np_r_vid, np_c_vid)

        return psnr
Пример #3
0
    def __init__(self,
                 model,
                 dataloaders,
                 inc_overhead=False,
                 if_codec=None,
                 standard_epe=False):

        # use GPU if available
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        # model to device & inference mode
        self.model = model.to(self.device)
        self.model.train(False)

        # video dataloaders
        vid_dls = dataloaders
        self.f_s = vid_dls.f_s

        self.n_gop = vid_dls.n_gop
        if "PFrame" in self.model.name:
            # remove reference frame
            self.n_gop = self.n_gop - 1
        elif "BFrame" in self.model.name:
            # remove reference frames
            self.n_gop = self.n_gop - 2

        self.vid_dls = vid_dls.get_data_loaders()

        # I-Frame image codec
        self.if_codec = if_codec
        if if_codec is not None:
            self.img_codec = ImageCodec(codec=if_codec)

        # include overhead bits
        self.inc_overhead = inc_overhead

        # evaluation metrics

        # SSIM
        self.ssim = SSIM(
            data_range=1,
            multichannel=True,
            gaussian_weights=True,
        )

        # PSNR
        self.psnr = PSNR(data_range=1)

        # EPE using Farneback or LiteFlowNet
        self.epe = EPE(standard=standard_epe)

        self.standard_epe = standard_epe

        # VMAF
        self.vmaf = VMAF()
Пример #4
0
    def __init__(self,
                 video_dir,
                 vid_ext="mp4",
                 frame_size=(224, 320),
                 num_frames=24,
                 method="DS",
                 mb_size=16,
                 search_dist=7):

        # per frame transform
        frame_transform = tf.Compose(
            [NpFrame2PIL("RGB"), tf.Resize(frame_size)])

        # composed transform
        video_transform = tf.Compose([
            CropVideoSequence(num_frames=num_frames),
            tf.Lambda(lambda frames: np.stack(
                [frame_transform(frame) for frame in frames]))
        ])

        # check video directory
        self.video_dataset = VideoDataset(root_dir=video_dir,
                                          vid_ext=vid_ext,
                                          transform=video_transform)

        self.num_videos = len(self.video_dataset)

        # motion parameters
        self.frame_size = frame_size
        self.num_frames = num_frames
        self.method = method
        self.mb_size = mb_size
        self.search_dist = search_dist

        # evaluation metrics

        # SSIM
        self.ssim = SSIM(
            data_range=255,
            multichannel=True,
            gaussian_weights=True,
        )

        # EPE using LiteFLowNet
        self.epe = EPE(standard=False)

        # PSNR
        self.psnr = PSNR(data_range=255)

        # VMAF
        self.vmaf = VMAF()
Пример #5
0
    def __init__(self, config, sess=None):
        # config proto
        self.config = config
        self.channel_dim = self.config.channel_dim
        self.batch_size = self.config.batch_size
        self.patch_size = self.config.patch_size
        self.input_channels = self.config.input_channels

        # metrics
        self.ssim = SSIM(max_val=1.0)
        self.psnr = PSNR(max_val=1.0)

        # create session
        self.sess = sess
Пример #6
0
    def train(self):
        dataset = self.dataset(self.args.data,
                               self.args.gt,
                               self.args.val_size,
                               crop_div=8)
        out_dir = self._create_out_dir()

        train_dataset = dataset.get_dataset(data_type='train',
                                            batch_size=self.args.batch,
                                            repeat_count=None,
                                            crop_params=('random',
                                                         self.args.crop))

        valid_dataset = dataset.get_dataset(data_type='valid',
                                            batch_size=1,
                                            repeat_count=None,
                                            crop_params=('center',
                                                         self.args.crop))

        optimizer = tf.keras.optimizers.Adam(self.args.lr, beta_1=.5)
        self.model.compile(loss=MSE(),
                           optimizer=optimizer,
                           metrics=[PSNR(), SSIM()])

        early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_PSNR',
                                                      patience=10,
                                                      mode='max')
        csv_log = tf.keras.callbacks.CSVLogger(
            os.path.join(out_dir, f'train_{self.args.model}.log'))

        checkpoints_path = os.path.join(out_dir, self.args.model)
        saver = tf.keras.callbacks.ModelCheckpoint(
            filepath=checkpoints_path + '_{epoch:03d}-{val_PSNR:.2f}.h5',
            monitor='val_PSNR')

        callbacks = [csv_log, saver]
        if not self.args.no_early_stop:
            callbacks.append(early_stop)

        self.model.fit(train_dataset,
                       epochs=self.args.epochs,
                       callbacks=callbacks,
                       validation_data=valid_dataset,
                       verbose=1,
                       validation_steps=1,
                       steps_per_epoch=dataset.numel * self.args.repeat //
                       self.args.batch)
Пример #7
0
def test(opt, netG):
    aver_psnr = 0.0
    # aver_ssim = 0.0
    counter = 0

    test = ReadConcat(opt.dataroot, transform=image_transform)
    testset = DataLoader(test, batch_size=1, shuffle=False)
    check_folder(opt.out_dir)
    netG.eval()

    for i, data in enumerate(testset):
        counter = i
        data_A = data['A']  # blur
        data_B = data['B']  # sharp
        if torch.cuda.is_available():
            data_A = data_A.cuda()
            data_B = data_B.cuda()
        with torch.no_grad():
            realA = Variable(data_A)
            realB = Variable(data_B)

        fakeB, _ = netG.forwarda2b(realA)
        # fakeB = image_recovery(fakeB.squeeze().cpu().detach().numpy())
        # realB = image_recovery(realB.squeeze().cpu().detach().numpy())
        fakeB = image_recovery(fakeB)
        realB = image_recovery(realB)

        aver_psnr += PSNR(fakeB, realB)
        # fakeB = Image.fromarray(fakeB)
        # realB = Image.fromarray(realB)
        # aver_ssim += SSIM(fakeB, realB)

        # save image
        img_path = data['img_name']
        save_path = os.path.join(opt.out_dir, img_path[0])
        save_image(fakeB, save_path)
        print('save successfully {}'.format(save_path))

    aver_psnr /= counter
    # aver_ssim /= counter
    print('PSNR = %f' % (aver_psnr))
def main():
    args = _parse_args()

    model_dir = os.path.join(args.model_dir, 'srresnet')

    train_dataset = get_dataset(args.image_dir, args.batch_size,
                                args.crop_size, args.downscale_factor)

    mse_loss = tf.losses.MeanSquaredError()
    optimizer = tf.optimizers.Adam(1e-4)
    model = Generator(upscale_factor=args.downscale_factor)
    model.build((1, None, None, 3))

    model.compile(optimizer=optimizer,
                  loss=mse_loss,
                  metrics=[PSNR(max_val=1.0),
                           SSIM(max_val=1.0)])

    if args.weights_path:
        model.load_weights(args.weights_path)
        print('Loaded weights from {}'.format(args.weights_path))

    callbacks = [
        tf.keras.callbacks.TensorBoard(log_dir=os.path.join(model_dir, 'logs'),
                                       update_freq=500),
        tf.keras.callbacks.ModelCheckpoint(
            os.path.join(model_dir, 'srresnet_weights_epoch_') + '{epoch}',
            save_weights_only=True,
            save_best_only=False,
            monitor='loss',
            verbose=1,
            save_freq=args.step_per_epoch // 2),
    ]

    epochs = args.iterations // args.step_per_epoch

    model.fit(train_dataset,
              epochs=epochs,
              steps_per_epoch=args.step_per_epoch,
              callbacks=callbacks)
Пример #9
0
    def evaluate(self):
        dataset = self.dataset(self.args.data, self.args.gt,
                               self.args.val_size)

        if self.args.val_size == 1. or self.args.val_size == 0.:
            dataset_length = dataset.numel
            data_type = 'all'
        else:
            dataset_length = int(dataset.numel * self.args.val_size)
            data_type = 'valid'

        eval_dataset = dataset.get_dataset(data_type=data_type,
                                           batch_size=1,
                                           repeat_count=1,
                                           crop_params=('pad', self.args.crop))

        prog_bar = ProgressBar(dataset_length, title='Evaluation')
        # TODO set metrics from config
        metrics = [PSNR(), SSIM()]

        if self.args.cpbd_mae:
            metrics.append(CPBD_MAE())
        if self.args.cpbd_mae:
            metrics.append(CPBD_PRED())
        if self.args.cpbd_mae:
            metrics.append(CPBD_TRUE())
        if self.args.lpips:
            metrics.append(LPIPS())

        for inputs, result in eval_dataset:
            predict = self.model.predict(inputs)
            update_str = ''
            for metric in metrics:
                metric.update_state(result.numpy(), predict)
                update_str += f'{metric.name}={metric.result():.4f} - '
            prog_bar.update(update_str)
Пример #10
0
def get_model(gaussian=False):
    """Constructs a Fourier MLP model for 2D image regression
    with default arguments
    """

    config.update({'gaussian': gaussian}, allow_val_change=True)
    model = FourierMLP(config.num_layers,
                       config.num_units,
                       config.num_units_final,
                       gaussian=config.gaussian,
                       staddev=config.staddev,
                       num_units_FFM=config.num_units_FFM)

    loss_fn = tf.keras.losses.MeanSquaredError()

    model.compile(optimizer=tf.keras.optimizers.Adam(
        learning_rate=config.learning_rate,
        beta_1=config.beta_1,
        beta_2=config.beta_2,
        epsilon=config.epsilon),
                  loss=loss_fn,
                  metrics=['accuracy', PSNR()])

    return model
def predict_lowlight_residual():

    #加载模型
    #hsid = HSID(36)
    hsid = MultiStageHSIDUpscale(36)
    #hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    hsid = hsid.to(DEVICE)
    hsid.load_state_dict(torch.load('./checkpoints/hsid_multistage_upscale_patchsize64_best.pth', map_location='cuda:0')['gen'])

    #加载数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label = scio.loadmat(mat_src_path)['label']
    #test=test.transpose((2,0,1)) #将通道维放在最前面:191*1280*307

    test_data_dir = './data/test_lowlight/origin/'
    test_set = HsiLowlightTestDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    for batch_idx, (noisy, label) in enumerate(test_dataloader):
        noisy = noisy.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)
        
        batch_size, width, height, band_num = noisy.shape
        denoised_hsi = np.zeros((width, height, band_num))

        noisy = noisy.to(DEVICE)
        label = label.to(DEVICE)

        with torch.no_grad():
            for i in range(band_num): #遍历每个band去处理
                current_noisy_band = noisy[:,:,:,i]
                current_noisy_band = current_noisy_band[:,None]

                adj_spectral_bands = get_adjacent_spectral_bands(noisy, K, i)# shape: batch_size, width, height, band_num
                adj_spectral_bands = adj_spectral_bands.permute(0, 3,1,2)#交换第一维和第三维 ,shape: batch_size, band_num, height, width               
                adj_spectral_bands = adj_spectral_bands.to(DEVICE)
                residual = hsid(current_noisy_band, adj_spectral_bands)
                denoised_band = current_noisy_band + residual[0]

                denoised_band_numpy = denoised_band.cpu().numpy().astype(np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:,:,i] = denoised_band_numpy

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat', {'denoised': denoised_hsi})

    psnr = PSNR(denoised_hsi, test_label)
    ssim = SSIM(denoised_hsi, test_label)
    sam = SAM(denoised_hsi, test_label)
    #计算pnsr和ssim
    print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(psnr, ssim, sam)) 
Пример #12
0
def main(args):
    # loading training and test data
    logger.info("Loading test data...")
    test_data, test_answ = load_test_data(args.dataset, args.dataset_dir, args.test_size, args.patch_size)
    logger.info("Test data was loaded\n")

    logger.info("Loading training data...")
    train_data, train_answ = load_batch(args.dataset, args.dataset_dir, args.train_size, args.patch_size)
    logger.info("Training data was loaded\n")

    TEST_SIZE = test_data.shape[0]
    num_test_batches = int(test_data.shape[0] / args.batch_size)

    # defining system architecture
    with tf.Graph().as_default(), tf.Session() as sess:

        # placeholders for training data
        phone_ = tf.placeholder(tf.float32, [None, args.patch_size])
        phone_image = tf.reshape(phone_, [-1, args.patch_height, args.patch_width, 3])

        dslr_ = tf.placeholder(tf.float32, [None, args.patch_size])
        dslr_image = tf.reshape(dslr_, [-1, args.patch_height, args.patch_width, 3])

        adv_ = tf.placeholder(tf.float32, [None, 1])
        enhanced = unet(phone_image)
        [w, h, d] = enhanced.get_shape().as_list()[1:]

        # # learning rate exponential_decay
        # global_step = tf.Variable(0)
        # learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, decay_steps=args.train_size / args.batch_size, decay_rate=0.98, staircase=True)

        ## loss introduce
        '''
        content loss three ways : 
        1. vgg_loss: mat model load;
        2. vgg_loss: npy model load;
        3. iqa model(meon_loss): feature and scores
        '''
        # vgg = vgg19_loss.Vgg19(vgg_path=args.pretrain_weights) #  # load vgg models
        # vgg_content = 2000*tf.reduce_mean(tf.sqrt(tf.reduce_sum(
        #     tf.square((vgg.extract_feature(enhanced) - vgg.extract_feature(dslr_image))))) / (w * h * d))
        # # loss_content = multi_content_loss(args.pretrain_weights, enhanced, dslr_image, args.batch_size) # change another way

        # meon loss
        # with tf.variable_scope('meon_loss') as scope: # load ckpt is not conveient.
        MEON_evaluate_model, loss_content = meon_loss(dslr_image, enhanced)

        loss_texture, discim_accuracy = texture_loss(enhanced, dslr_image, args.patch_width, args.patch_height, adv_)
        loss_discrim = -loss_texture

        loss_color = color_loss(enhanced, dslr_image, args.batch_size)
        loss_tv = variation_loss(enhanced, args.patch_width, args.patch_height, args.batch_size)

        loss_psnr = PSNR(enhanced, dslr_image)
        loss_ssim = MultiScaleSSIM(enhanced, dslr_image)

        loss_generator = args.w_content * loss_content + args.w_texture * loss_texture + args.w_tv * loss_tv + 1000 * (
                    1 - loss_ssim) + args.w_color * loss_color

        # optimize parameters of image enhancement (generator) and discriminator networks
        generator_vars = [v for v in tf.global_variables() if v.name.startswith("generator")]
        discriminator_vars = [v for v in tf.global_variables() if v.name.startswith("discriminator")]
        meon_vars = [v for v in tf.global_variables() if v.name.startswith("conv") or v.name.startswith("subtask")]

        # train_step_gen = tf.train.AdamOptimizer(args.learning_rate).minimize(loss_generator, var_list=generator_vars)
        # train_step_disc = tf.train.AdamOptimizer(args.learning_rate).minimize(loss_discrim, var_list=discriminator_vars)

        train_step_gen = tf.train.AdamOptimizer(5e-5).minimize(loss_generator, var_list=generator_vars)
        train_step_disc = tf.train.AdamOptimizer(5e-5).minimize(loss_discrim, var_list=discriminator_vars)

        saver = tf.train.Saver(var_list=generator_vars, max_to_keep=100)
        meon_saver = tf.train.Saver(var_list=meon_vars)

        logger.info('Initializing variables')
        sess.run(tf.global_variables_initializer())
        logger.info('Training network')
        train_loss_gen = 0.0
        train_acc_discrim = 0.0
        all_zeros = np.reshape(np.zeros((args.batch_size, 1)), [args.batch_size, 1])
        test_crops = test_data[np.random.randint(0, TEST_SIZE, 5), :]  # choose five images to visual

        # summary ,add the scalar you want to see
        tf.summary.scalar('loss_generator', loss_generator),
        tf.summary.scalar('loss_content', loss_content),
        tf.summary.scalar('loss_color', loss_color),
        tf.summary.scalar('loss_texture', loss_texture),
        tf.summary.scalar('loss_tv', loss_tv),
        tf.summary.scalar('discim_accuracy', discim_accuracy),
        tf.summary.scalar('psnr', loss_psnr),
        tf.summary.scalar('ssim', loss_ssim),
        tf.summary.scalar('learning_rate', args.learning_rate),
        merge_summary = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(os.path.join(args.tesorboard_logs_dir, 'train', args.exp_name), sess.graph,
                                             filename_suffix=args.exp_name)
        test_writer = tf.summary.FileWriter(os.path.join(args.tesorboard_logs_dir, 'test', args.exp_name), sess.graph,
                                            filename_suffix=args.exp_name)
        tf.global_variables_initializer().run()

        '''load ckpt models'''
        ckpt = tf.train.get_checkpoint_state(args.checkpoint_dir)
        start_i = 0
        if ckpt and ckpt.model_checkpoint_path:
            logger.info('loading checkpoint:' + ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
            import re
            start_i = int(re.findall("_(\d+).ckpt", ckpt.model_checkpoint_path)[0])
        MEON_evaluate_model.initialize(sess, meon_saver,
                                       args.meod_ckpt_path)  # initialize with anohter model pretrained weights

        '''start training...'''
        for i in range(start_i, args.iter_max):

            iter_start = time.time()
            # train generator
            idx_train = np.random.randint(0, args.train_size, args.batch_size)
            phone_images = train_data[idx_train]
            dslr_images = train_answ[idx_train]

            [loss_temp, temp] = sess.run([loss_generator, train_step_gen],
                                         feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: all_zeros})
            train_loss_gen += loss_temp / args.eval_step

            # train discriminator
            idx_train = np.random.randint(0, args.train_size, args.batch_size)

            # generate image swaps (dslr or enhanced) for discriminator
            swaps = np.reshape(np.random.randint(0, 2, args.batch_size), [args.batch_size, 1])

            phone_images = train_data[idx_train]
            dslr_images = train_answ[idx_train]
            # sess.run(train_step_disc)=train_step_disc.compute_gradients(loss,var)+train_step_disc.apply_gradients(var) @20190105
            [accuracy_temp, temp] = sess.run([discim_accuracy, train_step_disc],
                                             feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
            train_acc_discrim += accuracy_temp / args.eval_step

            if i % args.summary_step == 0:
                # summary intervals
                # enhance_f1_, enhance_f2_, enhance_s_, vgg_content_ = sess.run([enhance_f1, enhance_f2, enhance_s,vgg_content],
                #                          feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
                # loss_content1_, loss_content2_, loss_content3_ = sess.run([loss_content1,loss_content2,loss_content3],
                #                          feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
                # print("-----------------------------------------------")
                # print(enhance_f1_, enhance_f2_, enhance_s_,vgg_content_,loss_content1_, loss_content2_, loss_content3_)
                # print("-----------------------------------------------")
                train_summary = sess.run(merge_summary,
                                         feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
                train_writer.add_summary(train_summary, i)

            if i % args.eval_step == 0:
                # test generator and discriminator CNNs
                test_losses_gen = np.zeros((1, 7))
                test_accuracy_disc = 0.0

                for j in range(num_test_batches):
                    be = j * args.batch_size
                    en = (j + 1) * args.batch_size

                    swaps = np.reshape(np.random.randint(0, 2, args.batch_size), [args.batch_size, 1])
                    phone_images = test_data[be:en]
                    dslr_images = test_answ[be:en]

                    [enhanced_crops, accuracy_disc, losses] = sess.run([enhanced, discim_accuracy, \
                                                                        [loss_generator, loss_content, loss_color,
                                                                         loss_texture, loss_tv, loss_psnr, loss_ssim]], \
                                                                       feed_dict={phone_: phone_images,
                                                                                  dslr_: dslr_images, adv_: swaps})

                    test_losses_gen += np.asarray(losses) / num_test_batches
                    test_accuracy_disc += accuracy_disc / num_test_batches

                logs_disc = "step %d/%d, %s | discriminator accuracy | train: %.4g, test: %.4g" % \
                            (i, args.iter_max, args.dataset, train_acc_discrim, test_accuracy_disc)
                logs_gen = "generator losses | train: %.4g, test: %.4g | content: %.4g, color: %.4g, texture: %.4g, tv: %.4g | psnr: %.4g, ssim: %.4g\n" % \
                           (train_loss_gen, test_losses_gen[0][0], test_losses_gen[0][1], test_losses_gen[0][2],
                            test_losses_gen[0][3], test_losses_gen[0][4], test_losses_gen[0][5], test_losses_gen[0][6])

                logger.info(logs_disc)
                logger.info(logs_gen)

                test_summary = sess.run(merge_summary,
                                        feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
                test_writer.add_summary(test_summary, i)

                # save visual results for several test image crops
                if args.save_visual_result:
                    enhanced_crops = sess.run(enhanced,
                                              feed_dict={phone_: test_crops, dslr_: dslr_images, adv_: all_zeros})
                    idx = 0
                    for crop in enhanced_crops:
                        before_after = np.hstack(
                            (np.reshape(test_crops[idx], [args.patch_height, args.patch_width, 3]), crop))
                        misc.imsave(
                            os.path.join(args.checkpoint_dir, str(args.dataset) + str(idx) + '_iteration_' + str(i) +
                                         '.jpg'), before_after)
                        idx += 1

                # save the model that corresponds to the current iteration
                if args.save_ckpt_file:
                    saver.save(sess,
                               os.path.join(args.checkpoint_dir, str(args.dataset) + '_iteration_' + str(i) + '.ckpt'),
                               write_meta_graph=False)

                train_loss_gen = 0.0
                train_acc_discrim = 0.0
                # reload a different batch of training data
                del train_data
                del train_answ
                del test_data
                del test_answ
                test_data, test_answ = load_test_data(args.dataset, args.dataset_dir, args.test_size, args.patch_size)
                train_data, train_answ = load_batch(args.dataset, args.dataset_dir, args.train_size, args.patch_size)
Пример #13
0
def main(args, data_params):
    procname = os.path.basename(args.checkpoint_dir)

    log.info('Preparing summary and checkpoint directory {}'.format(
        args.checkpoint_dir))
    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    tf.set_random_seed(1234)  # Make experiments repeatable

    # Select an architecture

    # Add model parameters to the graph (so they are saved to disk at checkpoint)

    # --- Train/Test datasets ---------------------------------------------------
    data_pipe = getattr(dp, args.data_pipeline)
    with tf.variable_scope('train_data'):
        train_data_pipeline = data_pipe(
            args.data_dir,
            shuffle=True,
            batch_size=args.batch_size,
            nthreads=args.data_threads,
            fliplr=args.fliplr,
            flipud=args.flipud,
            rotate=args.rotate,
            random_crop=args.random_crop,
            params=data_params,
            output_resolution=args.output_resolution,
            scale=args.scale)
        train_samples = train_data_pipeline.samples

    if args.eval_data_dir is not None:
        with tf.variable_scope('eval_data'):
            eval_data_pipeline = data_pipe(
                args.eval_data_dir,
                shuffle=True,
                batch_size=args.batch_size,
                nthreads=args.data_threads,
                fliplr=False,
                flipud=False,
                rotate=False,
                random_crop=False,
                params=data_params,
                output_resolution=args.output_resolution,
                scale=args.scale)
            eval_samples = eval_data_pipeline.samples
    # ---------------------------------------------------------------------------
    swaps = np.reshape(np.random.randint(0, 2, args.batch_size),
                       [args.batch_size, 1])
    swaps = tf.convert_to_tensor(swaps)
    swaps = tf.cast(swaps, tf.float32)
    # Training graph
    with tf.variable_scope('inference'):
        prediction = unet(train_samples['image_input'])
        loss,loss_content,loss_texture,loss_color,loss_Mssim,loss_tv,discim_accuracy =\
          compute_loss.total_loss(train_samples['image_output'], prediction, swaps, args.batch_size)
        psnr = PSNR(train_samples['image_output'], prediction)
        loss_ssim = MultiScaleSSIM(train_samples['image_output'], prediction)

    # Evaluation graph
    if args.eval_data_dir is not None:
        with tf.name_scope('eval'):
            with tf.variable_scope('inference', reuse=True):
                eval_prediction = unet(eval_samples['image_input'])
            eval_psnr = PSNR(eval_samples['image_output'], eval_prediction)
            eval_ssim = MultiScaleSSIM(eval_samples['image_output'],
                                       eval_prediction)

    # Optimizer
    model_vars1 = [
        v for v in tf.global_variables()
        if v.name.startswith("inference/generator")
    ]
    discriminator_vars1 = [
        v for v in tf.global_variables()
        if v.name.startswith("inference/l2_loss/discriminator")
    ]

    global_step = tf.contrib.framework.get_or_create_global_step()
    with tf.name_scope('optimizer'):
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        updates = tf.group(*update_ops, name='update_ops')
        log.info("Adding {} update ops".format(len(update_ops)))

        reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        if reg_losses and args.weight_decay is not None and args.weight_decay > 0:
            print("Regularization losses:")
            for rl in reg_losses:
                print(" ", rl.name)
            opt_loss = loss + args.weight_decay * sum(reg_losses)
        else:
            print("No regularization.")
            opt_loss = loss

        with tf.control_dependencies([updates]):
            opt = tf.train.AdamOptimizer(args.learning_rate)
            minimize = opt.minimize(opt_loss,
                                    name='optimizer',
                                    global_step=global_step,
                                    var_list=model_vars1)
            minimize_discrim = opt.minimize(-loss_texture,
                                            name='discriminator',
                                            global_step=global_step,
                                            var_list=discriminator_vars1)

    # Average loss and psnr for display
    with tf.name_scope("moving_averages"):
        ema = tf.train.ExponentialMovingAverage(decay=0.99)
        update_ma = ema.apply([
            loss, loss_content, loss_texture, loss_color, loss_Mssim, loss_tv,
            discim_accuracy, psnr, loss_ssim
        ])
        loss = ema.average(loss)
        loss_content = ema.average(loss_content)
        loss_texture = ema.average(loss_texture)
        loss_color = ema.average(loss_color)
        loss_Mssim = ema.average(loss_Mssim)
        loss_tv = ema.average(loss_tv)
        discim_accuracy = ema.average(discim_accuracy)
        psnr = ema.average(psnr)
        loss_ssim = ema.average(loss_ssim)

    # Training stepper operation
    train_op = tf.group(minimize, update_ma)
    train_discrim_op = tf.group(minimize_discrim, update_ma)

    # Save a few graphs to
    summaries = [
        tf.summary.scalar('loss', loss),
        tf.summary.scalar('loss_content', loss_content),
        tf.summary.scalar('loss_color', loss_color),
        tf.summary.scalar('loss_texture', loss_texture),
        tf.summary.scalar('loss_ssim', loss_Mssim),
        tf.summary.scalar('loss_tv', loss_tv),
        tf.summary.scalar('discim_accuracy', discim_accuracy),
        tf.summary.scalar('psnr', psnr),
        tf.summary.scalar('ssim', loss_ssim),
        tf.summary.scalar('learning_rate', args.learning_rate),
        tf.summary.scalar('batch_size', args.batch_size),
    ]

    log_fetches = {
        "loss_content": loss_content,
        "loss_texture": loss_texture,
        "loss_color": loss_color,
        "loss_Mssim": loss_Mssim,
        "loss_tv": loss_tv,
        "discim_accuracy": discim_accuracy,
        "step": global_step,
        "loss": loss,
        "psnr": psnr,
        "loss_ssim": loss_ssim
    }

    model_vars = [
        v for v in tf.global_variables()
        if not v.name.startswith("inference/l2_loss/discriminator")
    ]
    discriminator_vars = [
        v for v in tf.global_variables()
        if v.name.startswith("inference/l2_loss/discriminator")
    ]

    # Train config
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # Do not canibalize the entire GPU

    sv = tf.train.Supervisor(
        saver=tf.train.Saver(var_list=model_vars, max_to_keep=100),
        local_init_op=tf.initialize_variables(discriminator_vars),
        logdir=args.checkpoint_dir,
        save_summaries_secs=args.summary_interval,
        save_model_secs=args.checkpoint_interval)
    # Train loop
    with sv.managed_session(config=config) as sess:
        sv.loop(args.log_interval, log_hook, (sess, log_fetches))
        last_eval = time.time()
        while True:
            if sv.should_stop():
                log.info("stopping supervisor")
                break
            try:
                step, _ = sess.run([global_step, train_op])
                _ = sess.run(train_discrim_op)
                since_eval = time.time() - last_eval

                if args.eval_data_dir is not None and since_eval > args.eval_interval:
                    log.info("Evaluating on {} images at step {}".format(
                        3, step))

                    p_ = 0
                    s_ = 0
                    for it in range(3):
                        p_ += sess.run(eval_psnr)
                        s_ += sess.run(eval_ssim)
                    p_ /= 3
                    s_ /= 3

                    sv.summary_writer.add_summary(tf.Summary(value=[
                        tf.Summary.Value(tag="psnr/eval", simple_value=p_)
                    ]),
                                                  global_step=step)

                    sv.summary_writer.add_summary(tf.Summary(value=[
                        tf.Summary.Value(tag="ssim/eval", simple_value=s_)
                    ]),
                                                  global_step=step)

                    log.info("  Evaluation PSNR = {:.2f} dB".format(p_))
                    log.info("  Evaluation SSIM = {:.4f} ".format(s_))

                    last_eval = time.time()

            except tf.errors.AbortedError:
                log.error("Aborted")
                break
            except KeyboardInterrupt:
                break
        chkpt_path = os.path.join(args.checkpoint_dir, 'on_stop.ckpt')
        log.info("Training complete, saving chkpt {}".format(chkpt_path))
        sv.saver.save(sess, chkpt_path)
        sv.request_stop()
Пример #14
0
def main(_run):
    args = tupperware(_run.config)

    # Dir init
    dir_init(args, is_local_rank_0=is_local_rank_0)

    # Ignore warnings
    if not is_local_rank_0:
        warnings.filterwarnings("ignore")

    # Mutli GPUS Setup
    if args.distdataparallel:
        rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        world_size = dist.get_world_size()
    else:
        rank = args.device
        world_size = 1

    # Get data
    data = get_dataloaders(args, is_local_rank_0=is_local_rank_0)

    # Model
    G = get_model.model(args).to(rank)

    # Optimisers
    g_optimizer, g_lr_scheduler = get_optimisers(G, args)

    # Load Models
    G, g_optimizer, global_step, start_epoch, loss = load_models(
        G, g_optimizer, args, is_local_rank_0=is_local_rank_0)

    if args.distdataparallel:
        # Wrap with Distributed Data Parallel
        G = torch.nn.parallel.DistributedDataParallel(G,
                                                      device_ids=[rank],
                                                      output_device=rank)

    # Log no of GPUs
    if is_local_rank_0:
        world_size = int(
            os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
        logging.info("Using {} GPUs".format(world_size))

        writer = SummaryWriter(log_dir=str(args.run_dir))
        writer.add_text("Args", pprint_args(args))

        # Pbars
        train_pbar = tqdm(range(len(data.train_loader) * args.batch_size),
                          dynamic_ncols=True)

        val_pbar = (tqdm(range(len(data.val_loader) * args.batch_size),
                         dynamic_ncols=True) if data.val_loader else None)

        test_pbar = (tqdm(range(len(data.test_loader) * args.batch_size),
                          dynamic_ncols=True) if data.test_loader else None)

    # Initialise losses
    g_loss = GLoss(args).to(rank)

    # Compatibility with checkpoints without global_step
    if not global_step:
        global_step = start_epoch * len(data.train_loader) * args.batch_size

    start_epoch = global_step // len(data.train_loader.dataset)

    # Exponential averaging of loss
    loss_dict = {
        "total_loss": 0.0,
        "image_loss": 0.0,
        "cobi_rgb_loss": 0.0,
        "train_PSNR": 0.0,
    }

    metric_dict = {"PSNR": 0.0, "total_loss": 0.0}
    avg_metrics = AvgLoss_with_dict(loss_dict=metric_dict, args=args)
    exp_loss = ExpLoss_with_dict(loss_dict=loss_dict, args=args)

    try:
        for epoch in range(start_epoch, args.num_epochs):
            # Train mode
            G.train()

            if is_local_rank_0:
                train_pbar.reset()

            if args.distdataparallel:
                data.train_loader.sampler.set_epoch(epoch)

            for i, batch in enumerate(data.train_loader):
                # allows for interrupted training
                if ((global_step + 1) %
                    (len(data.train_loader) * args.batch_size)
                        == 0) and (epoch == start_epoch):
                    break

                loss_dict = defaultdict(float)

                source, target, filename = batch
                source, target = (source.to(rank), target.to(rank))

                # ------------------------------- #
                # Update Gen
                # ------------------------------- #
                G.zero_grad()
                output = G(source)

                g_loss(output=output, target=target)

                g_loss.total_loss.backward()
                g_optimizer.step()

                # Update lr schedulers
                g_lr_scheduler.step(epoch + i / len(data.train_loader))

                # if is_local_rank_0:
                # Train PSNR
                loss_dict["train_PSNR"] += PSNR(output, target)

                # Accumulate all losses
                loss_dict["total_loss"] += g_loss.total_loss
                loss_dict["image_loss"] += g_loss.image_loss
                loss_dict["cobi_rgb_loss"] += g_loss.cobi_rgb_loss

                exp_loss += reduce_loss_dict(loss_dict, world_size=world_size)

                global_step += args.batch_size * world_size

                if is_local_rank_0:
                    train_pbar.update(args.batch_size)
                    train_pbar.set_description(
                        f"Epoch: {epoch + 1} | Gen loss: {exp_loss.loss_dict['total_loss']:.3f} "
                    )

                # Write lr rates and metrics
                if is_local_rank_0 and i % (args.log_interval) == 0:
                    gen_lr = g_optimizer.param_groups[0]["lr"]
                    writer.add_scalar("lr/gen", gen_lr, global_step)

                    for metric in exp_loss.loss_dict:
                        writer.add_scalar(
                            f"Train_Metrics/{metric}",
                            exp_loss.loss_dict[metric],
                            global_step,
                        )

                    # Display images at end of epoch
                    n = np.min([3, args.batch_size])
                    for e in range(n):
                        source_vis = source[e].mul(0.5).add(0.5)
                        target_vis = target[e].mul(0.5).add(0.5)
                        output_vis = output[e].mul(0.5).add(0.5)

                        writer.add_image(
                            f"Source/Train_{e + 1}",
                            source_vis.cpu().detach(),
                            global_step,
                        )

                        writer.add_image(
                            f"Target/Train_{e + 1}",
                            target_vis.cpu().detach(),
                            global_step,
                        )

                        writer.add_image(
                            f"Output/Train_{e + 1}",
                            output_vis.cpu().detach(),
                            global_step,
                        )

                        writer.add_text(f"Filename/Train_{e + 1}", filename[e],
                                        global_step)

            if is_local_rank_0:
                # Save ckpt at end of epoch
                logging.info(
                    f"Saving weights at epoch {epoch + 1} global step {global_step}"
                )

                # Save weights
                save_weights(
                    epoch=epoch,
                    global_step=global_step,
                    G=G,
                    g_optimizer=g_optimizer,
                    loss=loss,
                    tag="latest",
                    args=args,
                )

                train_pbar.refresh()

            # Run val and test only occasionally
            if epoch % args.val_test_epoch_interval != 0:
                continue

            # Val and test
            with torch.no_grad():
                G.eval()

                if data.val_loader:
                    avg_metrics.reset()
                    if is_local_rank_0:
                        val_pbar.reset()

                    filename_static = []

                    for i, batch in enumerate(data.val_loader):
                        metrics_dict = defaultdict(float)

                        source, target, filename = batch
                        source, target = (source.to(rank), target.to(rank))

                        output = G(source)
                        g_loss(output=output, target=target)

                        # Total loss
                        metrics_dict["total_loss"] += g_loss.total_loss
                        # PSNR
                        metrics_dict["PSNR"] += PSNR(output, target)

                        avg_metrics += reduce_loss_dict(metrics_dict,
                                                        world_size=world_size)

                        # Save image
                        if args.static_val_image in filename:
                            filename_static = filename
                            source_static = source
                            target_static = target
                            output_static = output

                        if is_local_rank_0:
                            val_pbar.update(args.batch_size)
                            val_pbar.set_description(
                                f"Val Epoch : {epoch + 1} Step: {global_step}| PSNR: {avg_metrics.loss_dict['PSNR']:.3f}"
                            )
                    if is_local_rank_0:
                        for metric in avg_metrics.loss_dict:
                            writer.add_scalar(
                                f"Val_Metrics/{metric}",
                                avg_metrics.loss_dict[metric],
                                global_step,
                            )

                        n = np.min([3, args.batch_size])
                        for e in range(n):
                            source_vis = source[e].mul(0.5).add(0.5)
                            target_vis = target[e].mul(0.5).add(0.5)
                            output_vis = output[e].mul(0.5).add(0.5)

                            writer.add_image(
                                f"Source/Val_{e+1}",
                                source_vis.cpu().detach(),
                                global_step,
                            )
                            writer.add_image(
                                f"Target/Val_{e+1}",
                                target_vis.cpu().detach(),
                                global_step,
                            )
                            writer.add_image(
                                f"Output/Val_{e+1}",
                                output_vis.cpu().detach(),
                                global_step,
                            )

                            writer.add_text(f"Filename/Val_{e + 1}",
                                            filename[e], global_step)

                        for e, name in enumerate(filename_static):
                            if name == args.static_val_image:
                                source_vis = source_static[e].mul(0.5).add(0.5)
                                target_vis = target_static[e].mul(0.5).add(0.5)
                                output_vis = output_static[e].mul(0.5).add(0.5)

                                writer.add_image(
                                    f"Source/Val_Static",
                                    source_vis.cpu().detach(),
                                    global_step,
                                )
                                writer.add_image(
                                    f"Target/Val_Static",
                                    target_vis.cpu().detach(),
                                    global_step,
                                )
                                writer.add_image(
                                    f"Output/Val_Static",
                                    output_vis.cpu().detach(),
                                    global_step,
                                )

                                writer.add_text(
                                    f"Filename/Val_Static",
                                    filename_static[e],
                                    global_step,
                                )

                                break

                        logging.info(
                            f"Saving weights at END OF epoch {epoch + 1} global step {global_step}"
                        )

                        # Save weights
                        if avg_metrics.loss_dict["total_loss"] < loss:
                            is_min = True
                            loss = avg_metrics.loss_dict["total_loss"]
                        else:
                            is_min = False

                        # Save weights
                        save_weights(
                            epoch=epoch,
                            global_step=global_step,
                            G=G,
                            g_optimizer=g_optimizer,
                            loss=loss,
                            is_min=is_min,
                            args=args,
                            tag="best",
                        )

                        val_pbar.refresh()

                # Test
                if data.test_loader:
                    filename_static = []

                    if is_local_rank_0:
                        test_pbar.reset()

                    for i, batch in enumerate(data.test_loader):
                        source, filename = batch
                        source = source.to(rank)

                        output = G(source)

                        # Save image
                        if args.static_test_image in filename:
                            filename_static = filename
                            source_static = source
                            output_static = output

                        if is_local_rank_0:
                            test_pbar.update(args.batch_size)
                            test_pbar.set_description(
                                f"Test Epoch : {epoch + 1} Step: {global_step}"
                            )

                    if is_local_rank_0:
                        n = np.min([3, args.batch_size])
                        for e in range(n):
                            source_vis = source[e].mul(0.5).add(0.5)
                            output_vis = output[e].mul(0.5).add(0.5)

                            writer.add_image(
                                f"Source/Test_{e+1}",
                                source_vis.cpu().detach(),
                                global_step,
                            )

                            writer.add_image(
                                f"Output/Test_{e+1}",
                                output_vis.cpu().detach(),
                                global_step,
                            )

                            writer.add_text(f"Filename/Test_{e + 1}",
                                            filename[e], global_step)

                        for e, name in enumerate(filename_static):
                            if name == args.static_test_image:
                                source_vis = source_static[e]
                                output_vis = output_static[e]

                                writer.add_image(
                                    f"Source/Test_Static",
                                    source_vis.cpu().detach(),
                                    global_step,
                                )

                                writer.add_image(
                                    f"Output/Test_Static",
                                    output_vis.cpu().detach(),
                                    global_step,
                                )

                                writer.add_text(
                                    f"Filename/Test_Static",
                                    filename_static[e],
                                    global_step,
                                )

                                break

                        test_pbar.refresh()

    except KeyboardInterrupt:
        if is_local_rank_0:
            logging.info("-" * 89)
            logging.info("Exiting from training early. Saving models")

            for pbar in [train_pbar, val_pbar, test_pbar]:
                if pbar:
                    pbar.refresh()

            save_weights(
                epoch=epoch,
                global_step=global_step,
                G=G,
                g_optimizer=g_optimizer,
                loss=loss,
                is_min=True,
                args=args,
            )
Пример #15
0
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=0,
                                           drop_last=False)  #
valid_dataset = KKDataset(args.valid_dataset,
                          is_trainval=True,
                          transform=transform_test)  #
valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                           batch_size=args.val_batch_size,
                                           shuffle=True,
                                           num_workers=0)  #
# model & loss
model = XXXNet().to(device)  #
lossfunc = RRLoss()  #
criterion = PSNR()  #
# lr & optimizer
optimizer = optim.SGD(model.parameters(),
                      lr=args.init_lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=[50, 70],
                                                 gamma=0.1)

# load resume
if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
def predict_lowlight_residual():

    #加载模型
    encam = ENCAM()
    #hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    encam = encam.to(DEVICE)

    encam.eval()
    encam.load_state_dict(
        torch.load('./checkpoints/encam_best_08_27.pth',
                   map_location='cuda:0')['gen'])

    #加载数据
    mat_src_path = '../HSID/data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label = scio.loadmat(mat_src_path)['label']
    #test=test.transpose((2,0,1)) #将通道维放在最前面:191*1280*307

    test_data_dir = '../HSID/data/test_lowlight/origin/'
    test_set = HsiLowlightTestDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    for batch_idx, (noisy, label) in enumerate(test_dataloader):
        noisy = noisy.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)

        batch_size, width, height, band_num = noisy.shape
        denoised_hsi = np.zeros((width, height, band_num))

        noisy = noisy.to(DEVICE)
        label = label.to(DEVICE)

        with torch.no_grad():
            for i in range(band_num):  #遍历每个band去处理
                current_noisy_band = noisy[:, :, :, i]
                current_noisy_band = current_noisy_band[:, None]

                adj_spectral_bands = get_adjacent_spectral_bands(
                    noisy, K, i)  # shape: batch_size, width, height, band_num
                adj_spectral_bands = adj_spectral_bands.permute(
                    0, 3, 1,
                    2)  #交换第一维和第三维 ,shape: batch_size, band_num, height, width
                adj_spectral_bands = torch.unsqueeze(adj_spectral_bands, 1)
                adj_spectral_bands = adj_spectral_bands.to(DEVICE)
                print('adj_spectral_bands : ', adj_spectral_bands.shape)
                print('adj_spectral_bands shape[4] =',
                      adj_spectral_bands.shape[4])
                #这里需要将current_noisy_band和adj_spectral_bands拆分成4份,每份大小为batchsize,1, band_num , height/2, width/2
                current_noisy_band_00 = current_noisy_band[:, :,
                                                           0:current_noisy_band
                                                           .shape[2] // 2,
                                                           0:current_noisy_band
                                                           .shape[3] // 2]
                adj_spectral_bands_00 = adj_spectral_bands[:, :, :,
                                                           0:adj_spectral_bands
                                                           .shape[3] // 2,
                                                           0:adj_spectral_bands
                                                           .shape[4] // 2]
                residual_00 = encam(current_noisy_band_00,
                                    adj_spectral_bands_00)
                denoised_band_00 = current_noisy_band_00 + residual_00

                current_noisy_band_00 = current_noisy_band[:, :,
                                                           0:current_noisy_band
                                                           .shape[2] // 2,
                                                           0:current_noisy_band
                                                           .shape[3] // 2]
                adj_spectral_bands_00 = adj_spectral_bands[:, :, :,
                                                           0:adj_spectral_bands
                                                           .shape[3] // 2,
                                                           0:adj_spectral_bands
                                                           .shape[4] // 2]
                residual_00 = encam(current_noisy_band_00,
                                    adj_spectral_bands_00)
                denoised_band_01 = current_noisy_band_00 + residual_00

                current_noisy_band_00 = current_noisy_band[:, :, 0:(
                    current_noisy_band.shape[2] //
                    2), 0:(current_noisy_band.shape[3] // 2)]
                adj_spectral_bands_00 = adj_spectral_bands[:, :, :,
                                                           0:adj_spectral_bands
                                                           .shape[3] // 2,
                                                           0:adj_spectral_bands
                                                           .shape[4] // 2]
                residual_00 = encam(current_noisy_band_00,
                                    adj_spectral_bands_00)
                denoised_band_10 = current_noisy_band_00 + residual_00

                current_noisy_band_00 = current_noisy_band[:, :,
                                                           0:current_noisy_band
                                                           .shape[2] // 2,
                                                           0:current_noisy_band
                                                           .shape[3] // 2]
                adj_spectral_bands_11 = adj_spectral_bands[:, :, :,
                                                           0:adj_spectral_bands
                                                           .shape[3] // 2,
                                                           0:adj_spectral_bands
                                                           .shape[4] // 2]
                residual_00 = encam(current_noisy_band_00,
                                    adj_spectral_bands_00)
                denoised_band_11 = current_noisy_band_00 + residual_00

                denoised_band_0 = torch.cat(
                    (denoised_band_00, denoised_band_01), dim=3)
                denoised_band_1 = torch.cat(
                    (denoised_band_10, denoised_band_11), dim=3)
                denoised_band = torch.cat((denoised_band_0, denoised_band_1),
                                          dim=2)

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, i] = denoised_band_numpy

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat',
                 {'denoised': denoised_hsi})

    psnr = PSNR(denoised_hsi, test_label)
    ssim = SSIM(denoised_hsi, test_label)
    sam = SAM(denoised_hsi, test_label)
    #计算pnsr和ssim
    print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(
        psnr, ssim, sam))
Пример #17
0
    with torch.no_grad():
        outputs = []
        masks_out = []

        avg_l1 = np.zeros(labels.shape[0])
        avg_PSNR = np.zeros(labels.shape[0])
        avg_MSSIM = np.zeros(labels.shape[0])

        for i in range(labels.shape[0]):
            output = model(data[i, :3, :, :][None, :, :, :],
                           masks[i, :, :, :][None, :, :, :])

            l1_error = l1(output[0].numpy(),
                          labels[i, :, :, :][None, :, :, :].numpy())
            PSNR_error = PSNR(output[0].numpy(),
                              labels[i, :, :, :][None, :, :, :].numpy())
            MSSIM_error = compute_MSSIM(
                output[0].numpy(), labels[i, :, :, :][None, :, :, :].numpy())

            print('l1 error of the model : {}'.format(l1_error))
            print('PSNR error of the model : {}'.format(PSNR_error))
            print('MSSIM error of the model : {}'.format(MSSIM_error))

            des_im = data[:, :3, :, :][i].cpu().numpy()
            des_im = np.transpose(des_im, (1, 2, 0))
            plt.imshow(des_im)
            plt.axis('off')
            plt.show()

            matplotlib.image.imsave('image_destroyed{}.png'.format(i), des_im)
Пример #18
0
def train():
    parser = argparse.ArgumentParser(description='Trainer')
    parser.add_argument('--model', help='H5 TF upsample model', default=None)
    parser.add_argument('--scale', help='Model scale', default=2, type=int)
    args = parser.parse_args()

    scale = args.scale
    train_folder = './frames'
    valid_folder = './test_data'
    crop_size = 96
    repeat_count = 1
    batch_size = 16
    in_channels = 9
    out_channels = 3

    epochs = 100
    lr = 1e-4
    out_dir = 'results'
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    if args.model is not None:
        model = tf.keras.models.load_model(args.model, compile=False)
        model.summary()
        print('Model loaded')
    else:
        model = ResUNet(scale, in_channels, out_channels)
        print('New model created')

    train_dataset = get_dataset(
        train_folder,
        batch_size,
        crop_size,
        scale,
        repeat_count
    )

    valid_dataset = get_dataset(
        valid_folder,
        batch_size,
        crop_size,
        scale,
        repeat_count
    )

    optimizer = tf.keras.optimizers.Adam(lr, beta_1=.5)
    model.compile(
        loss=tf.losses.MeanAbsoluteError(),
        optimizer=optimizer,
        metrics=[PSNR(), SSIM()])

    early_stop = tf.keras.callbacks.EarlyStopping(
        monitor='val_PSNR', patience=10, mode='max')
    csv_log = tf.keras.callbacks.CSVLogger(
        os.path.join(out_dir, f'train_resunet.log'))

    saver = tf.keras.callbacks.ModelCheckpoint(
        filepath=out_dir + '/resunet_{epoch:03d}-{val_PSNR:.2f}.h5',
        monitor='val_PSNR')

    callbacks = [csv_log, saver, early_stop]

    model.fit(
        train_dataset,
        epochs=epochs,
        callbacks=callbacks,
        validation_data=valid_dataset,
        verbose=1
    )
Пример #19
0
    # adv_ = tf.placeholder(tf.float32, [None, 1])

    enhanced = EDSR(phone_image)
    print enhanced.shape

    #loss introduce
    # loss_texture, discim_accuracy = texture_loss(enhanced,dslr_image,PATCH_WIDTH,PATCH_HEIGHT,adv_)
    # loss_discrim = -loss_texture
    # loss_content = content_loss(vgg_dir,enhanced,dslr_image,batch_size)
    # loss_color = color_loss(enhanced, dslr_image, batch_size)
    # loss_tv = variation_loss(enhanced,PATCH_WIDTH,PATCH_HEIGHT,batch_size)

    # loss_generator = w_content * loss_content + w_texture * loss_texture + w_color * loss_color + w_tv * loss_tv
    loss_generator = tf.losses.absolute_difference(labels=dslr_image,
                                                   predictions=enhanced)
    loss_psnr = PSNR(enhanced, dslr_, PATCH_SIZE, batch_size)
    loss_ssim = MultiScaleSSIM(enhanced, dslr_image)

    # optimize parameters of image enhancement (generator) and discriminator networks
    generator_vars = [
        v for v in tf.global_variables() if v.name.startswith("generator")
    ]
    # discriminator_vars = [v for v in tf.global_variables() if v.name.startswith("discriminator")]

    train_step_gen = tf.train.AdamOptimizer(learning_rate).minimize(
        loss_generator, var_list=generator_vars)
    # train_step_disc = tf.train.AdamOptimizer(learning_rate).minimize(loss_discrim, var_list=discriminator_vars)

    saver = tf.train.Saver(var_list=generator_vars, max_to_keep=100)

    print('Initializing variables')
def train_model_residual_lowlight_rdn():

    device = DEVICE
    #准备数据
    train = np.load('./data/denoise/train_washington8.npy')
    train = train.transpose((2, 1, 0))

    test = np.load('./data/denoise/train_washington8.npy')
    #test=test.transpose((2,1,0))
    test = test.transpose((2, 1, 0))  #将通道维放在最前面

    save_model_path = './checkpoints/hsirnd_denoise_l1loss'
    if not os.path.exists(save_model_path):
        os.mkdir(save_model_path)

    #创建模型
    net = HSIRDNECA_Denoise(K)
    init_params(net)
    net = nn.DataParallel(net).to(device)
    #net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[200, 400], gamma=0.5)

    #定义loss 函数
    #criterion = nn.MSELoss()

    gen_epoch_loss_list = []

    cur_step = 0

    best_psnr = 0
    best_epoch = 0
    best_iter = 0
    start_epoch = 1
    num_epoch = 600

    mpsnr_list = []
    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())

        gen_epoch_loss = 0

        net.train()

        channels = 191  # 191 channels
        data_patches, data_cubic_patches = datagenerator(train, channels)

        data_patches = torch.from_numpy(data_patches.transpose((
            0,
            3,
            1,
            2,
        )))
        data_cubic_patches = torch.from_numpy(
            data_cubic_patches.transpose((0, 4, 1, 2, 3)))

        DDataset = DenoisingDataset(data_patches, data_cubic_patches, SIGMA)

        print('yes')
        DLoader = DataLoader(dataset=DDataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True)  # loader出问题了

        epoch_loss = 0
        start_time = time.time()

        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for step, x_y in enumerate(DLoader):
            #print('batch_idx=', batch_idx)
            batch_x_noise, batch_y_noise, batch_x = x_y[0], x_y[1], x_y[2]

            batch_x_noise = batch_x_noise.to(device)
            batch_y_noise = batch_y_noise.to(device)
            batch_x = batch_x.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(batch_x_noise, batch_y_noise)
            alpha = 0.8
            loss = recon_criterion(residual, batch_x - batch_x_noise)
            #loss = alpha*recon_criterion(residual, label-noisy) + (1-alpha)*loss_function_mse(residual, label-noisy)
            #loss = recon_criterion(residual, label-noisy)
            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            if step % 10 == 0:
                print('%4d %4d / %4d loss = %2.8f' %
                      (epoch + 1, step, data_patches.size(0) // BATCH_SIZE,
                       loss.item() / BATCH_SIZE))

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            },
            f"{save_model_path}/hsid_rdn_eca_l1_loss_600epoch_patchsize32_{epoch}.pth"
        )

        #测试代码
        net.eval()
        """
        channel_s = 191  # 设置多少波段
        data_patches, data_cubic_patches = datagenerator(test, channel_s)

        data_patches = torch.from_numpy(data_patches.transpose((0, 3, 1, 2,)))
        data_cubic_patches = torch.from_numpy(data_cubic_patches.transpose((0, 4, 1, 2, 3)))

        DDataset = DenoisingDataset(data_patches, data_cubic_patches, SIGMA)
        DLoader = DataLoader(dataset=DDataset, batch_size=BATCH_SIZE, shuffle=True)
        epoch_loss = 0
        
        for step, x_y in enumerate(DLoader):
            batch_x_noise, batch_y_noise, batch_x = x_y[0], x_y[1], x_y[2]

            batch_x_noise = batch_x_noise.to(DEVICE)
            batch_y_noise = batch_y_noise.to(DEVICE)
            batch_x = batch_x.to(DEVICE)
            residual = net(batch_x_noise, batch_y_noise)

            loss = loss_fuction(residual, batch_x-batch_x_noise)

            epoch_loss += loss.item()

            if step % 10 == 0:
                print('%4d %4d / %4d test loss = %2.4f' % (
                    epoch + 1, step, data_patches.size(0) // BATCH_SIZE, loss.item() / BATCH_SIZE))
        """
        #加载数据
        test_data_dir = './data/denoise/test/'
        test_set = HsiTrainDataset(test_data_dir)

        test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

        #指定结果输出路径
        test_result_output_path = './data/denoise/testresult/'
        if not os.path.exists(test_result_output_path):
            os.makedirs(test_result_output_path)

        #逐个通道的去噪
        """
        分配一个numpy数组,存储去噪后的结果
        遍历所有通道,
        对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
        调用hsid进行预测
        将预测到的residual和输入的noise加起来,得到输出band

        将去噪后的结果保存成mat结构
        """
        psnr_list = []
        for batch_idx, (noisy, label) in enumerate(test_dataloader):
            noisy = noisy.type(torch.FloatTensor)
            label = label.type(torch.FloatTensor)

            batch_size, width, height, band_num = noisy.shape
            denoised_hsi = np.zeros((width, height, band_num))

            noisy = noisy.to(DEVICE)
            label = label.to(DEVICE)

            with torch.no_grad():
                for i in range(band_num):  #遍历每个band去处理
                    current_noisy_band = noisy[:, :, :, i]
                    current_noisy_band = current_noisy_band[:, None]

                    adj_spectral_bands = get_adjacent_spectral_bands(
                        noisy, K, i)
                    #adj_spectral_bands = torch.transpose(adj_spectral_bands,3,1) #将通道数置换到第二维
                    adj_spectral_bands = adj_spectral_bands.permute(0, 3, 1, 2)
                    adj_spectral_bands_unsqueezed = adj_spectral_bands.unsqueeze(
                        1)
                    #print(current_noisy_band.shape, adj_spectral_bands.shape)
                    residual = net(current_noisy_band,
                                   adj_spectral_bands_unsqueezed)
                    denoised_band = residual + current_noisy_band
                    denoised_band_numpy = denoised_band.cpu().numpy().astype(
                        np.float32)
                    denoised_band_numpy = np.squeeze(denoised_band_numpy)

                    denoised_hsi[:, :, i] += denoised_band_numpy

                    test_label_current_band = label[:, :, :, i]

                    label_band_numpy = test_label_current_band.cpu().numpy(
                    ).astype(np.float32)
                    label_band_numpy = np.squeeze(label_band_numpy)

                    #print(denoised_band_numpy.shape, label_band_numpy.shape, label.shape)
                    psnr = PSNR(denoised_band_numpy, label_band_numpy)
                    psnr_list.append(psnr)

            mpsnr = np.mean(psnr_list)
            mpsnr_list.append(mpsnr)

            denoised_hsi_trans = denoised_hsi.transpose(2, 0, 1)
            test_label_hsi_trans = np.squeeze(label.cpu().numpy().astype(
                np.float32)).transpose(2, 0, 1)
            mssim = SSIM(denoised_hsi_trans, test_label_hsi_trans)
            sam = SAM(denoised_hsi_trans, test_label_hsi_trans)

            #计算pnsr和ssim
            print(
                "=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
                format(mpsnr, mssim, sam))

        #保存best模型
        if mpsnr > best_psnr:
            best_psnr = mpsnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                },
                f"{save_model_path}/hsid_rdn_eca_l1_loss_600epoch_patchsize32_best.pth"
            )

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, mpsnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     INIT_LEARNING_RATE))
        print(
            "------------------------------------------------------------------"
        )
Пример #21
0
def fit_model(model, data_loaders, channels, criterion, optimizer, scheduler, device, n_epochs, val_freq, checkpoint_dir, model_name):
    """
    Training of the denoiser model.
    :param model: torch Module
        Neural network to fit.
    :param data_loaders: dict
        Dictionary with torch DataLoaders with training and validation datasets.
    :param channels: int
        Number of image channels
    :param criterion: torch Module
        Loss function.
    :param optimizer: torch Optimizer
        Gradient descent optimization algorithm.
    :param scheduler: torch lr_scheduler
        Learning rate scheduler.
    :param device: torch device
        Device used during training (CPU/GPU).
    :param n_epochs: int
        Number of epochs to fit the model.
    :param val_freq: int
        How many training epochs to run between validations.
    :param checkpoint_dir: str
        Path to the directory where the model checkpoints and CSV log files will be stored.
    :param model_name: str
        Prefix name of the trained model saved in checkpoint_dir.
    :return: None
    """
    psnr = PSNR(data_range=1., reduction='sum')
    ssim = SSIM(channels, data_range=1., reduction='sum')
    os.makedirs(checkpoint_dir, exist_ok=True)
    logfile_path = os.path.join(checkpoint_dir,  ''.join([model_name, '_logfile.csv']))
    model_path = os.path.join(checkpoint_dir, ''.join([model_name, '-{:03d}-{:.4e}-{:.4f}-{:.4f}.pth']))
    file_logger = FileLogger(logfile_path)
    best_model_path, best_psnr = '', -np.inf
    since = time.time()

    for epoch in range(1, n_epochs + 1):
        lr = optimizer.param_groups[0]['lr']
        epoch_logger = EpochLogger()
        epoch_log = dict()

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                print('\nEpoch: {}/{} - Learning rate: {:.4e}'.format(epoch, n_epochs, lr))
                description = 'Training - Loss:{:.5e} - PSNR:{:.5f} - SSIM:{:.5f}'
            elif phase == 'val' and epoch % val_freq == 0:
                model.eval()
                description = 'Validation - Loss:{:.5e} - PSNR:{:.5f} - SSIM:{:.5f}'
            else:
                break

            iterator = tqdm(enumerate(data_loaders[phase], 1), total=len(data_loaders[phase]), ncols=110)
            iterator.set_description(description.format(0, 0, 0))
            n_samples = 0

            for step, (inputs, targets) in iterator:
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                n_samples += inputs.size()[0]
                metrics = {
                    'loss': loss.item() * inputs.size()[0],
                    'psnr': psnr(outputs, targets).item(),
                    'ssim': ssim(outputs, targets).item()
                }
                epoch_logger.update_log(metrics, phase)
                log = epoch_logger.get_log(n_samples, phase)
                iterator.set_description(description.format(log[phase + ' loss'], log[phase + ' psnr'], log[phase + ' ssim']))

            if phase == 'val':
                # Apply Reduce LR On Plateau if it is the case and save the model if the validation PSNR is improved.
                if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(log['val psnr'])
                if log['val psnr'] > best_psnr:
                    best_psnr = log['val psnr']
                    best_model_path = model_path.format(epoch, log['val loss'], log['val psnr'], log['val ssim'])
                    torch.save(model.state_dict(), best_model_path)

            elif scheduler is not None:         # Apply another scheduler at epoch level.
                scheduler.step()

            epoch_log = {**epoch_log, **log}

        # Save the current epoch metrics in a CVS file.
        epoch_data = {'epoch': epoch, 'learning rate': lr, **epoch_log}
        file_logger(epoch_data)

    # Save the last model and report training time.
    best_model_path = model_path.format(epoch, log['val loss'], log['val psnr'], log['val ssim'])
    torch.save(model.state_dict(), best_model_path)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best PSNR: {:4f}'.format(best_psnr))
Пример #22
0
def evaluate(dataloader, D, G, sample_size, device, now, region):
    """
    This function is used to evaluate the performance of the model.
    It is not a validation funtion, but a test one because uses a test dataset
    different from the training one and wants to get information on how well
    the model perform. Hyperparameters are taken from the paper from which the
    idea comes.
    The process is similar to training. Data is taken from dataloader, split
    and then noise is added to create the multi-layer input sample for the
    generator and the ground truth for the discriminator. The prediction is
    normalized in the range [0:1] because G does not have sigmoid layer.
    The functions for the performance data run and data is saved in the
    correct list. Images are saved in lists too.

    Args:
        dataloader (obj): this is the dataloader that loads the test dataset
        D (obj): discriminator model, acts as a function
        G (obj): generator model, acts as a function
        sample_size (int): width and height of the samples
        device (obj): type of device used to process data, used to exploit gpus
        now (str): stores the name for the folder where to save images
    """

    # Lists from config file are loaded to save performance data
    # Modules saved as variables to run evaluation indeces
    global eval_data
    psnr = PSNR()
    ssim = SSIM()

    i = 0
    start = time.time()
    # To not use the gradient attribute
    with torch.no_grad():
        # Cycle over the dataloader
        print('^^^^^^^^^^^^^^^^')
        print('>>>TEST PHASE<<<')
        print('Evaluation loop on region' + region)
        for img, _ in dataloader:
            # Get back to three dimensions
            img = img.squeeze(0).to(device)
            if i % 10 == 0 or i == 1:
                print('Validation on sample ', i)

            # Ground truth directly prepared for comparison (range [0:1])
            y = img[:3, :, :].to(device)
            y = transforms.Normalize(mean=[-1, -1, -1], std=[2, 2,
                                                             2]).__call__(y)
            y = y.unsqueeze(0).to(device)
            #print('Memory allocated:', torch.cuda.memory_allocated())

            # Backup images
            x = img[3:, :, :].to(device)
            x = x.unsqueeze(0).to(device)
            #print('Memory allocated:', torch.cuda.memory_allocated())

            # Prediction normalized in range [0:1]
            y_pred = G(x).cuda()
            y_pred = y_pred.squeeze(0).to(device)
            y_pred = transforms.Normalize(mean=[-1, -1, -1],
                                          std=[2, 2, 2]).__call__(y_pred)
            y_pred = y_pred.unsqueeze(0).to(device)
            #print('Memory allocated:', torch.cuda.memory_allocated())

            # Data set in local variables
            psnr_value = psnr(y_pred, y).item()
            ssim_value = ssim(y_pred, y).item()

            # Data showed during test
            if i % 10 == 0:
                print('Iter: [%d/%d]\tPSNR: %.4f\tSSIM: %.4f\t' %
                      (i, len(dataloader), psnr_value, ssim_value))

            # Data saved in lists
            eval_data['psnr_list'].append(psnr_value)
            eval_data['ssim_list'].append(ssim_value)

            # Images saved in the correct folder
            y = y.squeeze(0)
            y_pred = y_pred.squeeze(0)

            f = transforms.ToPILImage().__call__(y_pred.detach().cpu())
            f.save(
                os.path.join(save_path, region, now, switch[1], 'fake',
                             'save' + str(i) + '.jpg'))
            r = transforms.ToPILImage().__call__(y.detach().cpu())
            r.save(
                os.path.join(save_path, region, now, switch[1], 'real',
                             'save' + str(i) + '.jpg'))

            i += 1
    end = time.time()
    print('Elapsed time in minutes:', (end - start) / 60)
Пример #23
0
    # Estrae N patch casuali
    blur, orig = draw_patches(n_clips)

    pred = model.predict(x=blur)

    batch, h, w, c = blur.shape

    mse_pred_orig = tf.keras.losses.MSE(orig.reshape(batch, h * w * c),
                                        pred.reshape(batch,
                                                     h * w * c)).numpy()
    mse_blur_orig = tf.keras.losses.MSE(orig.reshape(batch, h * w * c),
                                        blur.reshape(batch,
                                                     h * w * c)).numpy()

    psnr_pred_orig = PSNR(tf.Variable(orig), tf.Variable(pred)).numpy()
    psnr_blur_orig = PSNR(tf.Variable(orig), tf.Variable(blur)).numpy()

    ssim_pred_orig = SSIM(tf.Variable(orig), tf.Variable(pred)).numpy()
    ssim_blur_orig = SSIM(tf.Variable(orig), tf.Variable(blur)).numpy()

    # MI SERVE IL BEST MODEL PER ESTRARRE IMMAGINI

    mask = psnr_blur_orig < 30

    mse_pred_orig[mask].mean()
    psnr_pred_orig[mask].mean()
    ssim_pred_orig[mask].mean()

    idx = (ssim_pred_orig - ssim_blur_orig)
def predict_lowlight_hsid_origin():

    #加载模型
    #hsid = HSID(36)
    hsid = HSIRDNECA_Denoise(K)
    hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    save_model_path = './checkpoints/hsirnd_denoise_l1loss'

    #hsid = hsid.to(DEVICE)
    hsid.load_state_dict(
        torch.load(save_model_path +
                   '/hsid_rdn_eca_l1_loss_600epoch_patchsize32_best.pth',
                   map_location='cuda:0')['gen'])

    #加载数据
    test_data_dir = './data/denoise/test/level25'
    test_set = HsiTrainDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #指定结果输出路径
    test_result_output_path = './data/denoise/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    hsid.eval()
    psnr_list = []
    for batch_idx, (noisy, label) in enumerate(test_dataloader):
        noisy = noisy.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)

        batch_size, width, height, band_num = noisy.shape
        denoised_hsi = np.zeros((width, height, band_num))

        noisy = noisy.to(DEVICE)
        label = label.to(DEVICE)

        with torch.no_grad():
            for i in range(band_num):  #遍历每个band去处理
                current_noisy_band = noisy[:, :, :, i]
                current_noisy_band = current_noisy_band[:, None]

                adj_spectral_bands = get_adjacent_spectral_bands(noisy, K, i)
                #adj_spectral_bands = torch.transpose(adj_spectral_bands,3,1) #将通道数置换到第二维
                adj_spectral_bands = adj_spectral_bands.permute(0, 3, 1, 2)
                adj_spectral_bands_unsqueezed = adj_spectral_bands.unsqueeze(1)
                #print(current_noisy_band.shape, adj_spectral_bands.shape)
                residual = hsid(current_noisy_band,
                                adj_spectral_bands_unsqueezed)
                denoised_band = residual + current_noisy_band
                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, i] += denoised_band_numpy

                test_label_current_band = label[:, :, :, i]

                label_band_numpy = test_label_current_band.cpu().numpy(
                ).astype(np.float32)
                label_band_numpy = np.squeeze(label_band_numpy)

                #print(denoised_band_numpy.shape, label_band_numpy.shape, label.shape)
                psnr = PSNR(denoised_band_numpy, label_band_numpy)
                psnr_list.append(psnr)

        mpsnr = np.mean(psnr_list)

        denoised_hsi_trans = denoised_hsi.transpose(2, 0, 1)
        test_label_hsi_trans = np.squeeze(label.cpu().numpy().astype(
            np.float32)).transpose(2, 0, 1)
        mssim = SSIM(denoised_hsi_trans, test_label_hsi_trans)
        sam = SAM(denoised_hsi_trans, test_label_hsi_trans)

        #计算pnsr和ssim
        print("=====averPSNR:{:.4f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(mpsnr, mssim, sam))

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat',
                 {'denoised': denoised_hsi})
def predict_lowlight_hsid_origin():

    #加载模型
    #hsid = HSID(36)
    hsid = HSID_origin(24)
    #hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    hsid = hsid.to(DEVICE)
    hsid.load_state_dict(
        torch.load('./checkpoints/hsid_origin_best.pth',
                   map_location='cuda:0')['gen'])

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = './data/test_lowlight/cuk12/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    hsid.eval()
    for batch_idx, (noisy_test, cubic_test,
                    label_test) in enumerate(test_dataloader):
        noisy_test = noisy_test.type(torch.FloatTensor)
        label_test = label_test.type(torch.FloatTensor)
        cubic_test = cubic_test.type(torch.FloatTensor)

        noisy_test = noisy_test.to(DEVICE)
        label_test = label_test.to(DEVICE)
        cubic_test = cubic_test.to(DEVICE)

        with torch.no_grad():

            residual = hsid(noisy_test, cubic_test)
            denoised_band = noisy_test + residual

            denoised_band_numpy = denoised_band.cpu().numpy().astype(
                np.float32)
            denoised_band_numpy = np.squeeze(denoised_band_numpy)

            denoised_hsi[:, :, batch_idx] = denoised_band_numpy

    psnr = PSNR(denoised_hsi, test_label_hsi)
    ssim = SSIM(denoised_hsi, test_label_hsi)
    sam = SAM(denoised_hsi, test_label_hsi)

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat',
                 {'denoised': denoised_hsi})

    #计算pnsr和ssim
    print("=====averPSNR:{:.4f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(
        psnr, ssim, sam))
Пример #26
0
def train_model_residual_lowlight_rdn():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight_patchsize32/')
    #print('trainset32 training example:', len(train_set32))
    #train_set = HsiCubicTrainDataset('./data/train_lowlight/')

    #train_set_64 = HsiCubicTrainDataset('./data/train_lowlight_patchsize64/')

    #train_set_list = [train_set32, train_set_64]
    #train_set = ConcatDataset(train_set_list) #里面的样本大小必须是一致的,否则会连接失败
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    #test_data_dir = './data/test_lowlight/cuk12/'
    test_data_dir = './data/test_lowlight/cubic/'

    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    save_model_path = './checkpoints/hsirnd_cosine'
    if not os.path.exists(save_model_path):
        os.mkdir(save_model_path)

    #创建模型
    net = HSIRDN(K)
    init_params(net)
    net = nn.DataParallel(net).to(device)
    #net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    #scheduler = MultiStepLR(hsid_optimizer, milestones=[200,400], gamma=0.5)
    scheduler = CosineAnnealingLR(hsid_optimizer, T_max=600)

    #定义loss 函数
    #criterion = nn.MSELoss()

    is_resume = RESUME
    #唤醒训练
    if is_resume:
        path_chk_rest = dir_utils.get_last_path(save_model_path,
                                                'model_latest.pth')
        model_utils.load_checkpoint(net, path_chk_rest)
        start_epoch = model_utils.load_start_epoch(path_chk_rest) + 1
        model_utils.load_optim(hsid_optimizer, path_chk_rest)

        for i in range(1, start_epoch):
            scheduler.step()
        new_lr = scheduler.get_lr()[0]
        print(
            '------------------------------------------------------------------------------'
        )
        print("==> Resuming Training with learning rate:", new_lr)
        print(
            '------------------------------------------------------------------------------'
        )

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0
    if not is_resume:
        start_epoch = 1
    num_epoch = 600

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())

        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(noisy, cubic)
            alpha = 0.8
            loss = recon_criterion(residual, label - noisy)
            #loss = alpha*recon_criterion(residual, label-noisy) + (1-alpha)*loss_function_mse(residual, label-noisy)
            #loss = recon_criterion(residual, label-noisy)
            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            },
            f"{save_model_path}/hsid_rdn_4rdb_conise_l1_loss_600epoch_patchsize32_{epoch}.pth"
        )

        #测试代码
        net.eval()
        psnr_list = []
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual, axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

            test_label_current_band = test_label_hsi[:, :, batch_idx]
            psnr = PSNR(denoised_band_numpy, test_label_current_band)
            psnr_list.append(psnr)

        mpsnr = np.mean(psnr_list)

        denoised_hsi_trans = denoised_hsi.transpose(2, 0, 1)
        test_label_hsi_trans = test_label_hsi.transpose(2, 0, 1)
        mssim = SSIM(denoised_hsi_trans, test_label_hsi_trans)
        sam = SAM(denoised_hsi_trans, test_label_hsi_trans)

        #计算pnsr和ssim
        print("=====averPSNR:{:.4f}=====averSSIM:{:.4f}=====averSAM:{:.4f}".
              format(mpsnr, mssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': mpsnr,
            'average SSIM': mssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if mpsnr > best_psnr:
            best_psnr = mpsnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                },
                f"{save_model_path}/hsid_rdn_4rdb_conise_l1_loss_600epoch_patchsize32_best.pth"
            )

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     INIT_LEARNING_RATE))
        print(
            "------------------------------------------------------------------"
        )

        #保存当前模型
        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict()
            }, os.path.join(save_model_path, "model_latest.pth"))
    tb_writer.close()
def train_model_residual_lowlight_twostage_gan_best():

    #设置超参数
    batchsize = 128
    init_lr = 0.001
    K_adjacent_band = 36
    display_step = 20
    display_band = 20
    is_resume = False
    lambda_recon = 10

    start_epoch = 1

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight/')
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=batchsize,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    test_batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=test_batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = HSIDDenseNetTwoStage(K_adjacent_band)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    #创建discriminator
    disc = DiscriminatorABC(2, 4)
    init_params(disc)
    disc = disc.to(device)
    disc_opt = torch.optim.Adam(disc.parameters(), lr=init_lr)

    num_epoch = 100
    print('epoch count == ', num_epoch)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=init_lr)

    #Scheduler
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)
    warmup_epochs = 3
    #scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(hsid_optimizer, num_epoch-warmup_epochs+40, eta_min=1e-7)
    #scheduler = GradualWarmupScheduler(hsid_optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
    #scheduler.step()

    #唤醒训练
    if is_resume:
        model_dir = './checkpoints'
        path_chk_rest = dir_utils.get_last_path(model_dir, 'model_latest.pth')
        model_utils.load_checkpoint(net, path_chk_rest)
        start_epoch = model_utils.load_start_epoch(path_chk_rest) + 1
        model_utils.load_optim(hsid_optimizer, path_chk_rest)
        model_utils.load_disc_checkpoint(disc, path_chk_rest)
        model_utils.load_disc_optim(disc_opt, path_chk_rest)

        for i in range(1, start_epoch):
            scheduler.step()
        new_lr = scheduler.get_lr()[0]
        print(
            '------------------------------------------------------------------------------'
        )
        print("==> Resuming Training with learning rate:", new_lr)
        print(
            '------------------------------------------------------------------------------'
        )

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        #print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            ### Update discriminator ###
            disc_opt.zero_grad(
            )  # Zero out the gradient before backpropagation
            with torch.no_grad():
                fake, fake_stage2 = net(noisy, cubic)
            #print('noisy shape =', noisy.shape, fake_stage2.shape)
            #fake.detach()
            disc_fake_hat = disc(fake_stage2.detach() + noisy,
                                 noisy)  # Detach generator
            disc_fake_loss = adv_criterion(disc_fake_hat,
                                           torch.zeros_like(disc_fake_hat))
            disc_real_hat = disc(label, noisy)
            disc_real_loss = adv_criterion(disc_real_hat,
                                           torch.ones_like(disc_real_hat))
            disc_loss = (disc_fake_loss + disc_real_loss) / 2
            disc_loss.backward(retain_graph=True)  # Update gradients
            disc_opt.step()  # Update optimizer

            ### Update generator ###
            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual, residual_stage2 = net(noisy, cubic)
            disc_fake_hat = disc(residual_stage2 + noisy, noisy)
            gen_adv_loss = adv_criterion(disc_fake_hat,
                                         torch.ones_like(disc_fake_hat))

            alpha = 0.2
            beta = 0.2
            rec_loss = beta * (alpha*loss_fuction(residual, label-noisy) + (1-alpha) * recon_criterion(residual, label-noisy)) \
             + (1-beta) * (alpha*loss_fuction(residual_stage2, label-noisy) + (1-alpha) * recon_criterion(residual_stage2, label-noisy))

            loss = gen_adv_loss + lambda_recon * rec_loss

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                    print(
                        f"rec_loss =  {rec_loss.item()}, gen_adv_loss = {gen_adv_loss.item()}"
                    )

                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
                'disc': disc.state_dict(),
                'disc_opt': disc_opt.state_dict()
            }, f"checkpoints/two_stage_hsid_dense_gan_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual, residual_stage2 = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual_stage2

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual, axis=0)
                    residual_stage2_squeezed = torch.squeeze(residual_stage2,
                                                             axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual_stage2",
                                        residual_stage2_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if psnr > best_psnr:
            best_psnr = psnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                    'disc': disc.state_dict(),
                    'disc_opt': disc_opt.state_dict()
                }, f"checkpoints/two_stage_hsid_dense_gan_best.pth")

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     scheduler.get_lr()[0]))
        print(
            "------------------------------------------------------------------"
        )

        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
                'disc': disc.state_dict(),
                'disc_opt': disc_opt.state_dict()
            }, os.path.join('./checkpoints', "model_latest.pth"))

    tb_writer.close()
Пример #28
0
saver = tf.train.Saver()
with tf.Session() as sess:

    model = tf.train.get_checkpoint_state(model_path)
    saver.restore(sess, model.model_checkpoint_path)

    graph = tf.get_default_graph()

    prediction = sess.run([prediction], feed_dict={inputs: input_test})
    prediction = np.uint8(np.squeeze(prediction) * 255)

    avg_frame = np.uint8((frame1 / 2. + frame3 / 2.))

    loss = sum(sum(sum((prediction / 255. - frame2 / 255.)**2))) / 2

    psnr = PSNR(prediction, frame2)
    # avg_psnr = PSNR(avg_frame, frame2)
    pre_gray = prediction[:, :, 0]
    frame2_gray = frame2[:, :, 0]
    # avg_frame_gray = avg_frame[:,:,0]
    ssim = SSIM(pre_gray, frame2_gray).mean()
    ms_ssim = MSSSIM(pre_gray, frame2_gray)
    # avg_SSIM = SSIM(avg_frame_gray, frame2_gray).mean()
    # avg_MS_SSIM = MSSSIM(avg_frame_gray, frame2_gray)

    # print "Loss = %.2f" % loss
    # print "avg_PSNR = %.2f, avg_SSIM = %.4f, avg_MS-SSIM = %.4f" % ( avg_psnr, avg_SSIM, avg_MS_SSIM)
    print "Loss = %.2f, PSNR = %.2f, SSIM = %.4f, MS-SSIM = %.4f" % (
        loss, psnr, ssim, ms_ssim)

    # print "loss = %f, PSNR = %f" % (loss, psnr)
Пример #29
0
def train_model_multistage_lowlight():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight_patchsize32/')
    #print('trainset32 training example:', len(train_set32))

    #train_set_64 = HsiCubicTrainDataset('./data/train_lowlight_patchsize64/')

    #train_set_list = [train_set32, train_set_64]
    #train_set = ConcatDataset(train_set_list) #里面的样本大小必须是一致的,否则会连接失败
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = MultiStageHSID(K)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0
    start_epoch = 1
    num_epoch = 100

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        print(scheduler.get_lr())
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(noisy, cubic)
            #loss = loss_fuction(residual, label-noisy)
            loss = np.sum([
                loss_fuction(residual[j], label) for j in range(len(residual))
            ])
            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/hsid_multistage_patchsize64_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual[0]

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual[0], axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if psnr > best_psnr:
            best_psnr = psnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                }, f"checkpoints/hsid_multistage_patchsize64_best.pth")

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     scheduler.get_lr()[0]))
        print(
            "------------------------------------------------------------------"
        )

        #保存当前模型
        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict()
            }, os.path.join('./checkpoints', "model_latest.pth"))
    tb_writer.close()
def quantize_from_codebook(vectors, codebook):
    # TODO: maybe add return codes?
    quantized_vectors = np.zeros_like(vectors)
    codes, _ = vq(vectors, codebook)
    for idx, vector in enumerate(vectors):
        quantized_vectors[idx, :] = codebook[codes[idx], :]
    return quantized_vectors


def codes_from_vectors(vectors, codebook):
    codes, _ = vq(vectors, codebook)
    return codes


def vectors_from_codes(codes, codebook):
    vectors = [codebook[code] for code in codes]
    return np.array(vectors)


if __name__ == "__main__":
    from PIL import Image
    from metrics import PSNR
    from image import load_image, save_image
    from codebooks import random_codebook

    img = load_image("balloon.bmp")
    quantized_img = quantize(img, window_size=4, codebook_fun=random_codebook, codebook_size=32)
    print("PSNR:", PSNR(img, quantized_img))
    Image.fromarray(quantized_img).show()