Example #1
0
def get_model(model, hidden_size, k, num_channels, resolution, num_classes):

    if model == "vqvae":
        if hidden_size == 256:
            CKPT_DIR = "models/imagenet/best.pt"
        elif hidden_size == 128:
            CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size)
        else:
            CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size)
            #CKPT_DIR = "models/imagenet/hs_32_4/best.pt"#.format(hidden_size)
        model = VectorQuantizedVAE(num_channels, hidden_size, k)
    elif model == "vae":
        CKPT_DIR = f"models/imagenet_hs_128_{hidden_size}_vae.pt"
        model = VAE(num_channels, hidden_size, hidden_size)

    elif model == "aae":
        CKPT_DIR = f"models/aae/imagenet_hs_32_{hidden_size}/best.pt"
        model = AAE(32, num_channels, hidden_size)
    else:
        model = None

    imgnetclassifier = ImgnetClassifier(model, hidden_size, resolution)

    if hidden_size > 0:
        ckpt = torch.load(CKPT_DIR)
        model.load_state_dict(ckpt)

    return imgnetclassifier
Example #2
0
def get_model(model,
              hidden_size,
              num_channels,
              resolution,
              enc_type,
              dec_type,
              num_classes,
              k=512):

    CKPT_DIR = f"models/{model}_{args.recons_loss}/{args.train_dataset}/depth_{enc_type}_{dec_type}_hs_{args.img_res}_{hidden_size}/best.pt"
    if model == "vqvae":
        #CKPT_DIR = "models/imagenet/hs_32_4/best.pt"#.format(hidden_size)
        model = VectorQuantizedVAE(num_channels, hidden_size, k, enc_type,
                                   dec_type)
    elif model == "vae":
        model = VAE(num_channels, hidden_size, enc_type, dec_type)
    elif model == "acai":
        model = ACAI(resolution, num_channels, hidden_size, enc_type, dec_type)
    else:
        model = None

    if model != "supervised":
        imgclassifier = ImgClassifier(model, hidden_size, resolution,
                                      num_classes)
    else:
        imgclassifier = SupervisedImgClassifier(hidden_size, enc_type,
                                                resolution, num_classes)

    if hidden_size > 0:
        ckpt = torch.load(CKPT_DIR)
        model.load_state_dict(ckpt["model"])

    return imgclassifier
Example #3
0
def get_recons(loader, hidden_size=256):

    model = VectorQuantizedVAE(3, hidden_size, 512).to(DEVICE)

    if hidden_size < 256:
        ckpt = torch.load(
            "./models/imagenet/hs_{}/best.pt".format(hidden_size))
    else:
        ckpt = torch.load("./models/imagenet/best.pt".format(hidden_size))

    model.load_state_dict(ckpt)

    args = type('', (), {})()
    args.device = DEVICE

    gen_img, _ = next(iter(loader))

    # grid = make_grid(gen_img.cpu(), nrow=8)
    # torchvision.utils.save_image(grid, "hs_{}_recons.png".format(hidden_size))
    #exit()

    reconstruction = generate_samples(gen_img, model, args)
    grid = make_grid(reconstruction.cpu(), nrow=8)

    return grid
