コード例 #1
0
ファイル: test.py プロジェクト: alexandru-dinu/cae
def test(cfg: Namespace) -> None:
    assert cfg.checkpoint not in [None, ""]
    assert cfg.device == "cpu" or (cfg.device == "cuda"
                                   and T.cuda.is_available())

    exp_dir = ROOT_EXP_DIR / cfg.exp_name
    os.makedirs(exp_dir / "out", exist_ok=True)
    cfg.to_file(exp_dir / "test_config.json")
    logger.info(f"[exp dir={exp_dir}]")

    model = CAE()
    model.load_state_dict(T.load(cfg.checkpoint))
    model.eval()
    if cfg.device == "cuda":
        model.cuda()
    logger.info(f"[model={cfg.checkpoint}] on {cfg.device}")

    dataloader = DataLoader(dataset=ImageFolder720p(cfg.dataset_path),
                            batch_size=1,
                            shuffle=cfg.shuffle)
    logger.info(f"[dataset={cfg.dataset_path}]")

    loss_criterion = nn.MSELoss()

    for batch_idx, data in enumerate(dataloader, start=1):
        img, patches, _ = data
        if cfg.device == "cuda":
            patches = patches.cuda()

        if batch_idx % cfg.batch_every == 0:
            pass

        out = T.zeros(6, 10, 3, 128, 128)
        avg_loss = 0

        for i in range(6):
            for j in range(10):
                x = patches[:, :, i, j, :, :].cuda()
                y = model(x)
                out[i, j] = y.data

                loss = loss_criterion(y, x)
                avg_loss += (1 / 60) * loss.item()

        logger.debug("[%5d/%5d] avg_loss: %f", batch_idx, len(dataloader),
                     avg_loss)

        # save output
        out = np.transpose(out, (0, 3, 1, 4, 2))
        out = np.reshape(out, (768, 1280, 3))
        out = np.transpose(out, (2, 0, 1))

        y = T.cat((img[0], out), dim=2)
        save_imgs(
            imgs=y.unsqueeze(0),
            to_size=(3, 768, 2 * 1280),
            name=exp_dir / f"out/test_{batch_idx}.png",
        )
コード例 #2
0
def test(cfg: Namespace) -> None:
    assert cfg.checkpoint not in [None, ""]
    assert cfg.device == "cpu" or (cfg.device == "cuda"
                                   and T.cuda.is_available())

    exp_dir = ROOT_EXP_DIR / cfg.exp_name
    os.makedirs(exp_dir / "out", exist_ok=True)
    cfg.to_file(exp_dir / "test_config.json")
    logger.info(f"[exp dir={exp_dir}]")

    model = CAE()
    model.load_state_dict(T.load(cfg.checkpoint))
    model.eval()
    if cfg.device == "cuda":
        model.cuda()
    logger.info(f"[model={cfg.checkpoint}] on {cfg.device}")

    dataloader = DataLoader(dataset=ImageFolder720p(cfg.dataset_path),
                            batch_size=1,
                            shuffle=cfg.shuffle)
    logger.info(f"[dataset={cfg.dataset_path}]")

    loss_criterion = nn.MSELoss()

    for batch_idx, data in enumerate(dataloader, start=2):
        img, patches, _ = data
        print('the patches shape is:', patches.shape)
        # print(_)
        # plt.imshow(patches[0,:,3,1,:,:].permute(1,2,0))
        # # plt.imshow(patches[0].permute(1,2,0))
        # plt.show()
        if cfg.device == "cuda":
            patches = patches.cuda()

        if batch_idx % cfg.batch_every == 0:
            pass

        out = T.zeros(6, 10, 3, 128, 128)
        avg_loss = 0
        foo = []
        for i in range(6):
            for j in range(10):
                x = patches[:, :, i, j, :, :].cuda()
                print('the x shape is:', x.shape)
                y = model(x)
                print('hellyy', y.shape)
コード例 #3
0
ファイル: test.py プロジェクト: Iamanorange/cae
def test(cfg):
	os.makedirs(f"./test/{cfg['exp_name']}", exist_ok=True)

	model = CAE().cuda()

	model.load_state_dict(torch.load(cfg['chkpt']))
	model.eval()
	logger.info("Loaded model from", cfg['chkpt'])

	dataset = ImageFolder720p(cfg['dataset_path'])
	dataloader = DataLoader(dataset, batch_size=1, shuffle=cfg['shuffle'])
	logger.info(f"Done setup dataloader: {len(dataloader)}")

	mse_loss = nn.MSELoss()

	for bi, (img, patches, path) in enumerate(dataloader):

		out = torch.zeros(6, 10, 3, 128, 128)
		# enc = torch.zeros(6, 10, 16, 8, 8)
		avg_loss = 0

		for i in range(6):
			for j in range(10):
				x = Variable(patches[:, :, i, j, :, :]).cuda()
				y = model(x)

				# e = model.enc_x.data
				# p = torch.tensor(np.random.permutation(e.reshape(-1, 1)).reshape(1, 16, 8, 8)).cuda()
				# out[i, j] = model.decode(p).data

				# enc[i, j] = model.enc_x.data
				out[i, j] = y.data

				loss = mse_loss(y, x)
				avg_loss += (1 / 60) * loss.item()

		logger.debug('[%5d/%5d] avg_loss: %f' % (bi, len(dataloader), avg_loss))

		# save output
		out = np.transpose(out, (0, 3, 1, 4, 2))
		out = np.reshape(out, (768, 1280, 3))
		out = np.transpose(out, (2, 0, 1))

		y = torch.cat((img[0], out), dim=2)
		save_imgs(imgs=y.unsqueeze(0), to_size=(3, 768, 2 * 1280), name=f"./test/{cfg['exp_name']}/test_{bi}.png")
