def predict_lowlight_hsid_origin():

    #加载模型
    #hsid = HSID(36)
    hsid = HSID_origin(24)
    #hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    hsid = hsid.to(DEVICE)
    hsid.load_state_dict(
        torch.load('./checkpoints/hsid_origin_best.pth',
                   map_location='cuda:0')['gen'])

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = './data/test_lowlight/cuk12/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    hsid.eval()
    for batch_idx, (noisy_test, cubic_test,
                    label_test) in enumerate(test_dataloader):
        noisy_test = noisy_test.type(torch.FloatTensor)
        label_test = label_test.type(torch.FloatTensor)
        cubic_test = cubic_test.type(torch.FloatTensor)

        noisy_test = noisy_test.to(DEVICE)
        label_test = label_test.to(DEVICE)
        cubic_test = cubic_test.to(DEVICE)

        with torch.no_grad():

            residual = hsid(noisy_test, cubic_test)
            denoised_band = noisy_test + residual

            denoised_band_numpy = denoised_band.cpu().numpy().astype(
                np.float32)
            denoised_band_numpy = np.squeeze(denoised_band_numpy)

            denoised_hsi[:, :, batch_idx] = denoised_band_numpy

    psnr = PSNR(denoised_hsi, test_label_hsi)
    ssim = SSIM(denoised_hsi, test_label_hsi)
    sam = SAM(denoised_hsi, test_label_hsi)

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat',
                 {'denoised': denoised_hsi})

    #计算pnsr和ssim
    print("=====averPSNR:{:.4f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(
        psnr, ssim, sam))
Esempio n. 2
0
def train_model_multistage_lowlight():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight_patchsize32/')
    #print('trainset32 training example:', len(train_set32))

    #train_set_64 = HsiCubicTrainDataset('./data/train_lowlight_patchsize64/')

    #train_set_list = [train_set32, train_set_64]
    #train_set = ConcatDataset(train_set_list) #里面的样本大小必须是一致的,否则会连接失败
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = MultiStageHSID(K)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0
    start_epoch = 1
    num_epoch = 100

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        print(scheduler.get_lr())
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(noisy, cubic)
            #loss = loss_fuction(residual, label-noisy)
            loss = np.sum([
                loss_fuction(residual[j], label) for j in range(len(residual))
            ])
            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/hsid_multistage_patchsize64_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual[0]

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual[0], axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if psnr > best_psnr:
            best_psnr = psnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                }, f"checkpoints/hsid_multistage_patchsize64_best.pth")

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     scheduler.get_lr()[0]))
        print(
            "------------------------------------------------------------------"
        )

        #保存当前模型
        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict()
            }, os.path.join('./checkpoints', "model_latest.pth"))
    tb_writer.close()
Esempio n. 3
0
def train_model_residual_lowlight_rdn():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight_patchsize32/')
    #print('trainset32 training example:', len(train_set32))
    #train_set = HsiCubicTrainDataset('./data/train_lowlight/')

    #train_set_64 = HsiCubicTrainDataset('./data/train_lowlight_patchsize64/')

    #train_set_list = [train_set32, train_set_64]
    #train_set = ConcatDataset(train_set_list) #里面的样本大小必须是一致的,否则会连接失败
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    #test_data_dir = './data/test_lowlight/cuk12/'
    test_data_dir = './data/test_lowlight/cubic/'

    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    save_model_path = './checkpoints/hsirnd_cosine'
    if not os.path.exists(save_model_path):
        os.mkdir(save_model_path)

    #创建模型
    net = HSIRDN(K)
    init_params(net)
    net = nn.DataParallel(net).to(device)
    #net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    #scheduler = MultiStepLR(hsid_optimizer, milestones=[200,400], gamma=0.5)
    scheduler = CosineAnnealingLR(hsid_optimizer, T_max=600)

    #定义loss 函数
    #criterion = nn.MSELoss()

    is_resume = RESUME
    #唤醒训练
    if is_resume:
        path_chk_rest = dir_utils.get_last_path(save_model_path,
                                                'model_latest.pth')
        model_utils.load_checkpoint(net, path_chk_rest)
        start_epoch = model_utils.load_start_epoch(path_chk_rest) + 1
        model_utils.load_optim(hsid_optimizer, path_chk_rest)

        for i in range(1, start_epoch):
            scheduler.step()
        new_lr = scheduler.get_lr()[0]
        print(
            '------------------------------------------------------------------------------'
        )
        print("==> Resuming Training with learning rate:", new_lr)
        print(
            '------------------------------------------------------------------------------'
        )

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0
    if not is_resume:
        start_epoch = 1
    num_epoch = 600

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())

        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(noisy, cubic)
            alpha = 0.8
            loss = recon_criterion(residual, label - noisy)
            #loss = alpha*recon_criterion(residual, label-noisy) + (1-alpha)*loss_function_mse(residual, label-noisy)
            #loss = recon_criterion(residual, label-noisy)
            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            },
            f"{save_model_path}/hsid_rdn_4rdb_conise_l1_loss_600epoch_patchsize32_{epoch}.pth"
        )

        #测试代码
        net.eval()
        psnr_list = []
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual, axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

            test_label_current_band = test_label_hsi[:, :, batch_idx]
            psnr = PSNR(denoised_band_numpy, test_label_current_band)
            psnr_list.append(psnr)

        mpsnr = np.mean(psnr_list)

        denoised_hsi_trans = denoised_hsi.transpose(2, 0, 1)
        test_label_hsi_trans = test_label_hsi.transpose(2, 0, 1)
        mssim = SSIM(denoised_hsi_trans, test_label_hsi_trans)
        sam = SAM(denoised_hsi_trans, test_label_hsi_trans)

        #计算pnsr和ssim
        print("=====averPSNR:{:.4f}=====averSSIM:{:.4f}=====averSAM:{:.4f}".
              format(mpsnr, mssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': mpsnr,
            'average SSIM': mssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if mpsnr > best_psnr:
            best_psnr = mpsnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                },
                f"{save_model_path}/hsid_rdn_4rdb_conise_l1_loss_600epoch_patchsize32_best.pth"
            )

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     INIT_LEARNING_RATE))
        print(
            "------------------------------------------------------------------"
        )

        #保存当前模型
        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict()
            }, os.path.join(save_model_path, "model_latest.pth"))
    tb_writer.close()
