Ejemplo n.º 1
0
def get_omniglot_loaders(arguments):
    if arguments.preload_all_data: raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(datasets.Omniglot(
        DATASET_PATH,
        background=True,
        download=True,
        transform=transforms.Compose([
            transforms.RandomAffine(10),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=arguments.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=NUM_WORKERS)

    test_loader = torch.utils.data.DataLoader(
        datasets.Omniglot(
            DATASET_PATH,
            background=False,
            download=True,
            transform=transforms.Compose([  # transforms.RandomCrop(70),
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
        batch_size=arguments.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=NUM_WORKERS)

    return train_loader, test_loader
Ejemplo n.º 2
0
def get_inverted_omniglot_loaders(arguments, mean=(0.5,), std=(0.5,)):
    print("Using mean", mean)
    # (1-0.92206,), (0.08426,)
    if arguments['preload_all_data']: raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(
        datasets.Omniglot(DATASET_PATH, background=True, download=True,
                          transform=transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Lambda(lambda x: 1 - x),
                              transforms.Normalize((1 - mean), (1 - std))
                          ])),
        batch_size=arguments['batch_size'],
        shuffle=True,
        pin_memory=True,
        num_workers=NUM_WORKERS
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.Omniglot(DATASET_PATH, background=False, download=True,
                          transform=transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Lambda(lambda x: 1 - x),
                              transforms.Normalize((1 - mean), (1 - std))
                          ])),
        batch_size=arguments['batch_size'],
        shuffle=False,
        pin_memory=True,
        num_workers=NUM_WORKERS
    )

    return train_loader, test_loader
Ejemplo n.º 3
0
def main():
    controller = Controller()

    kwargs = {"num_workers": 1, "pin_memory": True} if FLAGS.CUDA else {}
    train_loader = DataLoader(
        datasets.Omniglot(
            "data",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # transforms.Normalize(mean, std) should add for better performance, + other transforms
            ])),
        batch_sizer=FLAGS.BATCH_SIZE,
        shuffle=True,
        **kwargs)

    test_loader = DataLoader(
        datasets.Omniglot(
            "data",
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # transforms.Normalize(mean, std) should add for better performance, + other transforms
            ])),
        batch_sizer=FLAGS.BATCH_SIZE,
        shuffle=True,
        **kwargs)

    if FLAGS.TRAIN:
        train(controller, train_loader)
    else:
        test(controller, test_loader)
Ejemplo n.º 4
0
    def _get_dataset(self):
        print("Loading {} dataset from {}.".format(self.dataset_name, self.datadir))
        augment_transforms = []
        image_transforms = tfs.Compose([tfs.ToTensor()])

        if self.dataset_name == 'omniglot':
            self.input_shape, self.num_classes = (1, 105, 105), 1623

            self.train_dataset = datasets.Omniglot(self.datadir,
                                                   background=True,
                                                   target_transform=None,
                                                   download=True,
                                                   transform=image_transforms)
            self.test_dataset = datasets.Omniglot(self.datadir,
                                                  background=False,
                                                  target_transform=None,
                                                  download=True,
                                                  transform=image_transforms)
        elif self.dataset_name == 'cifar100':
            self.input_shape, self.num_classes = (3, 32, 32), 100

            if self.augment:
                print("Using augmentation on train dataset.")
                augment_transforms = [tfs.RandomCrop(32, padding=4),
                                      tfs.RandomHorizontalFlip()]
            image_transforms = [tfs.ToTensor(),
                                tfs.Normalize(mean=[0.507, 0.487, 0.441],
                                              std=[0.267, 0.256, 0.276])]
            train_transforms = tfs.Compose(augment_transforms + image_transforms)
            test_transforms = tfs.Compose(image_transforms)

            self.train_dataset = datasets.CIFAR100(self.datadir,
                                                   train=True,
                                                   download=True,
                                                   transform=train_transforms)
            self.test_dataset = datasets.CIFAR100(self.datadir,
                                                  train=False,
                                                  download=True,
                                                  transform=test_transforms)
        elif self.dataset_name == 'cifar10':
            self.input_shape, self.num_classes = (3, 32, 32), 10

            self.train_dataset = datasets.CIFAR10(self.datadir,
                                                  train=True,
                                                  download=True,
                                                  transform=image_transforms)
            self.test_dataset = datasets.CIFAR10(self.datadir,
                                                 train=False,
                                                 download=True,
                                                 transform=image_transforms)
        else:
            raise Exception("{} dataset not found!".format(self.dataset_name))
Ejemplo n.º 5
0
    def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1):
        self.data_dir = data_dir
        dataset = datasets.Omniglot(self.data_dir, download=True, background=True)
        eval_dataset = datasets.Omniglot(self.data_dir, download=True, background=False)
        target_transform = OmniglotTargetTransform(self.data_dir, background=True)
        eval_target_transform = OmniglotTargetTransform(self.data_dir, background=False)

        self.dataset = datasets.Omniglot(self.data_dir, background=True, download=True, transform=omni_transforms, target_transform=target_transform)
        self.eval_dataset = datasets.Omniglot(self.data_dir, background=False, download=True, transform=omni_transforms, target_transform=eval_target_transform)
        self.targets = np.array([self.dataset[i][1] for i in range(len(self.dataset))])
        eval_targets = np.array([self.eval_dataset[i][1] for i in range(len(self.eval_dataset))])
        super().__init__(self.dataset, batch_size, shuffle, validation_split,
                         num_workers, self.targets, drop_train_last=False,
                         drop_valid_last=False, evaluation=(self.eval_dataset, eval_targets))
Ejemplo n.º 6
0
def omniglot():
    return itertools.chain(*[
        collect_download_configs(
            lambda: datasets.Omniglot(
                ROOT, background=background, download=True),
            name=f"Omniglot, {'background' if background else 'evaluation'}",
        ) for background in (True, False)
    ])
Ejemplo n.º 7
0
    def download(self):
        origin_dir = 'data/omniglot-py'
        processed_dir = self.root_dir

        dset.Omniglot(root='data', background=False, download=True)
        dset.Omniglot(root='data', background=True, download=True)

        try:
            os.mkdir(processed_dir)
        except OSError:
            pass

        for p in ['images_background', 'images_evaluation']:
            for f in os.listdir(os.path.join(origin_dir, p)):
                shutil.move(os.path.join(origin_dir, p, f), processed_dir)

        shutil.rmtree(origin_dir)
Ejemplo n.º 8
0
 def __init__(self, height=32, length=32):
     self.channels = 1
     self.height = height
     self.length = length
     self.data = datasets.Omniglot(root='./data', download=True)
     self.make_tasks()
     self.split_validation_and_training_task()
     self.resize = transforms.Resize((self.height, self.length))
     self.to_tensor = transforms.ToTensor()
