Exemplo n.º 1
0
def main():
    gpu_num = len(opt.gpus.split(','))
    device_ids = list(range(gpu_num))
    print("loading dataset ...")
    train_dataset = HyperDataset(mode='train')
    test_dataset = HyperDataset(mode='test')
    batchSize = opt.batchSize_per_gpu * gpu_num
    train_loader = udata.DataLoader(train_dataset, batch_size=batchSize, shuffle=True, num_workers=0)
    print('train dataset num : {}'.format(len(train_dataset)))
    criterion = nn.L1Loss()

    epoch = 0
    net = FMNet(bNum=opt.bNum, nblocks=opt.nblocks, input_features=31, num_features=64, out_features=31)
    optimizer = optim.Adam(net.parameters(), lr=opt.lr, weight_decay=1e-6, betas=(0.9, 0.999))
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, opt.milestones, opt.gamma)
    load = False
    if load:
        checkpoint_file_name = 'checkpoint.pth'
        checkpoint = torch.load(os.path.join(outf, checkpoint_file_name), map_location=torch.device('cuda:0'))
        epoch = checkpoint['epoch'] + 1
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print('Successfully load checkpoint {} ... '.format(os.path.join(outf, checkpoint_file_name)))

    model = nn.DataParallel(net, device_ids=device_ids, output_device=device_ids[0])
    model.cuda()
    criterion.cuda()
    
    writer = SummaryWriter(outf)
    while epoch < opt.epochs:
        start = time.time()
        print("epoch {} learning rate {}".format(epoch, optimizer.param_groups[0]['lr']))
        
        train(model, criterion, optimizer, train_loader, epoch, writer)
        lr_scheduler.step()
        train_dataset.shuffle()

        test(model, test_dataset, epoch, writer)

        if (epoch+1) % 20 == 0:
            torch.save({
                'epoch' : epoch,
                'model_state_dict': model.module.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, os.path.join(outf, 'checkpoint_{}.pth'.format(epoch)))

        end = time.time()
        print('epoch {} cost {} hour '.format(epoch, str((end - start)/(60*60))))
        epoch += 1
    torch.save(model.module.state_dict(), os.path.join(outf, 'model.pth'))
Exemplo n.º 2
0
def main():
    ## network architecture
    rgb_features = 3
    pre_features = 64
    hyper_features = 31
    growth_rate = 16
    negative_slope = 0.2
    ## optimization
    beta1 = 0.9
    beta2 = 0.999
    ## load dataset
    print("\nloading dataset ...\n")
    trainDataset = HyperDataset(crop_size=64, mode='train')
    trainLoader = udata.DataLoader(trainDataset,
                                   batch_size=opt.batchSize,
                                   shuffle=True,
                                   num_workers=4)
    testDataset = HyperDataset(crop_size=1280, mode='test')
    ## build model
    print("\nbuilding models ...\n")
    net = Model_Ref(input_features=rgb_features,
                    pre_features=pre_features,
                    output_features=hyper_features,
                    db_growth_rate=growth_rate,
                    negative_slope=negative_slope,
                    p_drop=opt.pdrop)
    net.apply(weights_init_kaimingUniform)
    criterion = nn.MSELoss()
    # move to GPU
    device_ids = [0, 1, 2, 3, 4, 5, 6, 7]
    model = nn.DataParallel(net, device_ids=device_ids).cuda()
    criterion.cuda()
    # optimizers
    optimizer = optim.Adam(model.parameters(),
                           lr=opt.lr,
                           weight_decay=1e-6,
                           betas=(beta1, beta2))
    ## begin training
    step = 0
    writer = SummaryWriter(opt.outf)
    for epoch in range(opt.epochs):
        # set learning rate
        if epoch < opt.milestone:
            current_lr = opt.lr
        else:
            current_lr = opt.lr / 10.
        for param_group in optimizer.param_groups:
            param_group["lr"] = current_lr
        print("\nepoch %d learning rate %f\n" % (epoch, current_lr))
        # run for one epoch
        for i, data in enumerate(trainLoader, 0):
            model.train()
            model.zero_grad()
            optimizer.zero_grad()
            real_hyper, real_rgb = data
            H = real_hyper.size(2)
            W = real_hyper.size(3)
            real_hyper, real_rgb = Variable(real_hyper.cuda()), Variable(
                real_rgb.cuda())
            # train
            fake_hyper = model.forward(real_rgb)
            loss = criterion(fake_hyper, real_hyper)
            loss.backward()
            optimizer.step()
            if i % 10 == 0:
                # result
                model.eval()
                with torch.no_grad():
                    fake_hyper = model.forward(real_rgb)
                RMSE = batch_RMSE(real_hyper, fake_hyper)
                RMSE_G = batch_RMSE_G(real_hyper, fake_hyper)
                rRMSE = batch_rRMSE(real_hyper, fake_hyper)
                rRMSE_G = batch_rRMSE_G(real_hyper, fake_hyper)
                SAM = batch_SAM(real_hyper, fake_hyper)
                print(
                    "[epoch %d][%d/%d] RMSE: %.4f RMSE_G: %.4f rRMSE: %.4f rRMSE_G: %.4f SAM: %.4f"
                    % (epoch, i, len(trainLoader), RMSE.item(), RMSE_G.item(),
                       rRMSE.item(), rRMSE_G.item(), SAM.item()))
                # Log the scalar values
                writer.add_scalar('RMSE', RMSE.item(), step)
                writer.add_scalar('RMSE_G', RMSE_G.item(), step)
                writer.add_scalar('rRMSE', rRMSE.item(), step)
                writer.add_scalar('rRMSE_G', rRMSE_G.item(), step)
                writer.add_scalar('SAM', SAM.item(), step)
            if (i == 0) | (i == 1000):
                # validate
                model.eval()
                print("\ncomputing results on validation set ...\n")
                num = len(testDataset)
                average_RMSE = 0.
                average_RMSE_G = 0.
                average_rRMSE = 0.
                average_rRMSE_G = 0.
                average_SAM = 0.
                for k in range(num):
                    # data
                    real_hyper, real_rgb = testDataset[k]
                    real_hyper = torch.unsqueeze(real_hyper, 0)
                    real_rgb = torch.unsqueeze(real_rgb, 0)
                    real_hyper, real_rgb = Variable(
                        real_hyper.cuda()), Variable(real_rgb.cuda())
                    # forward
                    with torch.no_grad():
                        fake_hyper = model.forward(real_rgb)
                    # metrics
                    RMSE = batch_RMSE(real_hyper, fake_hyper)
                    RMSE_G = batch_RMSE_G(real_hyper, fake_hyper)
                    rRMSE = batch_rRMSE(real_hyper, fake_hyper)
                    rRMSE_G = batch_rRMSE_G(real_hyper, fake_hyper)
                    SAM = batch_SAM(real_hyper, fake_hyper)
                    average_RMSE += RMSE.item()
                    average_RMSE_G += RMSE_G.item()
                    average_rRMSE += rRMSE.item()
                    average_rRMSE_G += rRMSE_G.item()
                    average_SAM += SAM.item()
                writer.add_scalar('RMSE_val', average_RMSE / num, step)
                writer.add_scalar('RMSE_G_val', average_RMSE_G / num, step)
                writer.add_scalar('rRMSE_val', average_rRMSE / num, step)
                writer.add_scalar('rRMSE_G_val', average_rRMSE_G / num, step)
                writer.add_scalar('SAM_val', average_SAM / num, step)
                print(
                    "[epoch %d][%d/%d] validation:\nRMSE: %.4f RMSE_G: %.4f rRMSE: %.4f rRMSE_G: %.4f SAM: %.4f\n"
                    % (epoch, i, len(trainLoader), average_RMSE / num,
                       average_RMSE_G / num, average_rRMSE / num,
                       average_rRMSE_G / num, average_SAM / num))
            step += 1
        ## the end of each epoch
        model.eval()
        # plot spectrum
        print("\nplotting spectrum ...\n")
        with torch.no_grad():
            fake_hyper = model.forward(real_rgb)
        real_spectrum = real_hyper.data.cpu().numpy()[0, :,
                                                      int(H / 2),
                                                      int(W / 2)]
        fake_spectrum = fake_hyper.data.cpu().numpy()[0, :,
                                                      int(H / 2),
                                                      int(W / 2)]
        I_spectrum = plot_spectrum(real_spectrum, fake_spectrum)
        writer.add_image('spectrum', torch.Tensor(I_spectrum), epoch)
        # images
        print("\nadding images ...\n")
        I_rgb = utils.make_grid(real_rgb.data[0:16, :, :, :].clamp(0., 1.),
                                nrow=4,
                                normalize=True,
                                scale_each=True)
        I_real = utils.make_grid(real_hyper.data[0:16,
                                                 0:3, :, :].clamp(0., 1.),
                                 nrow=4,
                                 normalize=True,
                                 scale_each=True)
        I_fake = utils.make_grid(fake_hyper.data[0:16,
                                                 0:3, :, :].clamp(0., 1.),
                                 nrow=4,
                                 normalize=True,
                                 scale_each=True)
        writer.add_image('rgb', I_rgb, epoch)
        writer.add_image('real', I_real, epoch)
        writer.add_image('fake', I_fake, epoch)
        # save model
        if epoch >= opt.epochs - 10:
            torch.save(model.state_dict(),
                       os.path.join(opt.outf, 'net_%03d.pth' % epoch))
def main():
    ## network architecture
    rgb_features = 3
    pre_features = 64
    hyper_features = 31
    growth_rate = 16
    negative_slope = 0.2
    ## load data
    print("loading dataset ...\n")
    testDataset = HyperDataset(crop_size=opt.size, mode='test')
    ## build model
    print("building models ...\n")
    if opt.model == 'Model':
        print("Our model, Dropout rate 0.%d\n" % opt.dropout)
        net = Model(
            input_features = rgb_features,
            output_features = hyper_features,
            negative_slope = negative_slope,
            p_drop = 0
        )
    elif opt.model == 'Ref':
        print("Reference model FC-DenseNet, Dropout rate 0.%d\n" % opt.dropout)
        net = Model_Ref(
            input_features = rgb_features,
            pre_features = pre_features,
            output_features = hyper_features,
            db_growth_rate = growth_rate,
            negative_slope = negative_slope,
            p_drop = 0
        )
    else:
        raise Exception("Invalid model name!", opt.model)
    # move to GPU
    device_ids = [0,1,2,3,4,5,6,7]
    model = nn.DataParallel(net, device_ids=device_ids).cuda()
    model.load_state_dict(torch.load(os.path.join(opt.logs, 'net_%s_%02d.pth'%(opt.model,opt.dropout) )))
    model.eval()
    ## testing
    num = len(testDataset)
    average_RMSE = 0.
    average_RMSE_G = 0.
    average_rRMSE = 0.
    average_rRMSE_G = 0.
    average_SAM  = 0.
    criterion = nn.MSELoss()
    criterion.cuda()
    # recording results
    im_rgb = dict()
    im_hyper = dict()
    im_hyper_fake = dict()
    for i in range(num):
        # data
        real_hyper, real_rgb = testDataset[i]
        real_hyper = torch.unsqueeze(real_hyper, 0)
        real_rgb = torch.unsqueeze(real_rgb, 0)
        H = real_hyper.size(2)
        W = real_hyper.size(3)
        real_hyper, real_rgb = Variable(real_hyper.cuda()), Variable(real_rgb.cuda())
        # forward
        with torch.no_grad():
            fake_hyper = model(real_rgb)
        # metrics
        RMSE = batch_RMSE(real_hyper, fake_hyper)
        RMSE_G = batch_RMSE_G(real_hyper, fake_hyper)
        rRMSE = batch_rRMSE(real_hyper, fake_hyper)
        rRMSE_G = batch_rRMSE_G(real_hyper, fake_hyper)
        SAM = batch_SAM(real_hyper, fake_hyper)
        average_RMSE    += RMSE.item()
        average_RMSE_G  += RMSE_G.item()
        average_rRMSE   += rRMSE.item()
        average_rRMSE_G += rRMSE_G.item()
        average_SAM     += SAM.item()
        print("[%d/%d] RMSE: %.4f RMSE_G: %.4f rRMSE: %.4f rRMSE_G: %.4f SAM: %.4f"
            % (i+1, num, RMSE.item(), RMSE_G.item(), rRMSE.item(), rRMSE_G.item(), SAM.item()))
        # images
        print("adding images ...\n")
        I_rgb = real_rgb.data.cpu().numpy().squeeze()
        I_hyper = real_hyper.data.cpu().numpy().squeeze()
        I_hyper_fake = fake_hyper.data.cpu().numpy().squeeze()
        im_rgb['rgb_%d'%i] = I_rgb
        im_hyper['hyper_%d'%i] = I_hyper
        im_hyper_fake['hyper_fake_%d'%i] = I_hyper_fake

    print("\naverage RMSE: %.4f" % (average_RMSE/num))
    print("average RMSE_G: %.4f" % (average_RMSE_G/num))
    print("\naverage rRMSE: %.4f" % (average_rRMSE/num))
    print("average rRMSE_G: %.4f" % (average_rRMSE_G/num))
    print("\naverage SAM: %.4f" % (average_SAM/num))

    print("\nsaving matlab files ...\n")
    scio.savemat(os.path.join(opt.logs, 'im_rgb.mat'), im_rgb)
    scio.savemat(os.path.join(opt.logs, 'im_hyper.mat'), im_hyper)
    scio.savemat(os.path.join(opt.logs, 'im_hyper_fake.mat'), im_hyper_fake)
def main():
    f = open('regressed_CIE.csv', 'w', newline='')
    writer = csv.writer(f)

    print("\nloading dataset ...\n")
    trainDataset = HyperDataset(crop_size=64, mode='train')
    trainLoader = udata.DataLoader(trainDataset,
                                   batch_size=opt.batchSize,
                                   shuffle=True,
                                   num_workers=4)
    testDataset = HyperDataset(crop_size=1024, mode='test')

    print("\nbuilding models ...\n")
    net = nn.Linear(in_features=31, out_features=3, bias=False)
    criterion = nn.MSELoss()

    device_ids = [0, 1]
    model = nn.DataParallel(net, device_ids=device_ids).cuda()
    criterion.cuda()
    model.load_state_dict(torch.load('net.pth'))

    # optimizer = optim.Adam(model.parameters(), lr=opt.lr)
    optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9)

    step = 0
    writer = SummaryWriter(opt.outf)
    for epoch in range(opt.epochs):
        for i, data in enumerate(trainLoader, 0):
            model.train()
            model.zero_grad()
            optimizer.zero_grad()
            real_hyper, real_rgb = data
            real_hyper = real_hyper.permute((0, 2, 3, 1))
            real_rgb = real_rgb.permute((0, 2, 3, 1))
            H = real_hyper.size(1)
            W = real_hyper.size(2)
            real_hyper, real_rgb = Variable(real_hyper.cuda()), Variable(
                real_rgb.cuda())
            # train
            fake_rgb = model.forward(real_hyper)
            loss = criterion(fake_rgb, real_rgb)
            loss.backward()
            optimizer.step()
            if i % 10 == 0:
                model.eval()
                with torch.no_grad():
                    fake_rgb = model.forward(real_hyper)
                loss_train = criterion(fake_rgb, real_rgb).item()
                writer.add_scalar('Loss_train', loss_train, step)
                print("[epoch %d][%d/%d] Loss: %.4f" %
                      (epoch, i, len(trainLoader), loss_train))
            step += 1
        # validate
        num = len(testDataset)
        avg_loss = 0
        for k in range(num):
            # data
            real_hyper, real_rgb = testDataset[k]
            real_hyper = torch.unsqueeze(real_hyper, 0).permute((0, 2, 3, 1))
            real_rgb = torch.unsqueeze(real_rgb, 0).permute((0, 2, 3, 1))
            real_hyper, real_rgb = Variable(real_hyper.cuda()), Variable(
                real_rgb.cuda())
            # forward
            with torch.no_grad():
                fake_rgb = model.forward(real_hyper)
            avg_loss += criterion(fake_rgb, real_rgb).item()
        writer.add_scalar('Loss_val', avg_loss / num, avg_loss / num)
        print("[epoch %d] Validation Loss: %.4f" % (epoch, avg_loss / num))

        for param in model.parameters():
            writer.writerows(param.data.cpu().numpy().T)
        torch.save(model.state_dict(), os.path.join(opt.outf, 'net.pth'))