def predict_lowlight_residual():

    #加载模型
    #hsid = HSID(36)
    hsid = MultiStageHSIDUpscale(36)
    #hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    hsid = hsid.to(DEVICE)
    hsid.load_state_dict(torch.load('./checkpoints/hsid_multistage_upscale_patchsize64_best.pth', map_location='cuda:0')['gen'])

    #加载数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label = scio.loadmat(mat_src_path)['label']
    #test=test.transpose((2,0,1)) #将通道维放在最前面:191*1280*307

    test_data_dir = './data/test_lowlight/origin/'
    test_set = HsiLowlightTestDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    for batch_idx, (noisy, label) in enumerate(test_dataloader):
        noisy = noisy.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)
        
        batch_size, width, height, band_num = noisy.shape
        denoised_hsi = np.zeros((width, height, band_num))

        noisy = noisy.to(DEVICE)
        label = label.to(DEVICE)

        with torch.no_grad():
            for i in range(band_num): #遍历每个band去处理
                current_noisy_band = noisy[:,:,:,i]
                current_noisy_band = current_noisy_band[:,None]

                adj_spectral_bands = get_adjacent_spectral_bands(noisy, K, i)# shape: batch_size, width, height, band_num
                adj_spectral_bands = adj_spectral_bands.permute(0, 3,1,2)#交换第一维和第三维 ,shape: batch_size, band_num, height, width               
                adj_spectral_bands = adj_spectral_bands.to(DEVICE)
                residual = hsid(current_noisy_band, adj_spectral_bands)
                denoised_band = current_noisy_band + residual[0]

                denoised_band_numpy = denoised_band.cpu().numpy().astype(np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:,:,i] = denoised_band_numpy

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat', {'denoised': denoised_hsi})

    psnr = PSNR(denoised_hsi, test_label)
    ssim = SSIM(denoised_hsi, test_label)
    sam = SAM(denoised_hsi, test_label)
    #计算pnsr和ssim
    print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(psnr, ssim, sam)) 