Ejemplo n.º 9
0
 def __init__(self, height=32, length=32):
     self.channels = 1
     self.height = height
     self.length = length
     self.data = datasets.Omniglot(root='./data', download=True)
     self.task_maker()
     self.split_dataset()
     self.resize = transforms.Resize((self.height, self.length))
     self.tensor = transforms.ToTensor()
 def __init__(self, height=32, length=32):
     self.channels = 1
     self.height = height
     self.length = length
     self.data = datasets.Omniglot(
         root=
         'C:/Users/kashi/Documents/CMPE_258/proj/FIGR-master/FIGR-master/data/',
         download=True)
     self.make_tasks()
     self.split_validation_and_training_task()
     self.resize = transforms.Resize((self.height, self.length))
     self.to_tensor = transforms.ToTensor()
Ejemplo n.º 11
0
def load_data_and_initialize_loaders(data_name, train_batch, test_batch):
    data_name = data_name.lower()
    kwargs = {'num_workers': 1, 'pin_memory': True}
    if data_name == 'mnist':
        train_data = datasets.MNIST('./data',
                                    train=True,
                                    download=True,
                                    transform=transforms.ToTensor())
        test_data = datasets.MNIST('./data',
                                   train=False,
                                   transform=transforms.ToTensor())
    elif data_name == 'fashion' or data_name == 'fashionmnist':
        train_data = datasets.FashionMNIST('./data',
                                           train=True,
                                           download=True,
                                           transform=transforms.ToTensor())
        test_data = datasets.FashionMNIST('./data',
                                          train=False,
                                          transform=transforms.ToTensor())
    elif data_name == 'omniglot':
        train_data = datasets.Omniglot(root='./data',
                                       background=True,
                                       download=True,
                                       transform=transforms.ToTensor())
        test_data = datasets.Omniglot(root='./data',
                                      background=False,
                                      download=True,
                                      transform=transforms.ToTensor())
    # else: raise Exception("Data name not recognized")
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=train_batch,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=test_batch,
                                              shuffle=True,
                                              **kwargs)
    return train_loader, test_loader
Ejemplo n.º 12
0
def load_omniglot(args):
    torch.cuda.manual_seed(1)
    kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True}
    path = 'data_o/'
    if args.scratch:
        path = '/scratch/eecs-share/ratzlafn/' + path
    train_loader = torch.utils.data.DataLoader(datasets.Omniglot(
        path,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=32,
                                               shuffle=True,
                                               **kwargs)
    return train_loader
Ejemplo n.º 13
0
def few_shot_omniglot(root, train=True, batch_size=128, timesteps=15,
                      n_jobs=0, resize=28, **kwargs):
    train_set = dset.Omniglot(root=root, download=True, background=train,
                              transform=transforms.Compose([
                                  transforms.Resize([resize, resize], interpolation=Image.NEAREST),
                                  transforms.ToTensor(),
                                  transforms.Lambda(lambda x: 1 - x)
                              ]))
    collate = FewShotCollate(timesteps=timesteps)
    sampler = FewShotSampler(
        num_cls=len(train_set._character_images),
        batch_size=batch_size, separate=True,
        timesteps=timesteps, **kwargs,
        img_per_class=len(train_set._character_images[0]),
    )
    return torch.utils.data.DataLoader(
        dataset=train_set, batch_sampler=sampler,
        collate_fn=collate, num_workers=n_jobs
    )
Ejemplo n.º 14
0
    def __load_data(subset) -> pd.DataFrame:
        # 必要があればデータをダウンロード
        datasets.Omniglot(root=config.DATA_PATH,
                          download=True,
                          background=(subset == 'background'))

        # プログレスバー表示のため,全量をカウント
        print(f'loading omniglot dataset ({subset})')
        total_images = 0
        for root, folders, files in os.walk(
                f'{config.DATA_PATH}/omniglot-py/images_{subset}/'):
            total_images += len(files)

        # ファイルシステムを参照し,画像データに属性を付与してDataFrameをつくる
        progress = tqdm(total=total_images)
        images = list()
        for root, folders, files in os.walk(
                f'{config.DATA_PATH}/omniglot-py/images_{subset}/'):
            alphabet = root.split('/')[-2]
            class_name = alphabet + '.' + root.split('/')[-1]
            for f in files:
                images.append({
                    'subset': subset,
                    'alphabet': alphabet,
                    'class_name': class_name,
                    'filepath': os.path.join(root, f)
                })
                progress.update(1)
        progress.close()

        # DataFrameに変換
        df = pd.DataFrame(images)
        df = df.assign(id=df.index.values)  # indexに応じた値をIDカラムとして追加
        unique_characters = sorted(df['class_name'].unique())
        num_classes = len(df['class_name'].unique())
        class_name_to_id = {
            unique_characters[i]: i
            for i in range(num_classes)
        }
        df = df.assign(
            class_id=df['class_name'].apply(lambda c: class_name_to_id[
                c]))  # クラスごとにユニークなIDを振り,class_nameカラムとして追加
        return df
Ejemplo n.º 15
0
def load_omniglot(root_dir=None,
                  batch_size=20,
                  shuffle=True,
                  transform=None,
                  download=True):
    dataset_type = "binary"

    if root_dir is None:
        root_dir = pathlib.Path(sys.argv[0]).parents[0] / 'datasets'
        # root_dir = str(root_dir)

    if transform is None:
        transform = transforms.ToTensor()

    train_dataset = datasets.Omniglot(root_dir,
                                      transform=transform,
                                      download=download)
    test_dataset = datasets.Omniglot(root_dir,
                                     transform=transform,
                                     download=download,
                                     background=False)

    train_data = np.zeros((len(train_dataset), 105, 105))
    test_data = np.zeros((len(test_dataset), 105, 105))

    for i, (image, _) in enumerate(train_dataset):
        train_data[i] = image.numpy() / 255

    for i, (image, _) in enumerate(test_dataset):
        test_data[i] = image.numpy() / 255

    if shuffle:
        np.random.shuffle(train_data)
        np.random.shuffle(test_data)

    train_data = torch.from_numpy(train_data)
    test_data = torch.from_numpy(test_data)

    # no labels
    train_labels = torch.zeros(train_data.shape)
    test_labels = torch.zeros(test_data.shape)

    train_dataset = data_utils.TensorDataset(train_data.float(), train_labels)
    test_dataset = data_utils.TensorDataset(test_data.float(), test_labels)

    size_train = len(train_dataset)
    indices = list(range(size_train))
    val_split = size_train - 1345  # given by god to make the batch size reasonable

    train_idx, valid_idx = indices[:val_split], indices[val_split:]
    train_sampler = SubsetSampler(train_idx)
    valid_sampler = SubsetSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        shuffle=False,
    )

    valid_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               sampler=valid_sampler,
                                               shuffle=False)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              shuffle=shuffle)

    return train_loader, test_loader, valid_loader, dataset_type