Exemplo n.º 5
0
def main(model_path, need_weisout=False):

    if not os.path.exists(os.path.join(model_path, 'result')):
        os.mkdir(os.path.join(model_path, 'result'))
    if not os.path.exists(os.path.join(model_path, 'mask')):
        os.mkdir(os.path.join(model_path, 'mask'))

    print('load model path : {}'.format(model_path))
    model_name = 'model.pth'
    name = model_path.split('/')[-1].split('_')
    bNum = int(name[2])
    nblocks = int(name[3])
    model = FMNet(bNum=bNum,
                  nblocks=nblocks,
                  input_features=31,
                  num_features=64,
                  out_features=31)
    print(bNum, nblocks)

    model.load_state_dict(
        torch.load(os.path.join(model_path, model_name), map_location='cpu'))
    print('model MNet has load !')

    model.eval()
    model.cuda()

    testDataset = HyperDataset(mode='test')
    num = len(testDataset)
    print('test img num : {}'.format(num))

    test_rgb_filename_list = get_testfile_list()
    print('test_rgb_filename_list len : {}'.format(
        len(test_rgb_filename_list)))

    psnr_sum = 0
    all_time = 0
    for i in range(num):

        file_name = test_rgb_filename_list[i].split('/')[-1]
        key = int(file_name.split('.')[0].split('_')[-1])
        print(file_name, key)

        real_hyper, _, real_rgb = testDataset.get_data_by_key(str(key))
        real_hyper, real_rgb = torch.unsqueeze(real_hyper, 0), torch.unsqueeze(
            real_rgb, 0)
        real_hyper, real_rgb = real_hyper.cuda(), real_rgb.cuda()

        # print('test img [{}/{}], input rgb shape : {}, hyper shape : {}'.format(i+1, num, real_rgb.shape, real_hyper.shape))
        # forward
        with torch.no_grad():
            start = time.time()
            fake_hyper = model.forward(real_rgb)
            all_time += (time.time() - start)

        if isinstance(fake_hyper, tuple):
            fake_hyper, weis_out = fake_hyper
            if need_weisout:
                weis_out = weis_out[0, :, 0, :, :].cpu().numpy()
                weis_out = np.squeeze(weis_out)
            else:
                weis_out = False
        else:
            weis_out = None

        psnr = batch_PSNR(real_hyper, fake_hyper).item()
        print('test img [{}/{}], fake hyper shape : {}, psnr : {}'.format(
            i + 1, num, fake_hyper.shape, psnr))
        psnr_sum += psnr
        fake_hyper_mat = fake_hyper[0, :, :, :].cpu().numpy()
        if weis_out is None:
            scio.savemat(
                os.path.join(
                    model_path, 'result',
                    test_rgb_filename_list[i].split('/')[-1].split('.')[0] +
                    '.mat'), {'rad': fake_hyper_mat})
            print('sucessfully save fake hyper !!!')
        else:
            scio.savemat(
                os.path.join(
                    model_path, 'result',
                    test_rgb_filename_list[i].split('/')[-1].split('.')[0] +
                    '.mat'), {
                        'rad': fake_hyper_mat,
                        'weis_out': weis_out
                    })
            print('sucessfully save fake hyper and weis_out !!!')
        print()

    print('average psnr : {}'.format(psnr_sum / num))
    print('average test time : {}'.format(all_time / num))