def train_model_residual_lowlight_twostage_gan_best():

    #设置超参数
    batchsize = 128
    init_lr = 0.001
    K_adjacent_band = 36
    display_step = 20
    display_band = 20
    is_resume = False
    lambda_recon = 10

    start_epoch = 1

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight/')
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=batchsize,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    test_batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=test_batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = HSIDDenseNetTwoStage(K_adjacent_band)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    #创建discriminator
    disc = DiscriminatorABC(2, 4)
    init_params(disc)
    disc = disc.to(device)
    disc_opt = torch.optim.Adam(disc.parameters(), lr=init_lr)

    num_epoch = 100
    print('epoch count == ', num_epoch)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=init_lr)

    #Scheduler
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)
    warmup_epochs = 3
    #scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(hsid_optimizer, num_epoch-warmup_epochs+40, eta_min=1e-7)
    #scheduler = GradualWarmupScheduler(hsid_optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
    #scheduler.step()

    #唤醒训练
    if is_resume:
        model_dir = './checkpoints'
        path_chk_rest = dir_utils.get_last_path(model_dir, 'model_latest.pth')
        model_utils.load_checkpoint(net, path_chk_rest)
        start_epoch = model_utils.load_start_epoch(path_chk_rest) + 1
        model_utils.load_optim(hsid_optimizer, path_chk_rest)
        model_utils.load_disc_checkpoint(disc, path_chk_rest)
        model_utils.load_disc_optim(disc_opt, path_chk_rest)

        for i in range(1, start_epoch):
            scheduler.step()
        new_lr = scheduler.get_lr()[0]
        print(
            '------------------------------------------------------------------------------'
        )
        print("==> Resuming Training with learning rate:", new_lr)
        print(
            '------------------------------------------------------------------------------'
        )

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        #print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            ### Update discriminator ###
            disc_opt.zero_grad(
            )  # Zero out the gradient before backpropagation
            with torch.no_grad():
                fake, fake_stage2 = net(noisy, cubic)
            #print('noisy shape =', noisy.shape, fake_stage2.shape)
            #fake.detach()
            disc_fake_hat = disc(fake_stage2.detach() + noisy,
                                 noisy)  # Detach generator
            disc_fake_loss = adv_criterion(disc_fake_hat,
                                           torch.zeros_like(disc_fake_hat))
            disc_real_hat = disc(label, noisy)
            disc_real_loss = adv_criterion(disc_real_hat,
                                           torch.ones_like(disc_real_hat))
            disc_loss = (disc_fake_loss + disc_real_loss) / 2
            disc_loss.backward(retain_graph=True)  # Update gradients
            disc_opt.step()  # Update optimizer

            ### Update generator ###
            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual, residual_stage2 = net(noisy, cubic)
            disc_fake_hat = disc(residual_stage2 + noisy, noisy)
            gen_adv_loss = adv_criterion(disc_fake_hat,
                                         torch.ones_like(disc_fake_hat))

            alpha = 0.2
            beta = 0.2
            rec_loss = beta * (alpha*loss_fuction(residual, label-noisy) + (1-alpha) * recon_criterion(residual, label-noisy)) \
             + (1-beta) * (alpha*loss_fuction(residual_stage2, label-noisy) + (1-alpha) * recon_criterion(residual_stage2, label-noisy))

            loss = gen_adv_loss + lambda_recon * rec_loss

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                    print(
                        f"rec_loss =  {rec_loss.item()}, gen_adv_loss = {gen_adv_loss.item()}"
                    )

                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
                'disc': disc.state_dict(),
                'disc_opt': disc_opt.state_dict()
            }, f"checkpoints/two_stage_hsid_dense_gan_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual, residual_stage2 = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual_stage2

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual, axis=0)
                    residual_stage2_squeezed = torch.squeeze(residual_stage2,
                                                             axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual_stage2",
                                        residual_stage2_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if psnr > best_psnr:
            best_psnr = psnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                    'disc': disc.state_dict(),
                    'disc_opt': disc_opt.state_dict()
                }, f"checkpoints/two_stage_hsid_dense_gan_best.pth")

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     scheduler.get_lr()[0]))
        print(
            "------------------------------------------------------------------"
        )

        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
                'disc': disc.state_dict(),
                'disc_opt': disc_opt.state_dict()
            }, os.path.join('./checkpoints', "model_latest.pth"))

    tb_writer.close()
def predict_lowlight_hsid_origin():

    #加载模型
    #hsid = HSID(36)
    hsid = HSIRDNECA_Denoise(K)
    hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    save_model_path = './checkpoints/hsirnd_denoise_l1loss'

    #hsid = hsid.to(DEVICE)
    hsid.load_state_dict(
        torch.load(save_model_path +
                   '/hsid_rdn_eca_l1_loss_600epoch_patchsize32_best.pth',
                   map_location='cuda:0')['gen'])

    #加载数据
    test_data_dir = './data/denoise/test/level25'
    test_set = HsiTrainDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #指定结果输出路径
    test_result_output_path = './data/denoise/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    hsid.eval()
    psnr_list = []
    for batch_idx, (noisy, label) in enumerate(test_dataloader):
        noisy = noisy.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)

        batch_size, width, height, band_num = noisy.shape
        denoised_hsi = np.zeros((width, height, band_num))

        noisy = noisy.to(DEVICE)
        label = label.to(DEVICE)

        with torch.no_grad():
            for i in range(band_num):  #遍历每个band去处理
                current_noisy_band = noisy[:, :, :, i]
                current_noisy_band = current_noisy_band[:, None]

                adj_spectral_bands = get_adjacent_spectral_bands(noisy, K, i)
                #adj_spectral_bands = torch.transpose(adj_spectral_bands,3,1) #将通道数置换到第二维
                adj_spectral_bands = adj_spectral_bands.permute(0, 3, 1, 2)
                adj_spectral_bands_unsqueezed = adj_spectral_bands.unsqueeze(1)
                #print(current_noisy_band.shape, adj_spectral_bands.shape)
                residual = hsid(current_noisy_band,
                                adj_spectral_bands_unsqueezed)
                denoised_band = residual + current_noisy_band
                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, i] += denoised_band_numpy

                test_label_current_band = label[:, :, :, i]

                label_band_numpy = test_label_current_band.cpu().numpy(
                ).astype(np.float32)
                label_band_numpy = np.squeeze(label_band_numpy)

                #print(denoised_band_numpy.shape, label_band_numpy.shape, label.shape)
                psnr = PSNR(denoised_band_numpy, label_band_numpy)
                psnr_list.append(psnr)

        mpsnr = np.mean(psnr_list)

        denoised_hsi_trans = denoised_hsi.transpose(2, 0, 1)
        test_label_hsi_trans = np.squeeze(label.cpu().numpy().astype(
            np.float32)).transpose(2, 0, 1)
        mssim = SSIM(denoised_hsi_trans, test_label_hsi_trans)
        sam = SAM(denoised_hsi_trans, test_label_hsi_trans)

        #计算pnsr和ssim
        print("=====averPSNR:{:.4f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(mpsnr, mssim, sam))

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat',
                 {'denoised': denoised_hsi})
