Ejemplo n.º 1
0
def check_dataset(dataset, dataroot, augment, download):
    if dataset == "cifar10":
        dataset = get_CIFAR10(augment, dataroot, download)
    if dataset == "svhn":
        dataset = get_SVHN(augment, dataroot, download)
    if dataset == "awa2":
        dataset = get_AwA2(augment, dataroot)
    return dataset
Ejemplo n.º 2
0
def check_dataset(dataset, dataroot, augment, download):
    if dataset == 'cifar10':
        cifar10 = get_CIFAR10(augment, dataroot, download)
        input_size, num_classes, train_dataset, test_dataset = cifar10
    if dataset == 'svhn':
        svhn = get_SVHN(augment, dataroot, download)
        input_size, num_classes, train_dataset, test_dataset = svhn

    return input_size, num_classes, train_dataset, test_dataset
Ejemplo n.º 3
0
def check_dataset(dataset, dataroot, augment, download):
    if dataset == "cifar64":
        cifar64 = get_CIFAR64(augment, dataroot, download)
        input_size, num_classes, train_dataset, test_dataset = cifar64
    if dataset == "cifar10":
        cifar10 = get_CIFAR10(augment, dataroot, download)
        input_size, num_classes, train_dataset, test_dataset = cifar10
    if dataset == "svhn":
        svhn = get_SVHN(augment, dataroot, download)
        input_size, num_classes, train_dataset, test_dataset = svhn

    return input_size, num_classes, train_dataset, test_dataset
Ejemplo n.º 4
0
def check_dataset(dataset, dataroot, augment, download):
    if dataset == "cifar10":
        cifar10 = get_CIFAR10(augment, dataroot, download)
        input_size, num_classes, train_dataset, test_dataset = cifar10
    if dataset == "svhn":
        svhn = get_SVHN(augment, dataroot, download)
        input_size, num_classes, train_dataset, test_dataset = svhn
    if dataset == "mnist":
        mnist = get_MNIST(augment, dataroot, download)
        input_size, num_classes, train_dataset, test_dataset = mnist

    return input_size, num_classes, train_dataset, test_dataset
Ejemplo n.º 5
0
from datasets import get_CIFAR10, get_SVHN, postprocess
from model import Glow
import ipdb

device = torch.device("cuda")

output_folder = 'glow/'
model_name = 'glow_affine_coupling.pt'

with open(output_folder + 'hparams.json') as json_file:
    hparams = json.load(json_file)
hparams['dataroot'] = '../mutual-information'

image_shape, num_classes, _, test_cifar = get_CIFAR10(hparams['augment'],
                                                      hparams['dataroot'],
                                                      hparams['download'])
image_shape, num_classes, _, test_svhn = get_SVHN(hparams['augment'],
                                                  hparams['dataroot'],
                                                  hparams['download'])

model = Glow(image_shape, hparams['hidden_channels'], hparams['K'],
             hparams['L'], hparams['actnorm_scale'],
             hparams['flow_permutation'], hparams['flow_coupling'],
             hparams['LU_decomposed'], num_classes, hparams['learn_top'],
             hparams['y_condition'])

model.load_state_dict(torch.load(output_folder + model_name))
model.set_actnorm_init()

model = model.to(device)
# Local imports
from model import Glow
from datasets import get_CIFAR10, get_SVHN

device = torch.device("cuda")

output_folder = 'pretrained/'
model_name = 'glow_affine_coupling.pt'

with open(output_folder + 'hparams.json') as json_file:
    hparams = json.load(json_file)

print(hparams)

image_shape, num_classes, train_cifar, test_cifar = get_CIFAR10(
    augment=False, dataroot=hparams['dataroot'], download=True)
image_shape, num_classes_svhn, train_svhn, test_svhn = get_SVHN(
    augment=False, dataroot=hparams['dataroot'], download=True)

# The data is in the range [-0.5, 0.5]
train_dataloader_cifar = torch.utils.data.DataLoader(train_cifar,
                                                     batch_size=32,
                                                     num_workers=0,
                                                     pin_memory=True)
test_dataloader_cifar = torch.utils.data.DataLoader(test_cifar,
                                                    batch_size=32,
                                                    num_workers=0,
                                                    pin_memory=True)

# The data is in the range [-0.5, 0.5]
train_dataloader_svhn = torch.utils.data.DataLoader(train_svhn,