def main(): gpu_num = len(opt.gpus.split(',')) device_ids = list(range(gpu_num)) print("loading dataset ...") train_dataset = HyperDataset(mode='train') test_dataset = HyperDataset(mode='test') batchSize = opt.batchSize_per_gpu * gpu_num train_loader = udata.DataLoader(train_dataset, batch_size=batchSize, shuffle=True, num_workers=0) print('train dataset num : {}'.format(len(train_dataset))) criterion = nn.L1Loss() epoch = 0 net = FMNet(bNum=opt.bNum, nblocks=opt.nblocks, input_features=31, num_features=64, out_features=31) optimizer = optim.Adam(net.parameters(), lr=opt.lr, weight_decay=1e-6, betas=(0.9, 0.999)) lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, opt.milestones, opt.gamma) load = False if load: checkpoint_file_name = 'checkpoint.pth' checkpoint = torch.load(os.path.join(outf, checkpoint_file_name), map_location=torch.device('cuda:0')) epoch = checkpoint['epoch'] + 1 net.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print('Successfully load checkpoint {} ... '.format(os.path.join(outf, checkpoint_file_name))) model = nn.DataParallel(net, device_ids=device_ids, output_device=device_ids[0]) model.cuda() criterion.cuda() writer = SummaryWriter(outf) while epoch < opt.epochs: start = time.time() print("epoch {} learning rate {}".format(epoch, optimizer.param_groups[0]['lr'])) train(model, criterion, optimizer, train_loader, epoch, writer) lr_scheduler.step() train_dataset.shuffle() test(model, test_dataset, epoch, writer) if (epoch+1) % 20 == 0: torch.save({ 'epoch' : epoch, 'model_state_dict': model.module.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, os.path.join(outf, 'checkpoint_{}.pth'.format(epoch))) end = time.time() print('epoch {} cost {} hour '.format(epoch, str((end - start)/(60*60)))) epoch += 1 torch.save(model.module.state_dict(), os.path.join(outf, 'model.pth'))
def main(): ## network architecture rgb_features = 3 pre_features = 64 hyper_features = 31 growth_rate = 16 negative_slope = 0.2 ## optimization beta1 = 0.9 beta2 = 0.999 ## load dataset print("\nloading dataset ...\n") trainDataset = HyperDataset(crop_size=64, mode='train') trainLoader = udata.DataLoader(trainDataset, batch_size=opt.batchSize, shuffle=True, num_workers=4) testDataset = HyperDataset(crop_size=1280, mode='test') ## build model print("\nbuilding models ...\n") net = Model_Ref(input_features=rgb_features, pre_features=pre_features, output_features=hyper_features, db_growth_rate=growth_rate, negative_slope=negative_slope, p_drop=opt.pdrop) net.apply(weights_init_kaimingUniform) criterion = nn.MSELoss() # move to GPU device_ids = [0, 1, 2, 3, 4, 5, 6, 7] model = nn.DataParallel(net, device_ids=device_ids).cuda() criterion.cuda() # optimizers optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=1e-6, betas=(beta1, beta2)) ## begin training step = 0 writer = SummaryWriter(opt.outf) for epoch in range(opt.epochs): # set learning rate if epoch < opt.milestone: current_lr = opt.lr else: current_lr = opt.lr / 10. for param_group in optimizer.param_groups: param_group["lr"] = current_lr print("\nepoch %d learning rate %f\n" % (epoch, current_lr)) # run for one epoch for i, data in enumerate(trainLoader, 0): model.train() model.zero_grad() optimizer.zero_grad() real_hyper, real_rgb = data H = real_hyper.size(2) W = real_hyper.size(3) real_hyper, real_rgb = Variable(real_hyper.cuda()), Variable( real_rgb.cuda()) # train fake_hyper = model.forward(real_rgb) loss = criterion(fake_hyper, real_hyper) loss.backward() optimizer.step() if i % 10 == 0: # result model.eval() with torch.no_grad(): fake_hyper = model.forward(real_rgb) RMSE = batch_RMSE(real_hyper, fake_hyper) RMSE_G = batch_RMSE_G(real_hyper, fake_hyper) rRMSE = batch_rRMSE(real_hyper, fake_hyper) rRMSE_G = batch_rRMSE_G(real_hyper, fake_hyper) SAM = batch_SAM(real_hyper, fake_hyper) print( "[epoch %d][%d/%d] RMSE: %.4f RMSE_G: %.4f rRMSE: %.4f rRMSE_G: %.4f SAM: %.4f" % (epoch, i, len(trainLoader), RMSE.item(), RMSE_G.item(), rRMSE.item(), rRMSE_G.item(), SAM.item())) # Log the scalar values writer.add_scalar('RMSE', RMSE.item(), step) writer.add_scalar('RMSE_G', RMSE_G.item(), step) writer.add_scalar('rRMSE', rRMSE.item(), step) writer.add_scalar('rRMSE_G', rRMSE_G.item(), step) writer.add_scalar('SAM', SAM.item(), step) if (i == 0) | (i == 1000): # validate model.eval() print("\ncomputing results on validation set ...\n") num = len(testDataset) average_RMSE = 0. average_RMSE_G = 0. average_rRMSE = 0. average_rRMSE_G = 0. average_SAM = 0. for k in range(num): # data real_hyper, real_rgb = testDataset[k] real_hyper = torch.unsqueeze(real_hyper, 0) real_rgb = torch.unsqueeze(real_rgb, 0) real_hyper, real_rgb = Variable( real_hyper.cuda()), Variable(real_rgb.cuda()) # forward with torch.no_grad(): fake_hyper = model.forward(real_rgb) # metrics RMSE = batch_RMSE(real_hyper, fake_hyper) RMSE_G = batch_RMSE_G(real_hyper, fake_hyper) rRMSE = batch_rRMSE(real_hyper, fake_hyper) rRMSE_G = batch_rRMSE_G(real_hyper, fake_hyper) SAM = batch_SAM(real_hyper, fake_hyper) average_RMSE += RMSE.item() average_RMSE_G += RMSE_G.item() average_rRMSE += rRMSE.item() average_rRMSE_G += rRMSE_G.item() average_SAM += SAM.item() writer.add_scalar('RMSE_val', average_RMSE / num, step) writer.add_scalar('RMSE_G_val', average_RMSE_G / num, step) writer.add_scalar('rRMSE_val', average_rRMSE / num, step) writer.add_scalar('rRMSE_G_val', average_rRMSE_G / num, step) writer.add_scalar('SAM_val', average_SAM / num, step) print( "[epoch %d][%d/%d] validation:\nRMSE: %.4f RMSE_G: %.4f rRMSE: %.4f rRMSE_G: %.4f SAM: %.4f\n" % (epoch, i, len(trainLoader), average_RMSE / num, average_RMSE_G / num, average_rRMSE / num, average_rRMSE_G / num, average_SAM / num)) step += 1 ## the end of each epoch model.eval() # plot spectrum print("\nplotting spectrum ...\n") with torch.no_grad(): fake_hyper = model.forward(real_rgb) real_spectrum = real_hyper.data.cpu().numpy()[0, :, int(H / 2), int(W / 2)] fake_spectrum = fake_hyper.data.cpu().numpy()[0, :, int(H / 2), int(W / 2)] I_spectrum = plot_spectrum(real_spectrum, fake_spectrum) writer.add_image('spectrum', torch.Tensor(I_spectrum), epoch) # images print("\nadding images ...\n") I_rgb = utils.make_grid(real_rgb.data[0:16, :, :, :].clamp(0., 1.), nrow=4, normalize=True, scale_each=True) I_real = utils.make_grid(real_hyper.data[0:16, 0:3, :, :].clamp(0., 1.), nrow=4, normalize=True, scale_each=True) I_fake = utils.make_grid(fake_hyper.data[0:16, 0:3, :, :].clamp(0., 1.), nrow=4, normalize=True, scale_each=True) writer.add_image('rgb', I_rgb, epoch) writer.add_image('real', I_real, epoch) writer.add_image('fake', I_fake, epoch) # save model if epoch >= opt.epochs - 10: torch.save(model.state_dict(), os.path.join(opt.outf, 'net_%03d.pth' % epoch))
def main(): ## network architecture rgb_features = 3 pre_features = 64 hyper_features = 31 growth_rate = 16 negative_slope = 0.2 ## load data print("loading dataset ...\n") testDataset = HyperDataset(crop_size=opt.size, mode='test') ## build model print("building models ...\n") if opt.model == 'Model': print("Our model, Dropout rate 0.%d\n" % opt.dropout) net = Model( input_features = rgb_features, output_features = hyper_features, negative_slope = negative_slope, p_drop = 0 ) elif opt.model == 'Ref': print("Reference model FC-DenseNet, Dropout rate 0.%d\n" % opt.dropout) net = Model_Ref( input_features = rgb_features, pre_features = pre_features, output_features = hyper_features, db_growth_rate = growth_rate, negative_slope = negative_slope, p_drop = 0 ) else: raise Exception("Invalid model name!", opt.model) # move to GPU device_ids = [0,1,2,3,4,5,6,7] model = nn.DataParallel(net, device_ids=device_ids).cuda() model.load_state_dict(torch.load(os.path.join(opt.logs, 'net_%s_%02d.pth'%(opt.model,opt.dropout) ))) model.eval() ## testing num = len(testDataset) average_RMSE = 0. average_RMSE_G = 0. average_rRMSE = 0. average_rRMSE_G = 0. average_SAM = 0. criterion = nn.MSELoss() criterion.cuda() # recording results im_rgb = dict() im_hyper = dict() im_hyper_fake = dict() for i in range(num): # data real_hyper, real_rgb = testDataset[i] real_hyper = torch.unsqueeze(real_hyper, 0) real_rgb = torch.unsqueeze(real_rgb, 0) H = real_hyper.size(2) W = real_hyper.size(3) real_hyper, real_rgb = Variable(real_hyper.cuda()), Variable(real_rgb.cuda()) # forward with torch.no_grad(): fake_hyper = model(real_rgb) # metrics RMSE = batch_RMSE(real_hyper, fake_hyper) RMSE_G = batch_RMSE_G(real_hyper, fake_hyper) rRMSE = batch_rRMSE(real_hyper, fake_hyper) rRMSE_G = batch_rRMSE_G(real_hyper, fake_hyper) SAM = batch_SAM(real_hyper, fake_hyper) average_RMSE += RMSE.item() average_RMSE_G += RMSE_G.item() average_rRMSE += rRMSE.item() average_rRMSE_G += rRMSE_G.item() average_SAM += SAM.item() print("[%d/%d] RMSE: %.4f RMSE_G: %.4f rRMSE: %.4f rRMSE_G: %.4f SAM: %.4f" % (i+1, num, RMSE.item(), RMSE_G.item(), rRMSE.item(), rRMSE_G.item(), SAM.item())) # images print("adding images ...\n") I_rgb = real_rgb.data.cpu().numpy().squeeze() I_hyper = real_hyper.data.cpu().numpy().squeeze() I_hyper_fake = fake_hyper.data.cpu().numpy().squeeze() im_rgb['rgb_%d'%i] = I_rgb im_hyper['hyper_%d'%i] = I_hyper im_hyper_fake['hyper_fake_%d'%i] = I_hyper_fake print("\naverage RMSE: %.4f" % (average_RMSE/num)) print("average RMSE_G: %.4f" % (average_RMSE_G/num)) print("\naverage rRMSE: %.4f" % (average_rRMSE/num)) print("average rRMSE_G: %.4f" % (average_rRMSE_G/num)) print("\naverage SAM: %.4f" % (average_SAM/num)) print("\nsaving matlab files ...\n") scio.savemat(os.path.join(opt.logs, 'im_rgb.mat'), im_rgb) scio.savemat(os.path.join(opt.logs, 'im_hyper.mat'), im_hyper) scio.savemat(os.path.join(opt.logs, 'im_hyper_fake.mat'), im_hyper_fake)
def main(): f = open('regressed_CIE.csv', 'w', newline='') writer = csv.writer(f) print("\nloading dataset ...\n") trainDataset = HyperDataset(crop_size=64, mode='train') trainLoader = udata.DataLoader(trainDataset, batch_size=opt.batchSize, shuffle=True, num_workers=4) testDataset = HyperDataset(crop_size=1024, mode='test') print("\nbuilding models ...\n") net = nn.Linear(in_features=31, out_features=3, bias=False) criterion = nn.MSELoss() device_ids = [0, 1] model = nn.DataParallel(net, device_ids=device_ids).cuda() criterion.cuda() model.load_state_dict(torch.load('net.pth')) # optimizer = optim.Adam(model.parameters(), lr=opt.lr) optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9) step = 0 writer = SummaryWriter(opt.outf) for epoch in range(opt.epochs): for i, data in enumerate(trainLoader, 0): model.train() model.zero_grad() optimizer.zero_grad() real_hyper, real_rgb = data real_hyper = real_hyper.permute((0, 2, 3, 1)) real_rgb = real_rgb.permute((0, 2, 3, 1)) H = real_hyper.size(1) W = real_hyper.size(2) real_hyper, real_rgb = Variable(real_hyper.cuda()), Variable( real_rgb.cuda()) # train fake_rgb = model.forward(real_hyper) loss = criterion(fake_rgb, real_rgb) loss.backward() optimizer.step() if i % 10 == 0: model.eval() with torch.no_grad(): fake_rgb = model.forward(real_hyper) loss_train = criterion(fake_rgb, real_rgb).item() writer.add_scalar('Loss_train', loss_train, step) print("[epoch %d][%d/%d] Loss: %.4f" % (epoch, i, len(trainLoader), loss_train)) step += 1 # validate num = len(testDataset) avg_loss = 0 for k in range(num): # data real_hyper, real_rgb = testDataset[k] real_hyper = torch.unsqueeze(real_hyper, 0).permute((0, 2, 3, 1)) real_rgb = torch.unsqueeze(real_rgb, 0).permute((0, 2, 3, 1)) real_hyper, real_rgb = Variable(real_hyper.cuda()), Variable( real_rgb.cuda()) # forward with torch.no_grad(): fake_rgb = model.forward(real_hyper) avg_loss += criterion(fake_rgb, real_rgb).item() writer.add_scalar('Loss_val', avg_loss / num, avg_loss / num) print("[epoch %d] Validation Loss: %.4f" % (epoch, avg_loss / num)) for param in model.parameters(): writer.writerows(param.data.cpu().numpy().T) torch.save(model.state_dict(), os.path.join(opt.outf, 'net.pth'))
def main(model_path, need_weisout=False): if not os.path.exists(os.path.join(model_path, 'result')): os.mkdir(os.path.join(model_path, 'result')) if not os.path.exists(os.path.join(model_path, 'mask')): os.mkdir(os.path.join(model_path, 'mask')) print('load model path : {}'.format(model_path)) model_name = 'model.pth' name = model_path.split('/')[-1].split('_') bNum = int(name[2]) nblocks = int(name[3]) model = FMNet(bNum=bNum, nblocks=nblocks, input_features=31, num_features=64, out_features=31) print(bNum, nblocks) model.load_state_dict( torch.load(os.path.join(model_path, model_name), map_location='cpu')) print('model MNet has load !') model.eval() model.cuda() testDataset = HyperDataset(mode='test') num = len(testDataset) print('test img num : {}'.format(num)) test_rgb_filename_list = get_testfile_list() print('test_rgb_filename_list len : {}'.format( len(test_rgb_filename_list))) psnr_sum = 0 all_time = 0 for i in range(num): file_name = test_rgb_filename_list[i].split('/')[-1] key = int(file_name.split('.')[0].split('_')[-1]) print(file_name, key) real_hyper, _, real_rgb = testDataset.get_data_by_key(str(key)) real_hyper, real_rgb = torch.unsqueeze(real_hyper, 0), torch.unsqueeze( real_rgb, 0) real_hyper, real_rgb = real_hyper.cuda(), real_rgb.cuda() # print('test img [{}/{}], input rgb shape : {}, hyper shape : {}'.format(i+1, num, real_rgb.shape, real_hyper.shape)) # forward with torch.no_grad(): start = time.time() fake_hyper = model.forward(real_rgb) all_time += (time.time() - start) if isinstance(fake_hyper, tuple): fake_hyper, weis_out = fake_hyper if need_weisout: weis_out = weis_out[0, :, 0, :, :].cpu().numpy() weis_out = np.squeeze(weis_out) else: weis_out = False else: weis_out = None psnr = batch_PSNR(real_hyper, fake_hyper).item() print('test img [{}/{}], fake hyper shape : {}, psnr : {}'.format( i + 1, num, fake_hyper.shape, psnr)) psnr_sum += psnr fake_hyper_mat = fake_hyper[0, :, :, :].cpu().numpy() if weis_out is None: scio.savemat( os.path.join( model_path, 'result', test_rgb_filename_list[i].split('/')[-1].split('.')[0] + '.mat'), {'rad': fake_hyper_mat}) print('sucessfully save fake hyper !!!') else: scio.savemat( os.path.join( model_path, 'result', test_rgb_filename_list[i].split('/')[-1].split('.')[0] + '.mat'), { 'rad': fake_hyper_mat, 'weis_out': weis_out }) print('sucessfully save fake hyper and weis_out !!!') print() print('average psnr : {}'.format(psnr_sum / num)) print('average test time : {}'.format(all_time / num))