def predict_lowlight_residual():

    #加载模型
    encam = ENCAM()
    #hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    encam = encam.to(DEVICE)

    encam.eval()
    encam.load_state_dict(
        torch.load('./checkpoints/encam_best_08_27.pth',
                   map_location='cuda:0')['gen'])

    #加载数据
    mat_src_path = '../HSID/data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label = scio.loadmat(mat_src_path)['label']
    #test=test.transpose((2,0,1)) #将通道维放在最前面:191*1280*307

    test_data_dir = '../HSID/data/test_lowlight/origin/'
    test_set = HsiLowlightTestDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    for batch_idx, (noisy, label) in enumerate(test_dataloader):
        noisy = noisy.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)

        batch_size, width, height, band_num = noisy.shape
        denoised_hsi = np.zeros((width, height, band_num))

        noisy = noisy.to(DEVICE)
        label = label.to(DEVICE)

        with torch.no_grad():
            for i in range(band_num):  #遍历每个band去处理
                current_noisy_band = noisy[:, :, :, i]
                current_noisy_band = current_noisy_band[:, None]

                adj_spectral_bands = get_adjacent_spectral_bands(
                    noisy, K, i)  # shape: batch_size, width, height, band_num
                adj_spectral_bands = adj_spectral_bands.permute(
                    0, 3, 1,
                    2)  #交换第一维和第三维 ,shape: batch_size, band_num, height, width
                adj_spectral_bands = torch.unsqueeze(adj_spectral_bands, 1)
                adj_spectral_bands = adj_spectral_bands.to(DEVICE)
                print('adj_spectral_bands : ', adj_spectral_bands.shape)
                print('adj_spectral_bands shape[4] =',
                      adj_spectral_bands.shape[4])
                #这里需要将current_noisy_band和adj_spectral_bands拆分成4份,每份大小为batchsize,1, band_num , height/2, width/2
                current_noisy_band_00 = current_noisy_band[:, :,
                                                           0:current_noisy_band
                                                           .shape[2] // 2,
                                                           0:current_noisy_band
                                                           .shape[3] // 2]
                adj_spectral_bands_00 = adj_spectral_bands[:, :, :,
                                                           0:adj_spectral_bands
                                                           .shape[3] // 2,
                                                           0:adj_spectral_bands
                                                           .shape[4] // 2]
                residual_00 = encam(current_noisy_band_00,
                                    adj_spectral_bands_00)
                denoised_band_00 = current_noisy_band_00 + residual_00

                current_noisy_band_00 = current_noisy_band[:, :,
                                                           0:current_noisy_band
                                                           .shape[2] // 2,
                                                           0:current_noisy_band
                                                           .shape[3] // 2]
                adj_spectral_bands_00 = adj_spectral_bands[:, :, :,
                                                           0:adj_spectral_bands
                                                           .shape[3] // 2,
                                                           0:adj_spectral_bands
                                                           .shape[4] // 2]
                residual_00 = encam(current_noisy_band_00,
                                    adj_spectral_bands_00)
                denoised_band_01 = current_noisy_band_00 + residual_00

                current_noisy_band_00 = current_noisy_band[:, :, 0:(
                    current_noisy_band.shape[2] //
                    2), 0:(current_noisy_band.shape[3] // 2)]
                adj_spectral_bands_00 = adj_spectral_bands[:, :, :,
                                                           0:adj_spectral_bands
                                                           .shape[3] // 2,
                                                           0:adj_spectral_bands
                                                           .shape[4] // 2]
                residual_00 = encam(current_noisy_band_00,
                                    adj_spectral_bands_00)
                denoised_band_10 = current_noisy_band_00 + residual_00

                current_noisy_band_00 = current_noisy_band[:, :,
                                                           0:current_noisy_band
                                                           .shape[2] // 2,
                                                           0:current_noisy_band
                                                           .shape[3] // 2]
                adj_spectral_bands_11 = adj_spectral_bands[:, :, :,
                                                           0:adj_spectral_bands
                                                           .shape[3] // 2,
                                                           0:adj_spectral_bands
                                                           .shape[4] // 2]
                residual_00 = encam(current_noisy_band_00,
                                    adj_spectral_bands_00)
                denoised_band_11 = current_noisy_band_00 + residual_00

                denoised_band_0 = torch.cat(
                    (denoised_band_00, denoised_band_01), dim=3)
                denoised_band_1 = torch.cat(
                    (denoised_band_10, denoised_band_11), dim=3)
                denoised_band = torch.cat((denoised_band_0, denoised_band_1),
                                          dim=2)

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, i] = denoised_band_numpy

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat',
                 {'denoised': denoised_hsi})

    psnr = PSNR(denoised_hsi, test_label)
    ssim = SSIM(denoised_hsi, test_label)
    sam = SAM(denoised_hsi, test_label)
    #计算pnsr和ssim
    print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(
        psnr, ssim, sam))
def train_model_residual_lowlight_rdn():

    device = DEVICE
    #准备数据
    train = np.load('./data/denoise/train_washington8.npy')
    train = train.transpose((2, 1, 0))

    test = np.load('./data/denoise/train_washington8.npy')
    #test=test.transpose((2,1,0))
    test = test.transpose((2, 1, 0))  #将通道维放在最前面

    save_model_path = './checkpoints/hsirnd_denoise_l1loss'
    if not os.path.exists(save_model_path):
        os.mkdir(save_model_path)

    #创建模型
    net = HSIRDNECA_Denoise(K)
    init_params(net)
    net = nn.DataParallel(net).to(device)
    #net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[200, 400], gamma=0.5)

    #定义loss 函数
    #criterion = nn.MSELoss()

    gen_epoch_loss_list = []

    cur_step = 0

    best_psnr = 0
    best_epoch = 0
    best_iter = 0
    start_epoch = 1
    num_epoch = 600

    mpsnr_list = []
    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())

        gen_epoch_loss = 0

        net.train()

        channels = 191  # 191 channels
        data_patches, data_cubic_patches = datagenerator(train, channels)

        data_patches = torch.from_numpy(data_patches.transpose((
            0,
            3,
            1,
            2,
        )))
        data_cubic_patches = torch.from_numpy(
            data_cubic_patches.transpose((0, 4, 1, 2, 3)))

        DDataset = DenoisingDataset(data_patches, data_cubic_patches, SIGMA)

        print('yes')
        DLoader = DataLoader(dataset=DDataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True)  # loader出问题了

        epoch_loss = 0
        start_time = time.time()

        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for step, x_y in enumerate(DLoader):
            #print('batch_idx=', batch_idx)
            batch_x_noise, batch_y_noise, batch_x = x_y[0], x_y[1], x_y[2]

            batch_x_noise = batch_x_noise.to(device)
            batch_y_noise = batch_y_noise.to(device)
            batch_x = batch_x.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(batch_x_noise, batch_y_noise)
            alpha = 0.8
            loss = recon_criterion(residual, batch_x - batch_x_noise)
            #loss = alpha*recon_criterion(residual, label-noisy) + (1-alpha)*loss_function_mse(residual, label-noisy)
            #loss = recon_criterion(residual, label-noisy)
            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            if step % 10 == 0:
                print('%4d %4d / %4d loss = %2.8f' %
                      (epoch + 1, step, data_patches.size(0) // BATCH_SIZE,
                       loss.item() / BATCH_SIZE))

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            },
            f"{save_model_path}/hsid_rdn_eca_l1_loss_600epoch_patchsize32_{epoch}.pth"
        )

        #测试代码
        net.eval()
        """
        channel_s = 191  # 设置多少波段
        data_patches, data_cubic_patches = datagenerator(test, channel_s)

        data_patches = torch.from_numpy(data_patches.transpose((0, 3, 1, 2,)))
        data_cubic_patches = torch.from_numpy(data_cubic_patches.transpose((0, 4, 1, 2, 3)))

        DDataset = DenoisingDataset(data_patches, data_cubic_patches, SIGMA)
        DLoader = DataLoader(dataset=DDataset, batch_size=BATCH_SIZE, shuffle=True)
        epoch_loss = 0
        
        for step, x_y in enumerate(DLoader):
            batch_x_noise, batch_y_noise, batch_x = x_y[0], x_y[1], x_y[2]

            batch_x_noise = batch_x_noise.to(DEVICE)
            batch_y_noise = batch_y_noise.to(DEVICE)
            batch_x = batch_x.to(DEVICE)
            residual = net(batch_x_noise, batch_y_noise)

            loss = loss_fuction(residual, batch_x-batch_x_noise)

            epoch_loss += loss.item()

            if step % 10 == 0:
                print('%4d %4d / %4d test loss = %2.4f' % (
                    epoch + 1, step, data_patches.size(0) // BATCH_SIZE, loss.item() / BATCH_SIZE))
        """
        #加载数据
        test_data_dir = './data/denoise/test/'
        test_set = HsiTrainDataset(test_data_dir)

        test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

        #指定结果输出路径
        test_result_output_path = './data/denoise/testresult/'
        if not os.path.exists(test_result_output_path):
            os.makedirs(test_result_output_path)

        #逐个通道的去噪
        """
        分配一个numpy数组,存储去噪后的结果
        遍历所有通道,
        对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
        调用hsid进行预测
        将预测到的residual和输入的noise加起来,得到输出band

        将去噪后的结果保存成mat结构
        """
        psnr_list = []
        for batch_idx, (noisy, label) in enumerate(test_dataloader):
            noisy = noisy.type(torch.FloatTensor)
            label = label.type(torch.FloatTensor)

            batch_size, width, height, band_num = noisy.shape
            denoised_hsi = np.zeros((width, height, band_num))

            noisy = noisy.to(DEVICE)
            label = label.to(DEVICE)

            with torch.no_grad():
                for i in range(band_num):  #遍历每个band去处理
                    current_noisy_band = noisy[:, :, :, i]
                    current_noisy_band = current_noisy_band[:, None]

                    adj_spectral_bands = get_adjacent_spectral_bands(
                        noisy, K, i)
                    #adj_spectral_bands = torch.transpose(adj_spectral_bands,3,1) #将通道数置换到第二维
                    adj_spectral_bands = adj_spectral_bands.permute(0, 3, 1, 2)
                    adj_spectral_bands_unsqueezed = adj_spectral_bands.unsqueeze(
                        1)
                    #print(current_noisy_band.shape, adj_spectral_bands.shape)
                    residual = net(current_noisy_band,
                                   adj_spectral_bands_unsqueezed)
                    denoised_band = residual + current_noisy_band
                    denoised_band_numpy = denoised_band.cpu().numpy().astype(
                        np.float32)
                    denoised_band_numpy = np.squeeze(denoised_band_numpy)

                    denoised_hsi[:, :, i] += denoised_band_numpy

                    test_label_current_band = label[:, :, :, i]

                    label_band_numpy = test_label_current_band.cpu().numpy(
                    ).astype(np.float32)
                    label_band_numpy = np.squeeze(label_band_numpy)

                    #print(denoised_band_numpy.shape, label_band_numpy.shape, label.shape)
                    psnr = PSNR(denoised_band_numpy, label_band_numpy)
                    psnr_list.append(psnr)

            mpsnr = np.mean(psnr_list)
            mpsnr_list.append(mpsnr)

            denoised_hsi_trans = denoised_hsi.transpose(2, 0, 1)
            test_label_hsi_trans = np.squeeze(label.cpu().numpy().astype(
                np.float32)).transpose(2, 0, 1)
            mssim = SSIM(denoised_hsi_trans, test_label_hsi_trans)
            sam = SAM(denoised_hsi_trans, test_label_hsi_trans)

            #计算pnsr和ssim
            print(
                "=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
                format(mpsnr, mssim, sam))

        #保存best模型
        if mpsnr > best_psnr:
            best_psnr = mpsnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                },
                f"{save_model_path}/hsid_rdn_eca_l1_loss_600epoch_patchsize32_best.pth"
            )

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, mpsnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     INIT_LEARNING_RATE))
        print(
            "------------------------------------------------------------------"
        )