Ejemplo n.º 16
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Amortized approximation on MNIST')
    parser.add_argument('--batch-size', type=int, default=256, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--approx-epochs', type=int, default=200, metavar='N',
                        help='number of epochs to approx (default: 10)')
    parser.add_argument('--lr', type=float, default=1e-2, metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--S', type=int, default=100, metavar='N',
                        help='number of posterior samples from the Bayesian model')
    parser.add_argument('--model-path', type=str, default='../saved_models/mnist_sgld/', metavar='N',
                        help='number of posterior samples from the Bayesian model')
    parser.add_argument('--from-approx-model', type=int, default=1, metavar='N',
                        help='if our model is loaded or trained')
    parser.add_argument('--test-ood-from-disk', type=int, default=1,
                        help='generate test samples or load from disk')
    parser.add_argument('--ood-name', type=str, default='omniglot',
                        help='name of the used ood dataset')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {}

    tr_data = MNIST('../data', train=True, transform=transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))]), download=True)

    te_data = MNIST('../data', train=False, transform=transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))]), download=True)

    train_loader = torch.utils.data.DataLoader(
        tr_data,
        batch_size=args.batch_size, shuffle=False, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        te_data,
        batch_size=args.batch_size, shuffle=False,  **kwargs)

    if args.ood_name == 'omniglot':
        ood_data = datasets.Omniglot('../../data', download=True, transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]))
    elif args.ood_name == 'SEMEION':
        ood_data = datasets.SEMEION('../../data', download=True,  transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]))

    ood_loader = torch.utils.data.DataLoader(
        ood_data,
        batch_size=args.batch_size, shuffle=False, **kwargs)

    model = mnist_mlp(dropout=False).to(device)

    model.load_state_dict(torch.load(args.model_path + 'sgld-mnist.pt'))

    test(args, model, device, test_loader)

    if args.from_approx_model == 0:
        output_samples = torch.load(args.model_path + 'mnist-sgld-train-samples.pt')

    # --------------- training approx ---------
    fmodel = mnist_mlp_h().to(device)
    gmodel = mnist_mlp_g().to(device)

    if args.from_approx_model == 0:
        g_optimizer = optim.SGD(gmodel.parameters(), lr=args.lr)
        f_optimizer = optim.SGD(fmodel.parameters(), lr=args.lr)
        best_acc = 0
        for epoch in range(1, args.approx_epochs + 1):
            train_approx(args, fmodel, gmodel, device, train_loader, f_optimizer, g_optimizer, output_samples, epoch)
            acc = test(args, fmodel, device, test_loader)
            # if (args.save_approx_model == 1):
            if acc > best_acc:
                torch.save(fmodel.state_dict(), args.model_path + 'sgld-mnist-mmd-mean.pt')
                torch.save(gmodel.state_dict(), args.model_path + 'sgld-mnist-mmd-conc.pt')
                best_acc = acc

    else:
        fmodel.load_state_dict(torch.load(args.model_path + 'sgld-mnist-mmd-mean.pt'))
        gmodel.load_state_dict(torch.load(args.model_path + 'sgld-mnist-mmd-conc.pt'))

    print('generating teacher particles for testing&ood data ...')
    # generate particles for test and ood dataset
    model.train()
    if args.test_ood_from_disk == 1:
        teacher_test_samples = torch.load(args.model_path + 'mnist-sgld-test-samples.pt')
    else:
        with torch.no_grad():
            # obtain ensemble outputs
            all_samples = []
            for i in range(500):
                samples_a_round = []
                for data, target in test_loader:
                    data = data.to(device)
                    data = data.view(data.shape[0], -1)
                    output = F.softmax(model(data))
                    samples_a_round.append(output)
                samples_a_round = torch.cat(samples_a_round).cpu()
                all_samples.append(samples_a_round)
            all_samples = torch.stack(all_samples).permute(1,0,2)

            torch.save(all_samples, args.model_path + 'mnist-sgld-test-samples.pt')
            teacher_test_samples = all_samples

    if args.test_ood_from_disk == 1:
        teacher_ood_samples = torch.load(args.model_path + 'mnist-sgld-' + args.ood_name + '-samples.pt')
    else:
        with torch.no_grad():
            # obtain ensemble outputs
            all_samples = []
            for i in range(500):
                samples_a_round = []
                for data, target in ood_loader:
                    data = data.to(device)
                    data = data.view(data.shape[0], -1)
                    output = F.softmax(model(data))
                    samples_a_round.append(output)
                samples_a_round = torch.cat(samples_a_round).cpu()
                all_samples.append(samples_a_round)
            all_samples = torch.stack(all_samples).permute(1,0,2)

            torch.save(all_samples, args.model_path + 'mnist-sgld-' + args.ood_name + '-samples.pt')
            teacher_ood_samples = all_samples

    eval_approx(args, fmodel, gmodel, device, test_loader, ood_loader, teacher_test_samples, teacher_ood_samples)
