예제 #1
0
파일: test.py 프로젝트: alexandru-dinu/cae
def test(cfg: Namespace) -> None:
    assert cfg.checkpoint not in [None, ""]
    assert cfg.device == "cpu" or (cfg.device == "cuda"
                                   and T.cuda.is_available())

    exp_dir = ROOT_EXP_DIR / cfg.exp_name
    os.makedirs(exp_dir / "out", exist_ok=True)
    cfg.to_file(exp_dir / "test_config.json")
    logger.info(f"[exp dir={exp_dir}]")

    model = CAE()
    model.load_state_dict(T.load(cfg.checkpoint))
    model.eval()
    if cfg.device == "cuda":
        model.cuda()
    logger.info(f"[model={cfg.checkpoint}] on {cfg.device}")

    dataloader = DataLoader(dataset=ImageFolder720p(cfg.dataset_path),
                            batch_size=1,
                            shuffle=cfg.shuffle)
    logger.info(f"[dataset={cfg.dataset_path}]")

    loss_criterion = nn.MSELoss()

    for batch_idx, data in enumerate(dataloader, start=1):
        img, patches, _ = data
        if cfg.device == "cuda":
            patches = patches.cuda()

        if batch_idx % cfg.batch_every == 0:
            pass

        out = T.zeros(6, 10, 3, 128, 128)
        avg_loss = 0

        for i in range(6):
            for j in range(10):
                x = patches[:, :, i, j, :, :].cuda()
                y = model(x)
                out[i, j] = y.data

                loss = loss_criterion(y, x)
                avg_loss += (1 / 60) * loss.item()

        logger.debug("[%5d/%5d] avg_loss: %f", batch_idx, len(dataloader),
                     avg_loss)

        # save output
        out = np.transpose(out, (0, 3, 1, 4, 2))
        out = np.reshape(out, (768, 1280, 3))
        out = np.transpose(out, (2, 0, 1))

        y = T.cat((img[0], out), dim=2)
        save_imgs(
            imgs=y.unsqueeze(0),
            to_size=(3, 768, 2 * 1280),
            name=exp_dir / f"out/test_{batch_idx}.png",
        )
예제 #2
0
def test(cfg: Namespace) -> None:
    assert cfg.checkpoint not in [None, ""]
    assert cfg.device == "cpu" or (cfg.device == "cuda"
                                   and T.cuda.is_available())

    exp_dir = ROOT_EXP_DIR / cfg.exp_name
    os.makedirs(exp_dir / "out", exist_ok=True)
    cfg.to_file(exp_dir / "test_config.json")
    logger.info(f"[exp dir={exp_dir}]")

    model = CAE()
    model.load_state_dict(T.load(cfg.checkpoint))
    model.eval()
    if cfg.device == "cuda":
        model.cuda()
    logger.info(f"[model={cfg.checkpoint}] on {cfg.device}")

    dataloader = DataLoader(dataset=ImageFolder720p(cfg.dataset_path),
                            batch_size=1,
                            shuffle=cfg.shuffle)
    logger.info(f"[dataset={cfg.dataset_path}]")

    loss_criterion = nn.MSELoss()

    for batch_idx, data in enumerate(dataloader, start=2):
        img, patches, _ = data
        print('the patches shape is:', patches.shape)
        # print(_)
        # plt.imshow(patches[0,:,3,1,:,:].permute(1,2,0))
        # # plt.imshow(patches[0].permute(1,2,0))
        # plt.show()
        if cfg.device == "cuda":
            patches = patches.cuda()

        if batch_idx % cfg.batch_every == 0:
            pass

        out = T.zeros(6, 10, 3, 128, 128)
        avg_loss = 0
        foo = []
        for i in range(6):
            for j in range(10):
                x = patches[:, :, i, j, :, :].cuda()
                print('the x shape is:', x.shape)
                y = model(x)
                print('hellyy', y.shape)
예제 #3
0
파일: test.py 프로젝트: Iamanorange/cae
def test(cfg):
	os.makedirs(f"./test/{cfg['exp_name']}", exist_ok=True)

	model = CAE().cuda()

	model.load_state_dict(torch.load(cfg['chkpt']))
	model.eval()
	logger.info("Loaded model from", cfg['chkpt'])

	dataset = ImageFolder720p(cfg['dataset_path'])
	dataloader = DataLoader(dataset, batch_size=1, shuffle=cfg['shuffle'])
	logger.info(f"Done setup dataloader: {len(dataloader)}")

	mse_loss = nn.MSELoss()

	for bi, (img, patches, path) in enumerate(dataloader):

		out = torch.zeros(6, 10, 3, 128, 128)
		# enc = torch.zeros(6, 10, 16, 8, 8)
		avg_loss = 0

		for i in range(6):
			for j in range(10):
				x = Variable(patches[:, :, i, j, :, :]).cuda()
				y = model(x)

				# e = model.enc_x.data
				# p = torch.tensor(np.random.permutation(e.reshape(-1, 1)).reshape(1, 16, 8, 8)).cuda()
				# out[i, j] = model.decode(p).data

				# enc[i, j] = model.enc_x.data
				out[i, j] = y.data

				loss = mse_loss(y, x)
				avg_loss += (1 / 60) * loss.item()

		logger.debug('[%5d/%5d] avg_loss: %f' % (bi, len(dataloader), avg_loss))

		# save output
		out = np.transpose(out, (0, 3, 1, 4, 2))
		out = np.reshape(out, (768, 1280, 3))
		out = np.transpose(out, (2, 0, 1))

		y = torch.cat((img[0], out), dim=2)
		save_imgs(imgs=y.unsqueeze(0), to_size=(3, 768, 2 * 1280), name=f"./test/{cfg['exp_name']}/test_{bi}.png")
예제 #4
0
def imgEncoding(name,patches,checkpoint,interFolder):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    encoder = CAE()
    encoder.load_state_dict(torch.load(checkpoint))
    encoder.eval()
    
    encoder.to(device)
    
    for idx,patcher in enumerate(patches):
        logging.info(f'Doing for {name[idx]}')
        out = torch.zeros(6,10, 32, 32, 32)
        for i in range(6):
            for j in range(10):
                x = patcher[None,:, i, j, :, :].to(device)
                y = encoder(x)
                out[i,j]=y.data
        torch.save(out,os.path.join(interFolder,str(name[idx])+'.pt'))
    
    logging.info('Images are encoded')
    
    return encoder
예제 #5
0
path = '../experiments/training/checkpoint/model_50.pth'
exp_dir = 'output/'

os.makedirs(exp_dir,exist_ok=True)
dataset = custom_single()
dataloader = DataLoader(
    dataset=dataset,
    batch_size=16,
    shuffle=False,
    num_workers=4,
)


model = CAE()
model.eval()

state_dict = torch.load(path)
model.load_state_dict(state_dict)
model = model.cuda()

# for batch_idx, data in enumerate(dataloader, start=1):
#     patches, _ = data
#     patches = patches.float().cuda()
#     break 

# out = T.zeros(33, 32, 3, 128, 128)
# all_patches = dataset.patches

# all_patches = all_patches.reshape(33,32,3,128,128)
# for i in range(33):