def train_model():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_cubic/')
    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    test_label_hsi = np.load('./data/origin/test_washington.npy')

    #加载测试数据
    test_data_dir = './data/test_level25/'
    test_set = HsiTrainDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #创建模型
    net = HSID_1x3(K)
    init_params(net)
    net = nn.DataParallel(net).to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer,
                            milestones=[15, 30, 45],
                            gamma=0.25)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    for epoch in range(NUM_EPOCHS):

        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):

            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            denoised_img = net(noisy, cubic)
            loss = loss_fuction(denoised_img, label)

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        scheduler.step()
        print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/hsid_{epoch}.pth")

        #预测代码
        net.eval()
        for batch_idx, (noisy, label) in enumerate(test_dataloader):
            noisy = noisy.type(torch.FloatTensor)
            label = label.type(torch.FloatTensor)

            batch_size, width, height, band_num = noisy.shape
            denoised_hsi = np.zeros((width, height, band_num))

            noisy = noisy.to(DEVICE)
            label = label.to(DEVICE)

            with torch.no_grad():
                for i in range(band_num):  #遍历每个band去处理
                    current_noisy_band = noisy[:, :, :, i]
                    current_noisy_band = current_noisy_band[:, None]

                    adj_spectral_bands = get_adjacent_spectral_bands(
                        noisy, K,
                        i)  # shape: batch_size, width, height, band_num
                    adj_spectral_bands = torch.transpose(
                        adj_spectral_bands, 3, 1
                    )  #交换第一维和第三维 ,shape: batch_size, band_num, height, width
                    denoised_band = net(current_noisy_band, adj_spectral_bands)

                    denoised_band_numpy = denoised_band.cpu().numpy().astype(
                        np.float32)
                    denoised_band_numpy = np.squeeze(denoised_band_numpy)

                    denoised_hsi[:, :, i] = denoised_band_numpy

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))

    tb_writer.close()