Ejemplo n.º 17
0
Archivo: vae.py Proyecto: vgurev/LRS_NF
def run(seed):

    assert torch.cuda.is_available()
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    np.random.seed(seed)
    torch.manual_seed(seed)

    # Create training data.
    data_transform = tvtransforms.Compose(
        [tvtransforms.ToTensor(),
         tvtransforms.Lambda(torch.bernoulli)])

    if args.dataset_name == 'mnist':
        dataset = datasets.MNIST(root=os.path.join(utils.get_data_root(),
                                                   'mnist'),
                                 train=True,
                                 download=True,
                                 transform=data_transform)
        test_dataset = datasets.MNIST(root=os.path.join(
            utils.get_data_root(), 'mnist'),
                                      train=False,
                                      download=True,
                                      transform=data_transform)
    elif args.dataset_name == 'fashion-mnist':
        dataset = datasets.FashionMNIST(root=os.path.join(
            utils.get_data_root(), 'fashion-mnist'),
                                        train=True,
                                        download=True,
                                        transform=data_transform)
        test_dataset = datasets.FashionMNIST(root=os.path.join(
            utils.get_data_root(), 'fashion-mnist'),
                                             train=False,
                                             download=True,
                                             transform=data_transform)
    elif args.dataset_name == 'omniglot':
        dataset = datasets.Omniglot(root=os.path.join(utils.get_data_root(),
                                                      'omniglot'),
                                    train=False,
                                    download=True,
                                    transform=data_transform)
        test_dataset = datasets.Omniglot(root=os.path.join(
            utils.get_data_root(), 'omniglot'),
                                         train=False,
                                         download=True,
                                         transform=data_transform)
    elif args.dataset_name == 'emnist':
        rotate = partial(tvF.rotate, angle=-90)
        hflip = tvF.hflip
        data_transform = tvtransforms.Compose([
            tvtransforms.Lambda(rotate),
            tvtransforms.Lambda(hflip),
            tvtransforms.ToTensor(),
            tvtransforms.Lambda(torch.bernoulli)
        ])
        dataset = datasets.EMNIST(root=os.path.join(utils.get_data_root(),
                                                    'emnist'),
                                  split='letters',
                                  train=True,
                                  transform=data_transform,
                                  download=True)
        test_dataset = datasets.EMNIST(root=os.path.join(
            utils.get_data_root(), 'emnist'),
                                       split='letters',
                                       train=False,
                                       transform=data_transform,
                                       download=True)
    else:
        raise ValueError

    if args.dataset_name == 'omniglot':
        split = -1345
    elif args.dataset_name == 'emnist':
        split = -20000
    else:
        split = -10000
    indices = np.arange(len(dataset))
    np.random.shuffle(indices)
    train_indices, val_indices = indices[:split], indices[split:]
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)
    train_loader = data.DataLoader(
        dataset=dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=4 if args.dataset_name == 'emnist' else 0)
    train_generator = data_.batch_generator(train_loader)
    val_loader = data.DataLoader(dataset=dataset,
                                 batch_size=1024,
                                 sampler=val_sampler,
                                 shuffle=False,
                                 drop_last=False)
    val_batch = next(iter(val_loader))[0]
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=False,
        drop_last=False,
    )

    def create_linear_transform():
        if args.linear_type == 'lu':
            return transforms.CompositeTransform([
                transforms.RandomPermutation(args.latent_features),
                transforms.LULinear(args.latent_features, identity_init=True)
            ])
        elif args.linear_type == 'svd':
            return transforms.SVDLinear(args.latent_features,
                                        num_householder=4,
                                        identity_init=True)
        elif args.linear_type == 'perm':
            return transforms.RandomPermutation(args.latent_features)
        else:
            raise ValueError

    def create_base_transform(i, context_features=None):
        if args.prior_type == 'affine-coupling':
            return transforms.AffineCouplingTransform(
                mask=utils.create_alternating_binary_mask(
                    features=args.latent_features, even=(i % 2 == 0)),
                transform_net_create_fn=lambda in_features, out_features: nn_.
                ResidualNet(in_features=in_features,
                            out_features=out_features,
                            hidden_features=args.hidden_features,
                            context_features=context_features,
                            num_blocks=args.num_transform_blocks,
                            activation=F.relu,
                            dropout_probability=args.dropout_probability,
                            use_batch_norm=args.use_batch_norm))
        elif args.prior_type == 'rq-coupling':
            return transforms.PiecewiseRationalQuadraticCouplingTransform(
                mask=utils.create_alternating_binary_mask(
                    features=args.latent_features, even=(i % 2 == 0)),
                transform_net_create_fn=lambda in_features, out_features: nn_.
                ResidualNet(in_features=in_features,
                            out_features=out_features,
                            hidden_features=args.hidden_features,
                            context_features=context_features,
                            num_blocks=args.num_transform_blocks,
                            activation=F.relu,
                            dropout_probability=args.dropout_probability,
                            use_batch_norm=args.use_batch_norm),
                num_bins=args.num_bins,
                tails='linear',
                tail_bound=args.tail_bound,
                apply_unconditional_transform=args.
                apply_unconditional_transform,
            )
        elif args.prior_type == 'rl-coupling':
            return transforms.PiecewiseRationalLinearCouplingTransform(
                mask=utils.create_alternating_binary_mask(
                    features=args.latent_features, even=(i % 2 == 0)),
                transform_net_create_fn=lambda in_features, out_features: nn_.
                ResidualNet(in_features=in_features,
                            out_features=out_features,
                            hidden_features=args.hidden_features,
                            context_features=context_features,
                            num_blocks=args.num_transform_blocks,
                            activation=F.relu,
                            dropout_probability=args.dropout_probability,
                            use_batch_norm=args.use_batch_norm),
                num_bins=args.num_bins,
                tails='linear',
                tail_bound=args.tail_bound,
                apply_unconditional_transform=args.
                apply_unconditional_transform,
            )
        elif args.prior_type == 'affine-autoregressive':
            return transforms.MaskedAffineAutoregressiveTransform(
                features=args.latent_features,
                hidden_features=args.hidden_features,
                context_features=context_features,
                num_blocks=args.num_transform_blocks,
                use_residual_blocks=True,
                random_mask=False,
                activation=F.relu,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm)
        elif args.prior_type == 'rq-autoregressive':
            return transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
                features=args.latent_features,
                hidden_features=args.hidden_features,
                context_features=context_features,
                num_bins=args.num_bins,
                tails='linear',
                tail_bound=args.tail_bound,
                num_blocks=args.num_transform_blocks,
                use_residual_blocks=True,
                random_mask=False,
                activation=F.relu,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm)
        elif args.prior_type == 'rl-autoregressive':
            return transforms.MaskedPiecewiseRationalLinearAutoregressiveTransform(
                features=args.latent_features,
                hidden_features=args.hidden_features,
                context_features=context_features,
                num_bins=args.num_bins,
                tails='linear',
                tail_bound=args.tail_bound,
                num_blocks=args.num_transform_blocks,
                use_residual_blocks=True,
                random_mask=False,
                activation=F.relu,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm)
        else:
            raise ValueError

    # ---------------
    # prior
    # ---------------
    def create_prior():
        if args.prior_type == 'standard-normal':
            prior = distributions_.StandardNormal((args.latent_features, ))

        else:
            distribution = distributions_.StandardNormal(
                (args.latent_features, ))
            transform = transforms.CompositeTransform([
                transforms.CompositeTransform(
                    [create_linear_transform(),
                     create_base_transform(i)])
                for i in range(args.num_flow_steps)
            ])
            transform = transforms.CompositeTransform(
                [transform, create_linear_transform()])
            prior = flows.Flow(transform, distribution)

        return prior

    # ---------------
    # inputs encoder
    # ---------------
    def create_inputs_encoder():
        if args.approximate_posterior_type == 'diagonal-normal':
            inputs_encoder = None
        else:
            inputs_encoder = nn_.ConvEncoder(
                context_features=args.context_features,
                channels_multiplier=16,
                dropout_probability=args.dropout_probability_encoder_decoder)
        return inputs_encoder

    # ---------------
    # approximate posterior
    # ---------------
    def create_approximate_posterior():
        if args.approximate_posterior_type == 'diagonal-normal':
            context_encoder = nn_.ConvEncoder(
                context_features=args.context_features,
                channels_multiplier=16,
                dropout_probability=args.dropout_probability_encoder_decoder)
            approximate_posterior = distributions_.ConditionalDiagonalNormal(
                shape=[args.latent_features], context_encoder=context_encoder)

        else:
            context_encoder = nn.Linear(args.context_features,
                                        2 * args.latent_features)
            distribution = distributions_.ConditionalDiagonalNormal(
                shape=[args.latent_features], context_encoder=context_encoder)

            transform = transforms.CompositeTransform([
                transforms.CompositeTransform([
                    create_linear_transform(),
                    create_base_transform(
                        i, context_features=args.context_features)
                ]) for i in range(args.num_flow_steps)
            ])
            transform = transforms.CompositeTransform(
                [transform, create_linear_transform()])
            approximate_posterior = flows.Flow(
                transforms.InverseTransform(transform), distribution)

        return approximate_posterior

    # ---------------
    # likelihood
    # ---------------
    def create_likelihood():
        latent_decoder = nn_.ConvDecoder(
            latent_features=args.latent_features,
            channels_multiplier=16,
            dropout_probability=args.dropout_probability_encoder_decoder)

        likelihood = distributions_.ConditionalIndependentBernoulli(
            shape=[1, 28, 28], context_encoder=latent_decoder)

        return likelihood

    prior = create_prior()
    approximate_posterior = create_approximate_posterior()
    likelihood = create_likelihood()
    inputs_encoder = create_inputs_encoder()

    model = vae.VariationalAutoencoder(
        prior=prior,
        approximate_posterior=approximate_posterior,
        likelihood=likelihood,
        inputs_encoder=inputs_encoder)

    n_params = utils.get_num_parameters(model)
    print('There are {} trainable parameters in this model.'.format(n_params))

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=args.num_training_steps, eta_min=0)

    def get_kl_multiplier(step):
        if args.kl_multiplier_schedule == 'constant':
            return args.kl_multiplier_initial
        elif args.kl_multiplier_schedule == 'linear':
            multiplier = min(
                step / (args.num_training_steps * args.kl_warmup_fraction), 1.)
            return args.kl_multiplier_initial * (1. + multiplier)

    # create summary writer and write to log directory
    timestamp = cutils.get_timestamp()
    if cutils.on_cluster():
        timestamp += '||{}'.format(os.environ['SLURM_JOB_ID'])
    log_dir = os.path.join(cutils.get_log_root(), args.dataset_name, timestamp)
    while True:
        try:
            writer = SummaryWriter(log_dir=log_dir, max_queue=20)
            break
        except FileExistsError:
            sleep(5)
    filename = os.path.join(log_dir, 'config.json')
    with open(filename, 'w') as file:
        json.dump(vars(args), file)

    best_val_elbo = -np.inf
    tbar = tqdm(range(args.num_training_steps))
    for step in tbar:
        model.train()
        optimizer.zero_grad()

        batch = next(train_generator)[0].to(device)
        elbo = model.stochastic_elbo(batch,
                                     kl_multiplier=get_kl_multiplier(step))
        loss = -torch.mean(elbo)
        loss.backward()
        optimizer.step()
        scheduler.step(step)

        if (step + 1) % args.monitor_interval == 0:
            model.eval()
            with torch.no_grad():
                elbo = model.stochastic_elbo(val_batch.to(device))
                mean_val_elbo = elbo.mean()

            if mean_val_elbo > best_val_elbo:
                best_val_elbo = mean_val_elbo
                path = os.path.join(
                    cutils.get_checkpoint_root(),
                    '{}-best-val-{}.t'.format(args.dataset_name, timestamp))
                torch.save(model.state_dict(), path)

            writer.add_scalar(tag='val-elbo',
                              scalar_value=mean_val_elbo,
                              global_step=step)

            writer.add_scalar(tag='best-val-elbo',
                              scalar_value=best_val_elbo,
                              global_step=step)

            with torch.no_grad():
                samples = model.sample(64)
            fig, ax = plt.subplots(figsize=(10, 10))
            cutils.gridimshow(make_grid(samples.view(64, 1, 28, 28), nrow=8),
                              ax)
            writer.add_figure(tag='vae-samples', figure=fig, global_step=step)
            plt.close()

    # load best val model
    path = os.path.join(
        cutils.get_checkpoint_root(),
        '{}-best-val-{}.t'.format(args.dataset_name, timestamp))
    model.load_state_dict(torch.load(path))
    model.eval()

    np.random.seed(5)
    torch.manual_seed(5)

    # compute elbo on test set
    with torch.no_grad():
        elbo = torch.Tensor([])
        log_prob_lower_bound = torch.Tensor([])
        for batch in tqdm(test_loader):
            elbo_ = model.stochastic_elbo(batch[0].to(device))
            elbo = torch.cat([elbo, elbo_])
            log_prob_lower_bound_ = model.log_prob_lower_bound(
                batch[0].to(device), num_samples=1000)
            log_prob_lower_bound = torch.cat(
                [log_prob_lower_bound, log_prob_lower_bound_])
    path = os.path.join(
        log_dir, '{}-prior-{}-posterior-{}-elbo.npy'.format(
            args.dataset_name, args.prior_type,
            args.approximate_posterior_type))
    np.save(path, utils.tensor2numpy(elbo))
    path = os.path.join(
        log_dir, '{}-prior-{}-posterior-{}-log-prob-lower-bound.npy'.format(
            args.dataset_name, args.prior_type,
            args.approximate_posterior_type))
    np.save(path, utils.tensor2numpy(log_prob_lower_bound))

    # save elbo and log prob lower bound
    mean_elbo = elbo.mean()
    std_elbo = elbo.std()
    mean_log_prob_lower_bound = log_prob_lower_bound.mean()
    std_log_prob_lower_bound = log_prob_lower_bound.std()
    s = 'ELBO: {:.2f} +- {:.2f}, LOG PROB LOWER BOUND: {:.2f} +- {:.2f}'.format(
        mean_elbo.item(), 2 * std_elbo.item() / np.sqrt(len(test_dataset)),
        mean_log_prob_lower_bound.item(),
        2 * std_log_prob_lower_bound.item() / np.sqrt(len(test_dataset)))
    filename = os.path.join(log_dir, 'test-results.txt')
    with open(filename, 'w') as file:
        file.write(s)