Example #4
0
def main(args):
    writer = SummaryWriter("./logs/{0}".format(args.output_folder))
    save_filename = "./models/{0}".format(args.output_folder)

    # Define the train, valid & test datasets
    all_dataset = TripletWhaleDataset(args.data_folder, min_instance_count=10)
    from torch.utils.data import random_split

    k = len(all_dataset)
    train_s = int(k * 0.6)
    test_s = int(k * 0.2)
    valid_s = k - train_s - test_s
    train_dataset, test_dataset, valid_dataset = random_split(
        all_dataset, (train_s, test_s, valid_s)
    )
    num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _, fixed_ids = next(iter(test_loader))[:3]
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image("original", fixed_grid, 0)
    n_classes = len(all_dataset.index_to_class)

    vae = VectorQuantizedVAE(num_channels, args.hidden_size, args.k)
    vae.load_state_dict(torch.load(args.model_path))
    encoder = vae.encoder.eval().to(args.device)
    model = AffineClassifier(all_dataset, encoder, n_classes).to(args.device).train()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    best_loss = -1.0
    for epoch in range(args.num_epochs):
        with torch.autograd.set_detect_anomaly(True):
            train(train_loader, model, optimizer, args, writer)
        loss = test(valid_loader, model, args, writer)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open("{0}/best.pt".format(save_filename), "wb") as f:
                torch.save(model.state_dict(), f)
        with open("{0}/model_{1}.pt".format(save_filename, epoch + 1), "wb") as f:
            torch.save(model.state_dict(), f)
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    if args.dataset == 'whales':
        # Define the train, valid & test datasets
        all_dataset = WhaleDataset(args.data_folder, image_transformation=BASIC_IMAGE_T)
        from torch.utils.data import random_split
        k = len(all_dataset)
        train_s = int(k * .6)
        test_s = int(k * .2)
        valid_s = k - train_s - test_s
        train_dataset, test_dataset, valid_dataset = random_split(all_dataset, (train_s, test_s, valid_s))
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
        batch_size=args.batch_size, shuffle=False, drop_last=True,
        num_workers=args.num_workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
        batch_size=16, shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _, fixed_ids = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    vae = VectorQuantizedVAE(num_channels, args.hidden_size, args.k)
    vae.load_state_dict(torch.load(args.model_path))
    encoder = vae.encoder.eval().to(args.device)
    affine_resample = AffineCropper(all_dataset, args.hidden_size, encoder).to(args.device)
    generator = affine_resample.generator.train()
    optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, fixed_ids, affine_resample, args)
    grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('cropped', grid, 0)

    best_loss = -1.
    for epoch in range(args.num_epochs):
        with torch.autograd.set_detect_anomaly(True):
            train(train_loader, encoder, affine_resample, optimizer, args, writer)
        loss = test(valid_loader, encoder, affine_resample, args, writer)

        reconstruction = generate_samples(fixed_images, fixed_ids, affine_resample, args)
        grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
        writer.add_image('cropped', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(generator.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f:
            torch.save(generator.state_dict(), f)
Example #6
0
def get_model(ae_type, hidden_size, k, num_channels):
    CKPT_DIR = "models/imagenet/hs_32_4/best.pt"  #.format(hidden_size)
    model = VectorQuantizedVAE(num_channels, hidden_size, k)
    imgnetclassifier = ImgnetClassifier(model, hidden_size)

    ckpt = torch.load(CKPT_DIR)
    model.load_state_dict(ckpt)

    return imgnetclassifier
Example #7
0
def get_model(ae_type, hidden_size, k, num_channels):
	if ae_type == "vqvae":
		if hidden_size==256:
			CKPT_DIR = "models/imagenet/best.pt" 
		elif hidden_size==128:
			CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size)
		else:
			CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size)
		model = VectorQuantizedVAE(num_channels, hidden_size, k)
		imgnetclassifier = ImgnetClassifier(model, hidden_size)
	elif ae_type == "vae":
		CKPT_DIR = "models/imagenet_vae.pt"
		model = VAE(num_channels, hidden_size, 4096)
		imgnetclassifier = ImgnetClassifier(model, 4)

	ckpt = torch.load(CKPT_DIR)
	model.load_state_dict(ckpt)

	return imgnetclassifier
Example #8
0
def main(args):
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("Start Time =", current_time)

    if args.input == "MNIST":
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize([0.5], [0.5])])

        # Define the train & test dataSets
        test_set = datasets.MNIST("MNIST",
                                  train=False,
                                  download=True,
                                  transform=transform)

        # Define the data loaders
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=100,
                                                  shuffle=True,
                                                  pin_memory=True)

        real_batch, _ = next(iter(test_loader))

    else:
        real_batch = torch.normal(mean=0, std=1, size=(100, 1, 28, 28))

    model = VectorQuantizedVAE(1, args.hidden_size_vae, args.k).to(args.device)
    with open(args.model, 'rb') as f:
        state_dict = torch.load(f)
        model.load_state_dict(state_dict)
    generated = generate_samples(real_batch, model, args)
    save_image(make_grid(generated, nrow=10),
               './generatedImages/{0}.png'.format(args.filename))

    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("End Time =", current_time)
Example #9
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)
    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size,
                               args.k).to(args.device)
    if args.ckp != "":
        model.load_state_dict(torch.load(args.ckp))
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    if args.tmodel != '':
        net = vgg.VGG('VGG19')
        net = net.to(args.device)
        net = torch.nn.DataParallel(net)
        checkpoint = torch.load(args.tmodel)
        net.load_state_dict(checkpoint['net'])
        target_model = net
    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=8,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    for epoch in range(args.num_epochs):
        print(epoch)
        # if epoch<100:
        #     args.lr = 1e-5
        # if epoch>100 and epoch< 400:
        #     args.lr = 2e-5
        train(train_loader, model, target_model, optimizer, args, writer)
        loss, _ = test(valid_loader, model, args, writer)
        print("test loss:", loss)
        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(),
                         nrow=8,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1),
                  'wb') as f:
            torch.save(model.state_dict(), f)
