Exemplo n.º 1
0
def test(model, test_dataset, test_file_num, epoch, writer=None):
    model.eval()
    if test_file_num is None:
        test_file_num = len(test_dataset)
    rmse_sum, psnr_sum, ssim_sum = 0, 0, 0
    with torch.no_grad():
        for i in range(test_file_num):
            data = test_dataset[i]
            O, B, M = data['O'], data['B'], data['M']
            O = torch.Tensor(O).unsqueeze(0).cuda()
            B = torch.Tensor(B).unsqueeze(0).cuda()
            M = torch.Tensor(M).unsqueeze(0).cuda()
            recon = model.forward(O)
            # recon = model.forward(recon)
            recon = torch.clamp(recon, 0., 1.)
            rmse = batch_RMSE_G(B, recon)
            psnr = batch_PSNR(B, recon)
            ssim = get_ssim(B, recon)
            # print(O.shape, B.shape, M.shape, mask.shape, recon.shape)
            print('epoch {}, test img {}, rmse {}, psnr {}, ssim {}'.format(
                epoch, i, rmse, psnr, ssim))
            rmse_sum += rmse
            psnr_sum += psnr
            ssim_sum += ssim

    print('\nepoch {}, avg RMSE {}, avg PSNR {}, avg SSIM {}\n'.format(
        epoch, rmse_sum / test_file_num, psnr_sum / test_file_num,
        ssim_sum / test_file_num))
    if writer is not None:
        writer.add_scalar('test_RMSE', rmse_sum / test_file_num, epoch)
        writer.add_scalar('test_PSNR', psnr_sum / test_file_num, epoch)
        writer.add_scalar('test_SSIM', ssim_sum / test_file_num, epoch)
Exemplo n.º 2
0
def generate_result(model, outf, test_dataset, mat=False, ouput_img=True):
    model.eval()
    test_file_num = len(test_dataset)
    if ouput_img and not os.path.exists(os.path.join(outf, 'output_imgs')):
        os.makedirs(os.path.join(outf, 'output_imgs'))
    results = []
    with torch.no_grad():
        for i in range(test_file_num):
            data = test_dataset[i]
            O, B, M = data['O'], data['B'], data['M']
            O = torch.Tensor(O).unsqueeze(0).cuda()
            B = torch.Tensor(B).unsqueeze(0).cuda()
            M = torch.Tensor(M).unsqueeze(0).cuda()
            recon = model.forward(O)
            # recon = model.forward(recon)
            recon = torch.clamp(recon, 0., 1.)
            rmse = batch_RMSE_G(B, recon)
            psnr = batch_PSNR(B, recon)
            ssim = get_ssim(B, recon)
            print('test img {}, rmse {}, psnr {}, ssim {}'.format(i, rmse, psnr, ssim))
            results.append([rmse, psnr, ssim])
            if ouput_img:
                generate_imgs(O, B, recon, outf, i, psnr)

    results = np.array(results)
    result_avg = np.sum(results, axis=0) / test_file_num
    print(results.shape, result_avg.shape)
    print('average results : {}'.format(result_avg))
    results = np.append(results, [result_avg], axis=0)

    excel = pd.DataFrame(data=results, columns=['rmse', 'psnr', 'ssim'])
    excel.to_csv(os.path.join(outf, 'results.csv'))
Exemplo n.º 3
0
def test(model, test_dataset, epoch, writer):
    test_image_num = len(test_dataset)
    model.eval()
    psnr_sum = 0
    for i, data in enumerate(test_dataset):
        real_hyper, _, real_rgb = data
        real_hyper, real_rgb = real_hyper.cuda(), real_rgb.cuda()
        fake_hyper, _ = model.forward(real_rgb)

        psnr = batch_PSNR(real_hyper, fake_hyper)
        print('test img [{}/{}], psnr {}'.format(i, test_image_num,
                                                 psnr.item()))
        psnr_sum += psnr.item()

    print('total {} test images, avg psnr {}'.format(test_image_num, psnr_sum /
                                                     test_image_num))
    writer.add_scalar('test_psnr', psnr_sum / test_image_num, epoch)
Exemplo n.º 4
0
def train(model, criterion, optimizer, train_loader, epoch, writer):
    train_times_per_epoch = len(train_loader)
    model.train()
    for i, data in enumerate(train_loader):
        model.zero_grad()
        optimizer.zero_grad()
        real_hyper, _, real_rgb = data
        real_hyper, real_rgb = real_hyper.cuda(), real_rgb.cuda()
        fake_hyper, _ = model.forward(real_rgb)
        loss = criterion(fake_hyper, real_hyper)
        loss.backward()
        optimizer.step()

        if i % 10 == 0:
            psnr = batch_PSNR(real_hyper, fake_hyper.detach())
            print("[epoch {}][{}/{}] psnr: {}".format(epoch, i,
                                                      len(train_loader),
                                                      psnr.item()))
            writer.add_scalar('train_psnr', psnr.item(),
                              train_times_per_epoch * epoch + i)