Ejemplo n.º 18
0
 def _get_dataset(self, train: bool, transform: Any) -> torch.utils.data.Dataset:
     return datasets.Omniglot(self.data_folder, background=train, download=False, transform=transform)
Ejemplo n.º 19
0
import torch
import torch.nn.functional as F
import utils
from torchvision import datasets, transforms
import VAE_model as vae
import easy_conv_vae as conv_vae
import maml_class
import numpy as np
import pickle
from torch import optim
import matplotlib.pyplot as plt
import random
from tqdm import tqdm
import seaborn as sns

data = datasets.Omniglot(root='./data', download=True)
source_task_number=500
task_set_number = 200
data_set = utils.get_dataset(data,50, 20, 10)
random.shuffle(data_set)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

epochs=200


betas = [1.]
for j,beta in enumerate(betas):
    model = conv_vae.VAE(z_dim=128)
    model.to(device)
    maml = maml_class.MAML(model=model, data=data_set, inner_lr=1e-2,
Ejemplo n.º 20
0
def train_gan(opt):

    os.makedirs(os.path.join(opt.savingroot, opt.dataset, 'images'),
                exist_ok=True)
    os.makedirs(os.path.join(opt.savingroot, opt.dataset, 'chkpts'),
                exist_ok=True)

    #Build networ
    if opt.model_type == 'sa':
        AC = False
        if opt.loss_type == 'Projection':
            AC = False
        elif opt.loss_type == 'Twin_AC':
            AC = True
        elif opt.loss_type == 'AC':
            AC = True
        netd_g = nn.DataParallel(
            SA_Discriminator(n_class=opt.num_classes,
                             nc=opt.nc,
                             AC=AC,
                             Resolution=opt.image_size,
                             ch=64).cuda())
        netg = nn.DataParallel(
            SA_Generator(n_class=opt.num_classes,
                         code_dim=opt.nz,
                         nc=opt.nc,
                         SN=opt.SN,
                         Resolution=opt.image_size,
                         ch=32).cuda())
    elif opt.model_type == 'big':
        AC = False
        if opt.loss_type == 'Projection':
            AC = False
        elif opt.loss_type == 'Twin_AC':
            AC = True
        elif opt.loss_type == 'AC':
            AC = True

        netd_g = nn.DataParallel(
            Discriminator(n_classes=opt.num_classes,
                          resolution=opt.image_size,
                          AC=AC).cuda())
        netg = nn.DataParallel(
            Generator(n_classes=opt.num_classes,
                      resolution=opt.image_size,
                      SN=opt.SN).cuda())

    if opt.data_r == 'MNIST':
        dataset = dset.MNIST(root=opt.dataroot, download=True, transform=tsfm)
    elif opt.data_r == 'CIFAR10':
        dataset = dset.CIFAR10(root=opt.dataroot,
                               download=True,
                               transform=tsfm)
    elif opt.data_r == 'CIFAR100':
        dataset = dset.CIFAR100(root=opt.dataroot,
                                download=True,
                                transform=tsfm)
    elif opt.data_r == 'CUB':
        dataset = dset.ImageFolder(
            root='/home/yanwuxu/CUB_200_2011_processed/ImageNet/ImageNet/',
            transform=tsfm)
    elif opt.data_r == 'VGGFACE':
        dataset = ILSVRC_HDF5(root='../data/VGGFACE64.hdf5', transform=tsfm)
    elif opt.data_r == 'IMAGENET100':
        dataset = Load_numpy_data(root='../data/ImageNet100.pt',
                                  transform=tsfm)
    elif opt.data_r == 'MNIST_overlap':
        dataset = Load_gray_data(root='../data/overlap_MNIST.pt',
                                 transform=tsfm)
    elif opt.data_r == 'OMNIGLOT':
        dataset = dset.Omniglot(
            '../result', transform=tsfm,
            download=True)  #Load_gray_data(root='omniglot.pt', transform=tsfm)

    print('training_start')
    print(opt.loss_type)

    step = 0

    train_g(netd_g, netg, dataset, step, opt)
import utils
import easy_conv_vae as conv_vae
import pickle
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

number_of_tasks = 500
train_number = 20 * number_of_tasks

transform = transforms.Compose(
    [transforms.Resize((28, 28)),
     transforms.ToTensor()])
data = datasets.Omniglot(root='./data', transform=transform)
train_set = list(data)[:train_number]
test_set = list(data)[train_number:train_number + 1000]
train_loader = torch.utils.data.DataLoader(train_set,
                                           batch_size=20,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set,
                                          batch_size=20,
                                          shuffle=True)

model = conv_vae.VAE(z_dim=128).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 20
beta = 1.
# Reconstruction + KL divergence losses summed over all elements and batch
Ejemplo n.º 22
0
        nn.init.normal_(m.weight,0,2e-2)
        nn.init.normal_(m.bias, 0.5, 1e-2)

#testing
net=Siamese_Net()
# print(net)
dummyx1=torch.randn(64,1,105,105)
dummyx2=torch.randn(64,1,105,105)
o=net(dummyx1,dummyx2)
print(o.shape)

"""# **Omniglot Dataset**"""

# download the dataset
omni_train=datasets.Omniglot(root='/content', 
                       background= True,  
                       download = True
                       )

omni_test=datasets.Omniglot(root='/content', 
                       background= False,  
                       download = True
                       )

# Omniglot dataset- didnt rename cause would have had to rename it everywhere
class Face_Dataset(Dataset):
  def __init__(self,root_dir, job='train',ways=10,transform=None):
    super(Face_Dataset,self).__init__()
    self.root_dir=root_dir
    self.job=job
    self.all_classes=os.listdir(root_dir)
    self.num_classes=len(self.all_classes)
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='run approximation to LeNet on Mnist')
    parser.add_argument('--batch-size',
                        type=int,
                        default=256,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=100,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--dropout-rate',
                        type=float,
                        default=0.5,
                        metavar='p_drop',
                        help='dropout rate')
    parser.add_argument(
        '--S',
        type=int,
        default=500,
        metavar='N',
        help='number of posterior samples from the Bayesian model')
    parser.add_argument(
        '--model-path',
        type=str,
        default='../saved_models/mnist_sgld/',
        metavar='N',
        help='number of posterior samples from the Bayesian model')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {}

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))])),
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    model = mnist_mlp(dropout=False).to(device)
    optimizer = SGLD(model.parameters(), lr=args.lr)

    import copy
    import pickle as pkl

    for epoch in range(1, args.epochs + 1):
        train_bayesian(args, model, device, train_loader, optimizer, epoch)
        print("epoch: {}".format(epoch))
        test(args, model, device, test_loader)

        # save models
        torch.save(model.state_dict(), args.model_path + 'sgld-mnist.pt')

    # save samples
    param_samples = []
    while (1):
        for idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            data = data.view(data.shape[0], -1)
            output = model(data)
            loss = F.nll_loss(F.log_softmax(output, dim=1), target)
            loss.backward()
            optimizer.step()
            param_samples.append(copy.deepcopy(model.state_dict()))
            if param_samples.__len__() >= args.S:
                print('1', len(param_samples))
                break
        if param_samples.__len__() >= args.S:
            print('2', len(param_samples))
            break
    with open(args.model_path + "sgld_samples.pkl", "wb") as f:
        print('3', len(param_samples))
        pkl.dump(param_samples, f)

    test(args, model, device, test_loader)

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))])),
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))])),
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    # generate teacher train samples

    with torch.no_grad():
        # obtain ensemble outputs
        all_samples = []
        for i in range(500):
            samples_a_round = []
            model.load_state_dict(param_samples[i])
            for data, target in train_loader:
                data = data.to(device)
                data = data.view(data.shape[0], -1)
                output = F.softmax(model(data))
                samples_a_round.append(output)
            samples_a_round = torch.cat(samples_a_round).cpu()
            all_samples.append(samples_a_round)
        all_samples = torch.stack(all_samples).permute(1, 0, 2)

        torch.save(all_samples,
                   args.model_path + 'mnist-sgld-train-samples.pt')

    # generate teacher test  samples

    with torch.no_grad():
        # obtain ensemble outputs
        all_samples = []
        for i in range(500):
            samples_a_round = []
            model.load_state_dict(param_samples[i])
            for data, target in test_loader:
                data = data.to(device)
                data = data.view(data.shape[0], -1)
                output = F.softmax(model(data))
                samples_a_round.append(output)
            samples_a_round = torch.cat(samples_a_round).cpu()
            all_samples.append(samples_a_round)
        all_samples = torch.stack(all_samples).permute(1, 0, 2)

        torch.save(all_samples, args.model_path + 'mnist-sgld-test-samples.pt')

    # generate teacher omniglot samples

    ood_data = datasets.Omniglot(
        '../../data',
        download=True,
        transform=transforms.Compose([
            # transforms.ToPILImage(),
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, )),
        ]))

    ood_loader = torch.utils.data.DataLoader(ood_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             **kwargs)

    with torch.no_grad():
        # obtain ensemble outputs
        all_samples = []
        for i in range(500):
            samples_a_round = []
            model.load_state_dict(param_samples[i])
            for data, target in ood_loader:
                data = data.to(device)
                data = data.view(data.shape[0], -1)
                output = F.softmax(model(data))
                samples_a_round.append(output)
            samples_a_round = torch.cat(samples_a_round).cpu()
            all_samples.append(samples_a_round)
        all_samples = torch.stack(all_samples).permute(1, 0, 2)

        torch.save(all_samples,
                   args.model_path + 'mnist-sgld-omniglot-samples.pt')

    # generate teacher SEMEION samples

    ood_data = datasets.SEMEION(
        '../../data',
        download=True,
        transform=transforms.Compose([
            # transforms.ToPILImage(),
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, )),
        ]))

    ood_loader = torch.utils.data.DataLoader(ood_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             **kwargs)

    with torch.no_grad():
        # obtain ensemble outputs
        all_samples = []
        for i in range(500):
            samples_a_round = []
            model.load_state_dict(param_samples[i])
            for data, target in ood_loader:
                data = data.to(device)
                data = data.view(data.shape[0], -1)
                output = F.softmax(model(data))
                samples_a_round.append(output)
            samples_a_round = torch.cat(samples_a_round).cpu()
            all_samples.append(samples_a_round)
        all_samples = torch.stack(all_samples).permute(1, 0, 2)

        torch.save(all_samples,
                   args.model_path + 'mnist-sgld-SEMEION-samples.pt')