Example #10
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}/prior.pt'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Save the label encoder
    # with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f:
    #     json.dump(train_dataset._label_encoder, f)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size_vae,
                               args.k).to(args.device)
    with open(args.model, 'rb') as f:
        state_dict = torch.load(f)
        model.load_state_dict(state_dict)
    model.eval()

    prior = GatedPixelCNN(args.k,
                          args.hidden_size_prior,
                          args.num_layers,
                          n_classes=32).to(args.device)
    # args.num_layers, n_classes=len(train_dataset._label_encoder)).to(args.device)
    optimizer = torch.optim.Adam(prior.parameters(), lr=args.lr)

    best_loss = -1.
    for epoch in range(args.num_epochs):
        print(epoch)
        train(train_loader, model, prior, optimizer, args, writer)
        # The validation loss is not properly computed since
        # the classes in the train and valid splits of Mini-Imagenet
        # do not overlap.
        loss = test(valid_loader, model, prior, args, writer)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open(save_filename, 'wb') as f:
                torch.save(prior.state_dict(), f)
Example #11
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}/prior.pt'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Save the label encoder
    # with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f:
    #     json.dump(train_dataset._label_encoder, f)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size_vae,
                               args.k).to(args.device)
    with open(args.model, 'rb') as f:
        state_dict = torch.load(f)
        model.load_state_dict(state_dict)
    model.eval()

    prior = GatedPixelCNN(args.k,
                          args.hidden_size_prior,
                          args.num_layers,
                          n_classes=10).to(args.device)
    state_dict = torch.load('models/mnist-10-prior/prior.pt')
    prior.load_state_dict(state_dict)

    # label_ = [int(i%10) for i in range(16)]
    # label_ = torch.tensor(label_).cuda()
    # reconstruction = model.decode(prior.generate(label=label_, batch_size=16,sample_index=10))
    # grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
    # writer.add_image('Partial sampling result', grid, 0)
    for epoch in range(args.num_epochs):

        label_ = [int(i % 10) for i in range(16)]
        label_ = torch.tensor(label_).cuda()
        reconstruction = model.decode(
            prior.generate(label=label_, batch_size=16, sample_index=20))
        grid = make_grid(reconstruction.cpu(),
                         nrow=8,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('Full sampling result', grid, epoch + 1)
Example #12
0
def main(args):

    if args.dataset == 'miniimagenet':
        readable_labels = {
            38: 'organ',
            42: 'prayer_rug',
            31: 'file',
            61: 'cliff',
            58: 'consomme',
            59: 'hotdog',
            21: 'aircraft_carrier',
            14: 'French_bulldog',
            28: 'cocktail_shaker',
            63: 'ear',
            3: 'green_mamba',
            4: 'harvestman',
            17: 'Arctic_fox',
            32: 'fire_screen',
            11: 'komondor',
            43: 'reel',
            18: 'ladybug',
            45: 'snorkel',
            24: 'beer_bottle',
            36: 'lipstick',
            5: 'toucan',
            0: 'house_finch',
            16: 'miniature_poodle',
            50: 'tile_roof',
            15: 'Newfoundland',
            46: 'solar_dish',
            10: 'Gordon_setter',
            7: 'dugong',
            52: 'unicycle',
            20: 'rock_beauty',
            48: 'stage',
            22: 'ashcan',
            34: 'hair_slide',
            30: 'dome',
            13: 'Tibetan_mastiff',
            53: 'upright',
            62: 'bolete',
            2: 'triceratops',
            40: 'pencil_box',
            26: 'chime',
            47: 'spider_web',
            51: 'tobacco_shop',
            60: 'orange',
            49: 'tank',
            8: 'Walker_hound',
            23: 'barrel',
            6: 'jellyfish',
            33: 'frying_pan',
            9: 'Saluki',
            37: 'oboe',
            1: 'robin',
            19: 'three-toed_sloth',
            39: 'parallel_bars',
            55: 'worm_fence',
            27: 'clog',
            41: 'photocopier',
            25: 'carousel',
            29: 'dishrag',
            57: 'street_sign',
            35: 'holster',
            12: 'boxer',
            56: 'yawl',
            54: 'wok',
            44: 'slot'
        }

    elif args.dataset == 'cifar10':
        readable_labels = {
            0: 'airplane',
            1: 'automobile',
            2: 'bird',
            3: 'cat',
            4: 'deer',
            5: 'dog',
            6: 'frog',
            7: 'horse',
            8: 'ship',
            9: 'truck'
        }

    writer = SummaryWriter('./VQVAE/logs/{0}'.format(args.output_folder))
    save_filename = './VQVAE/models/{0}/prior.pt'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1,
                                              shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size_vae,
                               args.k).to(args.device)
    with open(args.model, 'rb') as f:
        state_dict = torch.load(f)
        model.load_state_dict(state_dict)
    model.eval()

    if args.dataset == 'miniimagenet':
        print("number of training classes:", len(train_dataset._label_encoder))
        n_classes = len(train_dataset._label_encoder)
        shape = (32, 32)
        yrange = range(n_classes)
        csv_filename = 'miniimagenet_generated.csv'
        sample_size = 25
    elif args.dataset == 'cifar10':
        print("number of training classes:", 10)
        n_classes = 10
        shape = (8, 8)
        yrange = range(n_classes)
        csv_filename = 'cifar10_generated.csv'
        sample_size = 1000

    prior = GatedPixelCNN(args.k,
                          args.hidden_size_prior,
                          args.num_layers,
                          n_classes=n_classes).to(args.device)

    with open(args.prior, 'rb') as f:
        state_dict = torch.load(f)
        prior.load_state_dict(state_dict)
    prior.eval()

    # maximum number of kept dimensions
    max_num_dst = 0
    max_num_sparsemax = 0

    with torch.no_grad():

        f = open(
            './VQVAE/models/{0}/{1}'.format(args.output_folder, csv_filename),
            'w')

        with f:

            writer = csv.writer(f)
            writer.writerow(['filename', 'label'])

            for y in tqdm(yrange):
                label = torch.tensor([y])
                label = label.to(args.device)
                z = prior.generate(label=label,
                                   shape=shape,
                                   batch_size=sample_size)
                x = model.decode(z)

                for im in range(x.shape[0]):
                    save_image(x.cpu()[im],
                               './data/{0}/dataset_softmax/{1}_{2}.jpg'.format(
                                   args.output_folder,
                                   str(y).zfill(2),
                                   str(im).zfill(3)),
                               range=(-1, 1),
                               normalize=True)
                    if args.dataset == 'miniimagenet':
                        y_str = list(train_dataset._label_encoder.keys())[list(
                            train_dataset._label_encoder.values()).index(y)]
                        writer.writerow([
                            '{0}_{1}.jpg'.format(
                                str(y).zfill(2),
                                str(im).zfill(3)),
                            str(y_str)
                        ])
                    elif args.dataset == 'cifar10':
                        writer.writerow([
                            '{0}_{1}.jpg'.format(
                                str(y).zfill(2),
                                str(im).zfill(3)),
                            str(y)
                        ])

                z, num_dst = prior.generate_dst(label=label,
                                                shape=shape,
                                                batch_size=sample_size)
                x = model.decode(z)

                if num_dst > max_num_dst:
                    max_num_dst = num_dst

                for im in range(x.shape[0]):
                    save_image(x.cpu()[im],
                               './data/{0}/dataset_dst/{1}_{2}.jpg'.format(
                                   args.output_folder,
                                   str(y).zfill(2),
                                   str(im).zfill(3)),
                               range=(-1, 1),
                               normalize=True)
                z, num_sparsemax = prior.generate_sparsemax(
                    label=label, shape=shape, batch_size=sample_size)
                x = model.decode(z)

                if num_sparsemax > max_num_sparsemax:
                    max_num_sparsemax = num_sparsemax

                for im in range(x.shape[0]):
                    save_image(
                        x.cpu()[im],
                        './data/{0}/dataset_sparsemax/{1}_{2}.jpg'.format(
                            args.output_folder,
                            str(y).zfill(2),
                            str(im).zfill(3)),
                        range=(-1, 1),
                        normalize=True)

    pkl.dump([max_num_dst, max_num_sparsemax],
             open(
                 "max_num_" + str(args.dataset) + "_" + str(sample_size) +
                 "_correct_dst.pkl", "wb"))
    download=True,
    transform=preproc_transform,
),
                                           batch_size=BATCH_SIZE,
                                           shuffle=False,
                                           num_workers=NUM_WORKERS,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(eval('datasets.' + DATASET)(
    '../data/{}/'.format(DATASET), train=False, transform=preproc_transform),
                                          batch_size=BATCH_SIZE,
                                          shuffle=False,
                                          num_workers=NUM_WORKERS,
                                          pin_memory=True)