Exemplo n.º 5
0
def train_epoch(model,
                optimizer,
                train_loader,
                l1_criterion,
                mask_criterion,
                ssim_criterion,
                epoch,
                writer=None,
                radio=1):
    model.train()
    num = len(train_loader)
    for i, data in enumerate(train_loader):
        model.zero_grad()
        optimizer.zero_grad()
        O, B, M = data['O'].cuda(), data['B'].cuda(), data['M'].cuda()
        recon = model.forward(O)
        # recon = model.forward(recon)
        recon = torch.clamp(recon, 0., 1.)
        l1_loss = l1_criterion(recon, B)
        # mask_loss = mask_criterion(mask[:, 0, :, :], M)
        ssim_loss = ssim_criterion(recon, B)
        loss = l1_loss - radio * ssim_loss  # + mask_loss
        loss.backward()
        optimizer.step()
        if i % 10 == 0:
            with torch.no_grad():
                rmse = batch_RMSE_G(B, recon)
                psnr = batch_PSNR(B, recon)
                print(
                    'epoch {}, [{}/{}], loss {}, PSNR {}, SSIM {}, RMSE {}, '.
                    format(epoch, i, num, loss, psnr, ssim_loss.item(), rmse))
                if writer is not None:
                    step = epoch * num + i
                    writer.add_scalar('loss', loss.item(), step)
                    writer.add_scalar('l1_loss', l1_loss.item(), step)
                    writer.add_scalar('ssim', ssim_loss.item(), step)
                    # writer.add_scalar('mask_loss', mask_loss.item(), step)
                    writer.add_scalar('RMSE', rmse.item(), step)
                    writer.add_scalar('PSNR', psnr.item(), step)
Exemplo n.º 6
0
def main(model_path, need_weisout=False):

    if not os.path.exists(os.path.join(model_path, 'result')):
        os.mkdir(os.path.join(model_path, 'result'))
    if not os.path.exists(os.path.join(model_path, 'mask')):
        os.mkdir(os.path.join(model_path, 'mask'))

    print('load model path : {}'.format(model_path))
    model_name = 'model.pth'
    name = model_path.split('/')[-1].split('_')
    bNum = int(name[2])
    nblocks = int(name[3])
    model = FMNet(bNum=bNum,
                  nblocks=nblocks,
                  input_features=31,
                  num_features=64,
                  out_features=31)
    print(bNum, nblocks)

    model.load_state_dict(
        torch.load(os.path.join(model_path, model_name), map_location='cpu'))
    print('model MNet has load !')

    model.eval()
    model.cuda()

    testDataset = HyperDataset(mode='test')
    num = len(testDataset)
    print('test img num : {}'.format(num))

    test_rgb_filename_list = get_testfile_list()
    print('test_rgb_filename_list len : {}'.format(
        len(test_rgb_filename_list)))

    psnr_sum = 0
    all_time = 0
    for i in range(num):

        file_name = test_rgb_filename_list[i].split('/')[-1]
        key = int(file_name.split('.')[0].split('_')[-1])
        print(file_name, key)

        real_hyper, _, real_rgb = testDataset.get_data_by_key(str(key))
        real_hyper, real_rgb = torch.unsqueeze(real_hyper, 0), torch.unsqueeze(
            real_rgb, 0)
        real_hyper, real_rgb = real_hyper.cuda(), real_rgb.cuda()

        # print('test img [{}/{}], input rgb shape : {}, hyper shape : {}'.format(i+1, num, real_rgb.shape, real_hyper.shape))
        # forward
        with torch.no_grad():
            start = time.time()
            fake_hyper = model.forward(real_rgb)
            all_time += (time.time() - start)

        if isinstance(fake_hyper, tuple):
            fake_hyper, weis_out = fake_hyper
            if need_weisout:
                weis_out = weis_out[0, :, 0, :, :].cpu().numpy()
                weis_out = np.squeeze(weis_out)
            else:
                weis_out = False
        else:
            weis_out = None

        psnr = batch_PSNR(real_hyper, fake_hyper).item()
        print('test img [{}/{}], fake hyper shape : {}, psnr : {}'.format(
            i + 1, num, fake_hyper.shape, psnr))
        psnr_sum += psnr
        fake_hyper_mat = fake_hyper[0, :, :, :].cpu().numpy()
        if weis_out is None:
            scio.savemat(
                os.path.join(
                    model_path, 'result',
                    test_rgb_filename_list[i].split('/')[-1].split('.')[0] +
                    '.mat'), {'rad': fake_hyper_mat})
            print('sucessfully save fake hyper !!!')
        else:
            scio.savemat(
                os.path.join(
                    model_path, 'result',
                    test_rgb_filename_list[i].split('/')[-1].split('.')[0] +
                    '.mat'), {
                        'rad': fake_hyper_mat,
                        'weis_out': weis_out
                    })
            print('sucessfully save fake hyper and weis_out !!!')
        print()

    print('average psnr : {}'.format(psnr_sum / num))
    print('average test time : {}'.format(all_time / num))