Ejemplo n.º 24
0
def get_datasets(
    dataset_name,
    frac_val=FRAC_VAL,
    batch_size=8,
    img_shape=None,
    nn_architecture=None,
    train_params=None,
    synthetic_params=None,
    class_2d=None,
    kwargs=None,
):

    if kwargs is None:
        kwargs = KWARGS

    img_shape_no_channel = None
    if img_shape is not None:
        img_shape_no_channel = img_shape[1:]
    # TODO(nina): Consistency in datasets: add channels for all
    logging.info("Loading data from dataset: %s" % dataset_name)
    if dataset_name == "mnist":
        train_dataset, val_dataset = get_dataset_mnist()
    elif dataset_name == "omniglot":
        if img_shape_no_channel is not None:
            transform = transforms.Compose([
                transforms.Resize(img_shape_no_channel),
                transforms.ToTensor()
            ])
        else:
            transform = transforms.ToTensor()
        dataset = datasets.Omniglot("../data",
                                    download=True,
                                    transform=transform)
        train_dataset, val_dataset = split_dataset(dataset, frac_val=frac_val)
    elif dataset_name in [
            "cryo_sim",
            "randomrot1D_nodisorder",
            "randomrot1D_multiPDB",
            "randomrot_nodisorder",
    ]:
        dataset = get_dataset_cryo(dataset_name, img_shape_no_channel, kwargs)
        train_dataset, val_dataset = split_dataset(dataset)
    elif dataset_name == "cryo_sphere":
        dataset = get_dataset_cryo_sphere(img_shape_no_channel, kwargs)
        train_dataset, val_dataset = split_dataset(dataset)
    elif dataset_name == "cryo_exp":
        dataset = get_dataset_cryo_exp(img_shape_no_channel, kwargs)
        train_dataset, val_dataset = split_dataset(dataset)
    elif dataset_name == "cryo_exp_class_2d":
        dataset = get_dataset_cryo_exp_class_2d(img_shape_no_channel,
                                                class_2d)  # , kwargs)
        train_dataset, val_dataset = split_dataset(dataset)
    elif dataset_name == "cryo_exp_3d":
        dataset = get_dataset_cryo_exp_3d(img_shape_no_channel, kwargs)
        train_dataset, val_dataset = split_dataset(dataset)
    elif dataset_name == "connectomes":
        train_dataset, val_dataset = get_dataset_connectomes(
            img_shape_no_channel=img_shape_no_channel)
    elif dataset_name == "connectomes_simu":
        train_dataset, val_dataset = get_dataset_connectomes_simu(
            img_shape_no_channel=img_shape_no_channel)
    elif dataset_name == "connectomes_schizophrenia":
        train_dataset, val_dataset, _ = get_dataset_connectomes_schizophrenia()
    elif dataset_name in ["mri", "segmentation", "fmri"]:
        train_loader, val_loader = get_loaders_brain(dataset_name, frac_val,
                                                     batch_size,
                                                     img_shape_no_channel,
                                                     kwargs)
        return train_loader, val_loader
    elif dataset_name == "synthetic":
        dataset = make_synthetic_dataset_and_decoder(
            synthetic_params=synthetic_params,
            nn_architecture=nn_architecture,
            train_params=train_params,
        )
        train_dataset, val_dataset = split_dataset(dataset)
    else:
        raise ValueError("Unknown dataset name: %s" % dataset_name)

    return train_dataset, val_dataset