def train_model_residual_lowlight_twostage_unet():

    learning_rate = INIT_LEARNING_RATE * 0.5
    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight/')
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = TwoStageHSIDWithUNet(K)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    for epoch in range(NUM_EPOCHS):
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual, residual_stage2 = net(noisy, cubic)
            loss = loss_function_with_tvloss(
                residual, label - noisy) + loss_function_with_tvloss(
                    residual_stage2, label - noisy)

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/two_stage_unet_hsid_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual, residual_stage2 = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual_stage2

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

    tb_writer.close()
def predict_cubic():
    #加载模型
    hsid = HSID(36)
    hsid = nn.DataParallel(hsid).to(DEVICE)

    hsid.load_state_dict(torch.load('./checkpoints/hsid_70.pth')['gen'])

    #加载数据
    test_label_hsi = np.load('./data/origin/test_washington.npy')

    batch_size = 1
    test_data_dir = './data/test_cubic/'
    test_set = HsiCubicTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape
    
    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    for batch_idx, (noisy, cubic, label) in enumerate(test_dataloader):
        noisy = noisy.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)
        cubic = cubic.type(torch.FloatTensor)

        batch_size, width, height, band_num = noisy.shape

        noisy = noisy.to(DEVICE)
        cubic = cubic.to(DEVICE)

        with torch.no_grad():
                       
            denoised_band = hsid(noisy, cubic)

            denoised_band_numpy = denoised_band.cpu().numpy().astype(np.float32)
            denoised_band_numpy = np.squeeze(denoised_band_numpy)

            denoised_hsi[:,:,batch_idx] = denoised_band_numpy

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat', {'denoised': denoised_hsi})

    psnr = PSNR(denoised_hsi, test_label_hsi)
    ssim = SSIM(denoised_hsi, test_label_hsi)
    sam = SAM(denoised_hsi, test_label_hsi)

    #计算pnsr和ssim
    print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(psnr, ssim, sam)) 