autoencoder = VectorQuantizedVAE(INPUT_DIM, VAE_DIM, K).to(DEVICE)
autoencoder.load_state_dict(torch.load('models/{}_vqvae.pt'.format(DATASET)))
autoencoder.eval()

model = GatedPixelCNN(K, DIM, N_LAYERS).to(DEVICE)
criterion = nn.CrossEntropyLoss().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True)


def train():
    train_loss = []
    for batch_idx, (x, label) in enumerate(train_loader):
        start_time = time.time()
        x = x.to(DEVICE)
        label = label.to(DEVICE)

        # Get the latent codes for image x
Example #14
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'clevr':
        transform = transforms.Compose([
            # transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        import socket
        if "Alien" in socket.gethostname():
            dataset_name = "/media/mihir/dataset/clevr_veggies/"
        else:
            dataset_name = "/projects/katefgroup/datasets/clevr_veggies/"
            dataset_name = '/home/mprabhud/dataset/clevr_veggies'
        # Define the train, valid & test datasets
        train_dataset = Clevr(dataset_name,mod = args.modname\
            , train=True, transform=transform,object_level= args.object_level)
        valid_dataset = Clevr(dataset_name,mod = args.modname,\
         valid=True,transform=transform,object_level= args.object_level)
        test_dataset = Clevr(dataset_name,mod = args.modname,\
         test=True, transform=transform,object_level= args.object_level)
        num_channels = 3
    elif args.dataset == 'carla':
        if args.use_depth:
            transform = transforms.Compose([
                # transforms.RandomResizedCrop(128),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5, 0.5),
                                     (0.5, 0.5, 0.5, 0.5))
            ])
            num_channels = 4
        else:
            transform = transforms.Compose([
                # transforms.RandomResizedCrop(128),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
            num_channels = 3
        import socket
        if "Alien" in socket.gethostname():
            dataset_name = "/media/mihir/dataset/clevr_veggies/"
        else:
            dataset_name = '/home/shamitl/datasets/carla'
            dataset_name = "/projects/katefgroup/datasets/carla/"
            dataset_name = '/home/mprabhud/dataset/carla'
        # Define the train, valid & test datasets
        train_dataset = Clevr(dataset_name,mod = args.modname\
            , train=True, transform=transform,object_level= args.object_level,use_depth=args.use_depth)
        valid_dataset = Clevr(dataset_name,mod = args.modname,\
         valid=True,transform=transform,object_level= args.object_level,use_depth=args.use_depth)
        test_dataset = Clevr(dataset_name,mod = args.modname,\
         test=True, transform=transform,object_level= args.object_level,use_depth=args.use_depth)

    # elif args.dataset == 'miniimagenet':
    #     transform = transforms.Compose([
    #         transforms.RandomResizedCrop(128),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #     ])
    #     # Define the train, valid & test datasets
    #     train_dataset = MiniImagenet(args.data_folder, train=True,
    #         download=True, transform=transform)
    #     # valid_dataset = MiniImagenet(args.data_folder, valid=True,
    #     #     download=True, transform=transform)
    #     # test_dataset = MiniImagenet(args.data_folder, test=True,
    #     #     download=True, transform=transform)
    #     # num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)
    fixed_images, _ = next(iter(train_loader))
    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size,
                               args.object_level, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    if args.load_model is not "":
        with open(args.load_model, 'rb') as f:
            state_dict = torch.load(f)
            model.load_state_dict(state_dict)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=8,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    # st()
    best_loss = -1.
    for epoch in range(args.num_epochs):
        if not args.test_mode:
            train(train_loader, model, optimizer, args, writer, epoch)
            # st()
            loss, _ = test_old(valid_loader, model, args, writer)
            reconstruction = generate_samples(fixed_images, model, args)
            grid = make_grid(reconstruction.cpu(),
                             nrow=8,
                             range=(-1, 1),
                             normalize=True)
            writer.add_image('reconstruction', grid, epoch + 1)
            # st()
            with open('{0}/recent.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
            if (epoch == 0) or (loss < best_loss):
                best_loss = loss
                with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                    torch.save(model.state_dict(), f)
            # else:
            #     print("nothing")
        else:
            test(train_loader, model, args, writer)
            reconstruction = generate_samples(fixed_images, model, args)
            grid = make_grid(reconstruction.cpu(),
                             nrow=8,
                             range=(-1, 1),
                             normalize=True)
            writer.add_image('reconstruction', grid, epoch + 1)