Ejemplo n.º 25
0
    batch_size = 32
    max_length = 15

    rescaling = lambda x: (x - .5) * 2.
    rescaling_inv = lambda x: .5 * x + .5
    flip = lambda x: -x
    kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True}
    resizing = lambda x: x.resize((28, 28))
    omni_transforms = transforms.Compose(
        [resizing, transforms.ToTensor(), rescaling,
         flip])  #TODO: check this, but i think i don't want rescaling

    kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True}
    train_loader = torch.utils.data.DataLoader(datasets.Omniglot(
        '../vhe/data',
        download=True,
        background=True,
        transform=omni_transforms),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               **kwargs)

    test_loader = torch.utils.data.DataLoader(datasets.Omniglot(
        '../vhe/data',
        download=True,
        background=False,
        transform=omni_transforms),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              **kwargs)
Ejemplo n.º 26
0
def load_dataset(
    root,
    name,
    dataset_size=None,
    transform=None,
    split="train",
    fraction=1.0,
    batch_size=None,
):
    name = name.lower().replace("-", "_")
    assert split in ("train", "test")

    if name == "mnist":
        data = datasets.MNIST(
            os.path.join(root, "MNIST"),
            train=split == "train",
            download=True,
            transform=transform,
        )
    elif name == "fashion-mnist":
        data = datasets.FashionMNIST(
            os.path.join(root, "FASHION-MNIST"),
            train=split == "train",
            download=True,
            transform=transform,
        )
    elif name == "cifar10":
        data = datasets.CIFAR10(
            os.path.join(root, "CIFAR10"),
            train=split == "train",
            download=True,
            transform=transform,
        )
    elif name == "svhn":
        data = datasets.SVHN(
            os.path.join(root, "SVHN"), split=split, download=True, transform=transform
        )
    elif name == "stl10":
        if split == "train":
            split += "+unlabeled"
        data = datasets.STL10(
            os.path.join(root, "STL10"), split=split, download=True, transform=transform
        )
    elif name == "lsun-bed":
        data = datasets.LSUN(
            os.path.join(root, "LSUN"), classes=["bedroom_train"], transform=transform
        )
    elif name == "omniglot":
        data = datasets.Omniglot(
            os.path.join(root, "OMNIGLOT"),
            background=split == "train",
            download=True,
            transform=transform,
        )
    elif name.startswith("test"):
        _, c, d = name.split("_")
        data = datasets.FakeData(
            size=dataset_size,
            image_size=(int(c), int(d), int(d)),
            num_classes=2,
            transform=transform,
        )
        data.labels = np.random.randint(0, 2, dataset_size)
    elif name == "scaly":
        data = SingleFolderDataset(os.path.join(root, "SCALY"), transform=transform)
    elif name == "celeba":
        center_crop = transforms.CenterCrop(178)
        # deal with non squared celebA
        transform = transforms.Compose([center_crop, transform])
        data = SingleFolderDataset(
            os.path.join(root, "celebA", "img_align_celeba"), transform=transform
        )
    else:
        raise NotImplementedError(name)
    assert 0 < fraction <= 1
    assert dataset_size is None or dataset_size <= len(data)
    if dataset_size is None:
        dataset_size = len(data)
    size = min(int(fraction * len(data)), dataset_size)
    size = max(
        size, batch_size if batch_size is not None else 0, torch.cuda.device_count(), 1
    )
    if size < len(data):
        points = np.random.choice(range(len(data)), replace=False, size=size)
        data = torch.utils.data.Subset(data, points)
    return IgnoreLabelDataset(data)