コード例 #4
0
def imgEncoding(name,patches,checkpoint,interFolder):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    encoder = CAE()
    encoder.load_state_dict(torch.load(checkpoint))
    encoder.eval()
    
    encoder.to(device)
    
    for idx,patcher in enumerate(patches):
        logging.info(f'Doing for {name[idx]}')
        out = torch.zeros(6,10, 32, 32, 32)
        for i in range(6):
            for j in range(10):
                x = patcher[None,:, i, j, :, :].to(device)
                y = encoder(x)
                out[i,j]=y.data
        torch.save(out,os.path.join(interFolder,str(name[idx])+'.pt'))
    
    logging.info('Images are encoded')
    
    return encoder
コード例 #5
0
def train(cfg):
    os.makedirs(f"out/{cfg['exp_name']}", exist_ok=True)
    os.makedirs(f"checkpoints/{cfg['exp_name']}", exist_ok=True)

    # dump config for current experiment
    with open(f"checkpoints/{cfg['exp_name']}/setup.cfg", "wt") as f:
        for k, v in cfg.items():
            f.write("%15s: %s\n" % (k, v))

    model = CAE().cuda()

    if cfg['load']:
        model.load_state_dict(torch.load(cfg['chkpt']))
        logger.info("Loaded model from", cfg['chkpt'])

    model.train()
    logger.info("Done setup model")

    dataset = ImageFolder720p(cfg['dataset_path'])
    dataloader = DataLoader(dataset,
                            batch_size=cfg['batch_size'],
                            shuffle=cfg['shuffle'],
                            num_workers=cfg['num_workers'])
    logger.info(
        f"Done setup dataloader: {len(dataloader)} batches of size {cfg['batch_size']}"
    )

    mse_loss = nn.MSELoss()
    adam = torch.optim.Adam(model.parameters(),
                            lr=cfg['learning_rate'],
                            weight_decay=1e-5)
    sgd = torch.optim.SGD(model.parameters(), lr=cfg['learning_rate'])

    optimizer = adam

    ra = 0

    for ei in range(cfg['resume_epoch'], cfg['num_epochs']):
        for bi, (img, patches, _) in enumerate(dataloader):

            avg_loss = 0
            for i in range(6):
                for j in range(10):
                    x = Variable(patches[:, :, i, j, :, :]).cuda()
                    y = model(x)
                    loss = mse_loss(y, x)

                    avg_loss += (1 / 60) * loss.item()

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            ra = avg_loss if bi == 0 else ra * bi / (bi + 1) + avg_loss / (bi +
                                                                           1)

            logger.debug('[%3d/%3d][%5d/%5d] avg_loss: %f, ra: %f' %
                         (ei + 1, cfg['num_epochs'], bi + 1, len(dataloader),
                          avg_loss, ra))

            # save img
            if bi % cfg['out_every'] == 0:
                out = torch.zeros(6, 10, 3, 128, 128)
                for i in range(6):
                    for j in range(10):
                        x = Variable(patches[0, :, i,
                                             j, :, :].unsqueeze(0)).cuda()
                        out[i, j] = model(x).cpu().data

                out = np.transpose(out, (0, 3, 1, 4, 2))
                out = np.reshape(out, (768, 1280, 3))
                out = np.transpose(out, (2, 0, 1))

                y = torch.cat((img[0], out), dim=2).unsqueeze(0)
                save_imgs(imgs=y,
                          to_size=(3, 768, 2 * 1280),
                          name=f"out/{cfg['exp_name']}/out_{ei}_{bi}.png")

            # save model
            if bi % cfg['save_every'] == cfg['save_every'] - 1:
                torch.save(
                    model.state_dict(),
                    f"checkpoints/{cfg['exp_name']}/model_{ei}_{bi}.state")

    # save final model
    torch.save(model.state_dict(),
               f"checkpoints/{cfg['exp_name']}/model_final.state")
コード例 #6
0
import os
import yaml
import argparse
from pathlib import Path

import numpy as np
import torch as T
import torch.nn as nn
from torch.utils.data import DataLoader

from data_loader import ImageFolder720p
from utils import save_imgs
import matplotlib.pylab as plt

# from bagoftools.namespace import Namespace
# from bagoftools.logger import Logger

from models.cae_32x32x32_zero_pad_bin import CAE

model = CAE()

model.load_state_dict(T.load(r'../checkpoint\model_yt_small_final.state'))

encoded = T.load('something.pt', map_location=T.device('cpu'))
print(encoded.shape)

out = model.decode(encoded)
print(out.shape)

plt.imshow(out[0].detach().permute(1,2,0))
plt.show()
コード例 #7
0
ファイル: infer.py プロジェクト: abhyantrika/tile_decoder
os.makedirs(exp_dir,exist_ok=True)
dataset = custom_single()
dataloader = DataLoader(
    dataset=dataset,
    batch_size=16,
    shuffle=False,
    num_workers=4,
)


model = CAE()
model.eval()

state_dict = torch.load(path)
model.load_state_dict(state_dict)
model = model.cuda()

# for batch_idx, data in enumerate(dataloader, start=1):
#     patches, _ = data
#     patches = patches.float().cuda()
#     break 

# out = T.zeros(33, 32, 3, 128, 128)
# all_patches = dataset.patches

# all_patches = all_patches.reshape(33,32,3,128,128)
# for i in range(33):
#     for j in range(32):
#         x = all_patches[i,j,...].unsqueeze(0).cuda().float()
#         out[i, j] = model(x.float()).cpu().data