def predict():

    #加载模型
    hsid = HSID(36)
    hsid = nn.DataParallel(hsid).to(DEVICE)

    hsid.load_state_dict(torch.load('./checkpoints/hsid_5.pth')['gen'])

    #加载数据
    test=np.load('./data/origin/test_washington.npy')
    #test=test.transpose((2,0,1)) #将通道维放在最前面:191*1280*307

    test_data_dir = './data/test_level25/'
    test_set = HsiTrainDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    for batch_idx, (noisy, label) in enumerate(test_dataloader):
        noisy = noisy.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)
        
        batch_size, width, height, band_num = noisy.shape
        denoised_hsi = np.zeros((width, height, band_num))

        noisy = noisy.to(DEVICE)
        label = label.to(DEVICE)

        with torch.no_grad():
            for i in range(band_num): #遍历每个band去处理
                current_noisy_band = noisy[:,:,:,i]
                current_noisy_band = current_noisy_band[:,None]

                adj_spectral_bands = get_adjacent_spectral_bands(noisy, K, i)# shape: batch_size, width, height, band_num
                adj_spectral_bands = torch.transpose(adj_spectral_bands,3,1)#交换第一维和第三维 ,shape: batch_size, band_num, height, width               
                denoised_band = hsid(current_noisy_band, adj_spectral_bands)

                denoised_band_numpy = denoised_band.cpu().numpy().astype(np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:,:,i] = denoised_band_numpy

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat', {'denoised': denoised_hsi})

    psnr = PSNR(denoised_hsi, test)
    ssim = SSIM(denoised_hsi, test)
    sam = SAM(denoised_hsi, test)
    #计算pnsr和ssim
    print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(psnr, ssim, sam)) 
def predict_lowlight_residual():

    #加载模型
    encam = ENCAM()
    #hsid = nn.DataParallel(hsid).to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    encam = encam.to(DEVICE)

    encam.eval()
    encam.load_state_dict(
        torch.load('./checkpoints/encam_best.pth',
                   map_location='cuda:0')['gen'])

    #加载数据
    mat_src_path = '../HSID/data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label = scio.loadmat(mat_src_path)['label']
    #test=test.transpose((2,0,1)) #将通道维放在最前面:191*1280*307
    #test_label_tensor = torch.from_numpy(test_label)
    #test_label_tensor = torch.unsqueeze(test_label_tensor, 0)
    #test_label_tensor = test_label_tensor.permute(0,  3,1,2)

    #test_label_tensor = F.interpolate(test_label_tensor, scale_factor=0.5, mode='bilinear')
    #test_label_tensor = test_label_tensor.permute(0,  2,3,1)
    #test_label = torch.squeeze(test_label_tensor).numpy()

    test_data_dir = '../HSID/data/test_lowlight/origin/'
    test_set = HsiLowlightTestDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    for batch_idx, (noisy, label) in enumerate(test_dataloader):
        noisy = noisy.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)

        batch_size, width, height, band_num = noisy.shape
        denoised_hsi = np.zeros((width, height, band_num))

        #对图像下采样
        noisy_permute = noisy.permute(
            0, 3, 1, 2)  #交换第一维和第三维 ,shape: batch_size, band_num, height, width
        label_permute = label.permute(0, 3, 1, 2)
        noisy_down = F.interpolate(noisy_permute,
                                   scale_factor=0.5,
                                   mode='bilinear')
        label_down = F.interpolate(label_permute,
                                   scale_factor=0.5,
                                   mode='bilinear')

        #batch_size, band_num, width, height = noisy_down.shape
        #denoised_hsi = np.zeros((width, height, band_num))

        noisy_down = noisy_down.to(DEVICE)
        label_down = label_down.to(DEVICE)

        with torch.no_grad():
            for i in range(band_num):  #遍历每个band去处理
                current_noisy_band = noisy_down[:, i, :, :]
                current_noisy_band = current_noisy_band[:, None]
                noisy_down = noisy_down.permute(0, 2, 3, 1)
                adj_spectral_bands = get_adjacent_spectral_bands(
                    noisy_down, K,
                    i)  # shape: batch_size, width, height, band_num
                noisy_down = noisy_down.permute(0, 3, 1, 2)
                adj_spectral_bands = adj_spectral_bands.permute(
                    0, 3, 1,
                    2)  #交换第一维和第三维 ,shape: batch_size, band_num, height, width
                adj_spectral_bands = torch.unsqueeze(adj_spectral_bands, 1)
                adj_spectral_bands = adj_spectral_bands.to(DEVICE)
                residual = encam(current_noisy_band, adj_spectral_bands)
                denoised_band = current_noisy_band + residual

                #对denoised_hsi进行上采样
                denoised_band = F.interpolate(denoised_band,
                                              scale_factor=2,
                                              mode='bilinear')

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, i] = denoised_band_numpy

    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat',
                 {'denoised': denoised_hsi})

    psnr = PSNR(denoised_hsi, test_label)
    ssim = SSIM(denoised_hsi, test_label)
    sam = SAM(denoised_hsi, test_label)
    #计算pnsr和ssim
    print("=====averPSNR:{:.4f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(
        psnr, ssim, sam))