Ejemplo n.º 27
0
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        x = x.view(x.size(0), -1)

        return self.logits(x)


if __name__ == "__main__":

    trans = transforms.Compose(
        [transforms.Resize((28, 28)),
         transforms.ToTensor()])
    tasks = Omniglot_Task_Distribution(
        datasets.Omniglot('./Omniglot/', transform=trans), 20)
    N, K = 5, 5
    task = tasks.sample_task(N, K, 15)
    meta_model = Classifier(N)
    maml = MAML(meta_model.cuda(),
                tasks,
                inner_lr=0.01,
                meta_lr=0.001,
                K=10,
                inner_steps=1,
                tasks_per_meta_batch=32,
                criterion=nn.CrossEntropyLoss())
    maml.main_loop(num_iterations=100)
Ejemplo n.º 28
0
elif 'cifar' in args.dataset : 
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True, 
        download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
    
    test_loader  = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False, 
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
    
    if args.nr_logistic_mix:
        loss_op   = lambda real, fake : discretized_mix_logistic_loss(real, fake)
        sample_op = lambda x : sample_from_discretized_mix_logistic(x, args.nr_logistic_mix)
    else:
        raise NotImplementedError("No 3D Softmax")

elif 'omni' in args.dataset :

    train_loader = torch.utils.data.DataLoader(datasets.Omniglot(args.data_dir, download=True, 
                        background=True, transform=omni_transforms), batch_size=1, 
                            shuffle=True, **kwargs)

    #d = datasets.Omniglot(args.data_dir, download=True, 
    #                   background=True, transform=omni_transforms)
    
    test_loader = torch.utils.data.DataLoader(datasets.Omniglot(args.data_dir, download=True, 
                        background=False, transform=omni_transforms), batch_size=1, 
                            shuffle=True, **kwargs)
    
    if args.nr_logistic_mix:
        loss_op   = lambda real, fake : discretized_mix_logistic_loss_1d(real, fake)
        sample_op = lambda x : sample_from_discretized_mix_logistic_1d(x, args.nr_logistic_mix)
    else:
        loss_op   = lambda real, fake : softmax_loss_1d(real, fake)
        sample_op = lambda x : sample_from_softmax_1d(x)
Ejemplo n.º 29
0
import torchvision.datasets as tvd

if __name__ == "__main__":
    data_dir = "./data"

    for train in [True, False]:
        tvd.MNIST(data_dir, download=True, train=train)
        tvd.CIFAR10(data_dir, download=True, train=train)
        tvd.Omniglot(data_dir, download=True, background=train)
Ejemplo n.º 30
0
 def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, background=True):
     self.data_dir = data_dir
     self.dataset = datasets.Omniglot(self.data_dir, background=background, download=True, transform=omni_transforms)
     super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)