Ejemplo n.º 1
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}/prior.pt'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Save the label encoder
    # with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f:
    #     json.dump(train_dataset._label_encoder, f)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size_vae,
                               args.k).to(args.device)
    with open(args.model, 'rb') as f:
        state_dict = torch.load(f)
        model.load_state_dict(state_dict)
    model.eval()

    prior = GatedPixelCNN(args.k,
                          args.hidden_size_prior,
                          args.num_layers,
                          n_classes=10).to(args.device)
    state_dict = torch.load('models/mnist-10-prior/prior.pt')
    prior.load_state_dict(state_dict)

    # label_ = [int(i%10) for i in range(16)]
    # label_ = torch.tensor(label_).cuda()
    # reconstruction = model.decode(prior.generate(label=label_, batch_size=16,sample_index=10))
    # grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
    # writer.add_image('Partial sampling result', grid, 0)
    for epoch in range(args.num_epochs):

        label_ = [int(i % 10) for i in range(16)]
        label_ = torch.tensor(label_).cuda()
        reconstruction = model.decode(
            prior.generate(label=label_, batch_size=16, sample_index=20))
        grid = make_grid(reconstruction.cpu(),
                         nrow=8,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('Full sampling result', grid, epoch + 1)
Ejemplo n.º 2
0
def main(args):

    if args.dataset == 'miniimagenet':
        readable_labels = {
            38: 'organ',
            42: 'prayer_rug',
            31: 'file',
            61: 'cliff',
            58: 'consomme',
            59: 'hotdog',
            21: 'aircraft_carrier',
            14: 'French_bulldog',
            28: 'cocktail_shaker',
            63: 'ear',
            3: 'green_mamba',
            4: 'harvestman',
            17: 'Arctic_fox',
            32: 'fire_screen',
            11: 'komondor',
            43: 'reel',
            18: 'ladybug',
            45: 'snorkel',
            24: 'beer_bottle',
            36: 'lipstick',
            5: 'toucan',
            0: 'house_finch',
            16: 'miniature_poodle',
            50: 'tile_roof',
            15: 'Newfoundland',
            46: 'solar_dish',
            10: 'Gordon_setter',
            7: 'dugong',
            52: 'unicycle',
            20: 'rock_beauty',
            48: 'stage',
            22: 'ashcan',
            34: 'hair_slide',
            30: 'dome',
            13: 'Tibetan_mastiff',
            53: 'upright',
            62: 'bolete',
            2: 'triceratops',
            40: 'pencil_box',
            26: 'chime',
            47: 'spider_web',
            51: 'tobacco_shop',
            60: 'orange',
            49: 'tank',
            8: 'Walker_hound',
            23: 'barrel',
            6: 'jellyfish',
            33: 'frying_pan',
            9: 'Saluki',
            37: 'oboe',
            1: 'robin',
            19: 'three-toed_sloth',
            39: 'parallel_bars',
            55: 'worm_fence',
            27: 'clog',
            41: 'photocopier',
            25: 'carousel',
            29: 'dishrag',
            57: 'street_sign',
            35: 'holster',
            12: 'boxer',
            56: 'yawl',
            54: 'wok',
            44: 'slot'
        }

    elif args.dataset == 'cifar10':
        readable_labels = {
            0: 'airplane',
            1: 'automobile',
            2: 'bird',
            3: 'cat',
            4: 'deer',
            5: 'dog',
            6: 'frog',
            7: 'horse',
            8: 'ship',
            9: 'truck'
        }

    writer = SummaryWriter('./VQVAE/logs/{0}'.format(args.output_folder))
    save_filename = './VQVAE/models/{0}/prior.pt'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1,
                                              shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size_vae,
                               args.k).to(args.device)
    with open(args.model, 'rb') as f:
        state_dict = torch.load(f)
        model.load_state_dict(state_dict)
    model.eval()

    if args.dataset == 'miniimagenet':
        print("number of training classes:", len(train_dataset._label_encoder))
        n_classes = len(train_dataset._label_encoder)
        shape = (32, 32)
        yrange = range(n_classes)
        csv_filename = 'miniimagenet_generated.csv'
        sample_size = 25
    elif args.dataset == 'cifar10':
        print("number of training classes:", 10)
        n_classes = 10
        shape = (8, 8)
        yrange = range(n_classes)
        csv_filename = 'cifar10_generated.csv'
        sample_size = 1000

    prior = GatedPixelCNN(args.k,
                          args.hidden_size_prior,
                          args.num_layers,
                          n_classes=n_classes).to(args.device)

    with open(args.prior, 'rb') as f:
        state_dict = torch.load(f)
        prior.load_state_dict(state_dict)
    prior.eval()

    # maximum number of kept dimensions
    max_num_dst = 0
    max_num_sparsemax = 0

    with torch.no_grad():

        f = open(
            './VQVAE/models/{0}/{1}'.format(args.output_folder, csv_filename),
            'w')

        with f:

            writer = csv.writer(f)
            writer.writerow(['filename', 'label'])

            for y in tqdm(yrange):
                label = torch.tensor([y])
                label = label.to(args.device)
                z = prior.generate(label=label,
                                   shape=shape,
                                   batch_size=sample_size)
                x = model.decode(z)

                for im in range(x.shape[0]):
                    save_image(x.cpu()[im],
                               './data/{0}/dataset_softmax/{1}_{2}.jpg'.format(
                                   args.output_folder,
                                   str(y).zfill(2),
                                   str(im).zfill(3)),
                               range=(-1, 1),
                               normalize=True)
                    if args.dataset == 'miniimagenet':
                        y_str = list(train_dataset._label_encoder.keys())[list(
                            train_dataset._label_encoder.values()).index(y)]
                        writer.writerow([
                            '{0}_{1}.jpg'.format(
                                str(y).zfill(2),
                                str(im).zfill(3)),
                            str(y_str)
                        ])
                    elif args.dataset == 'cifar10':
                        writer.writerow([
                            '{0}_{1}.jpg'.format(
                                str(y).zfill(2),
                                str(im).zfill(3)),
                            str(y)
                        ])

                z, num_dst = prior.generate_dst(label=label,
                                                shape=shape,
                                                batch_size=sample_size)
                x = model.decode(z)

                if num_dst > max_num_dst:
                    max_num_dst = num_dst

                for im in range(x.shape[0]):
                    save_image(x.cpu()[im],
                               './data/{0}/dataset_dst/{1}_{2}.jpg'.format(
                                   args.output_folder,
                                   str(y).zfill(2),
                                   str(im).zfill(3)),
                               range=(-1, 1),
                               normalize=True)
                z, num_sparsemax = prior.generate_sparsemax(
                    label=label, shape=shape, batch_size=sample_size)
                x = model.decode(z)

                if num_sparsemax > max_num_sparsemax:
                    max_num_sparsemax = num_sparsemax

                for im in range(x.shape[0]):
                    save_image(
                        x.cpu()[im],
                        './data/{0}/dataset_sparsemax/{1}_{2}.jpg'.format(
                            args.output_folder,
                            str(y).zfill(2),
                            str(im).zfill(3)),
                        range=(-1, 1),
                        normalize=True)

    pkl.dump([max_num_dst, max_num_sparsemax],
             open(
                 "max_num_" + str(args.dataset) + "_" + str(sample_size) +
                 "_correct_dst.pkl", "wb"))