def predict_lowlight_hsid_origin():

    #加载模型
    #hsid = HSID(36)
    hsid = HSIRDNECA(24)
    hsid = nn.DataParallel(hsid).to(DEVICE)
    #hsid = hsid.to(DEVICE)
    #device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    save_model_path = './checkpoints/hsirnd_ablation_convlayernum_4'

    hsid.load_state_dict(
        torch.load(save_model_path +
                   '/hsid_rdn_eca_l1_loss_600epoch_patchsize32_best.pth',
                   map_location='cuda:0')['gen'])

    #加载测试label数据
    mat_src_path = './data/lowlight_origin_outdoor/test/15ms/007_2_2021-01-20_024.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label_normalized_hsi']

    #加载测试数据
    batch_size = 1
    #test_data_dir = './data/test_lowlight/cuk12/'
    test_data_dir = './data/test_lowli_outdoor_k12/007_2_2021-01-20_024/'

    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #指定结果输出路径
    test_result_output_path = './data/testresult/'
    if not os.path.exists(test_result_output_path):
        os.makedirs(test_result_output_path)

    #逐个通道的去噪
    """
    分配一个numpy数组,存储去噪后的结果
    遍历所有通道,
    对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
    调用hsid进行预测
    将预测到的residual和输入的noise加起来,得到输出band

    将去噪后的结果保存成mat结构
    """
    hsid.eval()
    psnr_list = []

    for batch_idx, (noisy_test, cubic_test,
                    label_test) in enumerate(test_dataloader):
        noisy_test = noisy_test.type(torch.FloatTensor)
        label_test = label_test.type(torch.FloatTensor)
        cubic_test = cubic_test.type(torch.FloatTensor)

        noisy_test = noisy_test.to(DEVICE)
        label_test = label_test.to(DEVICE)
        cubic_test = cubic_test.to(DEVICE)

        with torch.no_grad():

            residual = hsid(noisy_test, cubic_test)
            denoised_band = noisy_test + residual

            denoised_band_numpy = denoised_band.cpu().numpy().astype(
                np.float32)
            denoised_band_numpy = np.squeeze(denoised_band_numpy)

            denoised_hsi[:, :, batch_idx] = denoised_band_numpy

        test_label_current_band = test_label_hsi[:, :, batch_idx]

        psnr = PSNR(denoised_band_numpy, test_label_current_band)
        psnr_list.append(psnr)
    #mdict是python字典类型,value值需要是一个numpy数组
    scio.savemat(test_result_output_path + 'result.mat',
                 {'denoised': denoised_hsi})

    #计算pnsr和ssim
    mpsnr = np.mean(psnr_list)
    #mssim = np.mean(ssim_list)
    #sam = SAM(denoised_hsi.transpose(2,0,1), test_label_hsi.transpose(2, 0, 1))

    denoised_hsi_trans = denoised_hsi.transpose(2, 0, 1)
    test_label_hsi_trans = test_label_hsi.transpose(2, 0, 1)
    mssim = SSIM(denoised_hsi_trans, test_label_hsi_trans)
    sam = SAM(denoised_hsi_trans, test_label_hsi_trans)
    print("=====averPSNR:{:.4f}=====averSSIM:{:.4f}=====averSAM:{:.4f}".format(
        mpsnr, mssim, sam))
Esempio n. 15
0
def train_model_residual_lowlight():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('../HSID/data/train_lowlight/')
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = '../HSID/data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = '../HSID/data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = ENCAM()
    #init_params(net) #创建encam时,已经通过self._initialize_weights()进行了初始化
    net = net.to(device)
    #net = nn.DataParallel(net)
    #net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[15, 30, 45], gamma=0.1)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0

    for epoch in range(NUM_EPOCHS):
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, cubic, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #noisy, cubic, label = next(iter(train_loader)) #从dataloader中取出一个batch
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(noisy, cubic)
            loss = loss_fuction(residual, label - noisy)

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/encam_{epoch}.pth")

        #测试代码

        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                #对图像下采样
                #noisy_permute = noisy.permute(0, 3,1,2)#交换第一维和第三维 ,shape: batch_size, band_num, height, width
                #label_permute = label.permute(0, 3, 1, 2)
                noisy_test_down = F.interpolate(noisy_test,
                                                scale_factor=0.5,
                                                mode='bilinear')
                cubic_test_squeeze = torch.squeeze(cubic_test, 0)
                cubic_test_down = F.interpolate(cubic_test_squeeze,
                                                scale_factor=0.5,
                                                mode='bilinear')
                cubic_test_down_unsqueeze = torch.unsqueeze(cubic_test_down, 0)
                residual = net(noisy_test_down, cubic_test_down_unsqueeze)
                denoised_band = noisy_test_down + residual

                #图像上采样
                denoised_band = F.interpolate(denoised_band,
                                              scale_factor=2,
                                              mode='bilinear')

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if psnr > best_psnr:
            best_psnr = psnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                }, f"checkpoints/encam_best.pth")

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

    tb_writer.close()