Ejemplo n.º 1
0
def test_output():
    BATCH_SIZE = 1
    DATA_DIR = "../CSC2515_data/cifar/test/"
    scale_transform = transforms.Compose([
        transforms.Scale(32),
        transforms.RandomCrop(32),
        transforms.ToTensor()
    ])
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    test_set = ValImageFolder(root=DATA_DIR, transform=scale_transform)
    test_set_size = len(test_set)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=BATCH_SIZE,
                                              shuffle=False,
                                              num_workers=1)

    # (img_original, img_scale), y = iter(test_loader).next()

    color_model = ColorNet()
    color_model.load_state_dict(torch.load('colornet_params.pkl'))
    if USE_CUDA:
        color_model.cuda()
    color_model.eval()

    i_gray = 0
    i_color = 0
    i_original = 0
    count = 0
    for data, label in test_loader:
        gray_img = data[0].unsqueeze(1).float()

        # gray_name = '../CSC2515_output/gray/' + str(i_gray) + '.jpg'
        # for img in gray_img:
        #     pic = img.squeeze().numpy()
        #     pic = pic.astype(np.float64)
        #     plt.imsave(gray_name, pic, cmap='gray')
        #     i_gray += 1

        w = gray_img.size()[2]
        h = gray_img.size()[3]
        scale_img = data[1].unsqueeze(1).float()
        if USE_CUDA:
            gray_img, scale_img = gray_img.cuda(), scale_img.cuda()

        gray_img, scale_img = Variable(gray_img,
                                       volatile=True), Variable(scale_img)
        pred_label, output = color_model(gray_img, scale_img)
        color_img = torch.cat((gray_img, output[:, :, 0:w, 0:h]), 1)
        color_img = color_img.data.cpu().numpy().transpose((0, 2, 3, 1))
        for img in color_img:
            img[:, :, 0:1] = img[:, :, 0:1] * 100
            img[:, :, 1:3] = img[:, :, 1:3] * 255 - 128
            img = img.astype(np.float64)
            img = lab2rgb(img)
            color_name = '../CSC2515_output/colorimg/' + str(i_color) + '.jpg'
            plt.imsave(color_name, img)
            i_color += 1
Ejemplo n.º 2
0
def test_trainset():
    BATCH_SIZE = 5
    DATA_DIR = "../CSC2515_data/cifar/test/"
    scale_transform = transforms.Compose([
        transforms.Scale(32),
        transforms.RandomCrop(32),
        transforms.ToTensor()
    ])
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    test_set = ValImageFolder(root=DATA_DIR, transform=scale_transform)
    test_set_size = len(test_set)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=BATCH_SIZE,
                                              shuffle=True,
                                              num_workers=1)

    # (img_original, img_scale), y = iter(test_loader).next()

    color_model = ColorNet()
    color_model.load_state_dict(torch.load('colornet_params.pkl'))
    if USE_CUDA:
        color_model.cuda()
    color_model.eval()

    data, label = iter(test_loader).next()
    gray_img = data[0].unsqueeze(1).float()

    fig = plt.figure()
    i = 1
    for img in gray_img:
        pic = img.squeeze().numpy()
        pic = pic.astype(np.float64)
        fig.add_subplot(3, 5, i)
        i += 1
        plt.imshow(pic, cmap='gray')

    w = gray_img.size()[2]
    h = gray_img.size()[3]
    scale_img = data[1].unsqueeze(1).float()
    if USE_CUDA:
        gray_img, scale_img = gray_img.cuda(), scale_img.cuda()

    gray_img, scale_img = Variable(gray_img,
                                   volatile=True), Variable(scale_img)
    pred_label, output = color_model(gray_img, scale_img)
    color_img = torch.cat((gray_img, output[:, :, 0:w, 0:h]), 1)
    color_img = color_img.data.cpu().numpy().transpose((0, 2, 3, 1))
    for img in color_img:
        img[:, :, 0:1] = img[:, :, 0:1] * 100
        img[:, :, 1:3] = img[:, :, 1:3] * 255 - 128
        img = img.astype(np.float64)
        img = lab2rgb(img)
        fig.add_subplot(3, 5, i)
        i += 1
        plt.imshow(img)

    original_img = data[2].float().squeeze().numpy()
    for img in original_img:
        # pic = img.squeeze().numpy()
        pic = img.astype(np.float64)
        fig.add_subplot(3, 5, i)
        i += 1
        plt.imshow(pic)

    plt.show()
Ejemplo n.º 3
0
from torch.autograd import Variable
from torchvision.utils import make_grid, save_image
from skimage.color import lab2rgb
from skimage import io
from colornet import ColorNet
from myimgfolder import ValImageFolder
import numpy as np
import matplotlib.pyplot as plt
from skimage import io

data_dir = "places365_standard/val"
# data_dir = "custom_test"
gamut = np.load('models/custom_layers/pts_in_hull.npy')
have_cuda = torch.cuda.is_available()

val_set = ValImageFolder(data_dir)
val_set_size = len(val_set)
val_loader = torch.utils.data.DataLoader(val_set,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=1)

color_model = torch.nn.DataParallel(ColorNet())
if have_cuda:
    color_model.load_state_dict(
        torch.load('./pretrained/colornet_params.pkl', ))
    color_model.cuda()
else:
    color_model.load_state_dict(
        torch.load('./pretrained/colornet_params.pkl', map_location='cpu'))
Ejemplo n.º 4
0
from myimgfolder import ValImageFolder
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms

original_transform = transforms.Compose([
    transforms.Scale(1024),
    transforms.CenterCrop(900),
    #transforms.RandomHorizontalFlip(),
    #transforms.ToTensor()
])
torch.cuda.empty_cache()
data_dir = "images256"
have_cuda = torch.cuda.is_available()

val_set = ValImageFolder(data_dir, original_transform)
val_set_size = len(val_set)
val_loader = torch.utils.data.DataLoader(val_set,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=0)

color_model = ColorNet()
color_model.load_state_dict(torch.load('colornet_params.pkl'))
if have_cuda:
    color_model.cuda()


def val():
    color_model.eval()
    torch.cuda.empty_cache()