def eval_cifar5_hypergan(hypergan, ens_size, s_dim, outlier=False):
    hypergan.eval_()
    if outlier is True:
        trainloader, testloader = datagen.load_cifar10()
    else:
        trainloader, testloader = datagen.load_cifar10()

    model_outputs = torch.zeros(ens_size, len(testloader.dataset), 10)
    model_outputs = torch.zeros(ens_size, len(testloader.dataset), 5)
    for i, (data, target) in enumerate(testloader):
        data = data.cuda()
        target = target.cuda()
        z = torch.randn(ens_size, s_dim).to(hypergan.device)
        codes = hypergan.mixer(z)
        params = hypergan.generator(codes)
        outputs = []
        for (layers) in zip(*params):
            output = hypergan.eval_f(layers, data)
            outputs.append(output)
        outputs = torch.stack(outputs)
        model_outputs[:, i * len(data):(i + 1) * len(data), :] = outputs

    # Soft Voting (entropy in confidence)
    probs_soft = F.softmax(outputs, dim=-1)  # [ens, data, 10]
    preds_soft = probs_soft.mean(0)  # [data, 10]
    entropy = entropy_fn(preds_soft.T.cpu().numpy())

    # Hard Voting (variance in predicted classed)
    probs_hard = F.softmax(outputs, dim=-1)  #[ens, data, 10]
    preds_hard = probs_hard.argmax(-1).cpu()  # [ens, data, 1]
    variance = preds_hard.var(0)  # [data, 1]
    hypergan.train_()

    return entropy, variance
Beispiel #2
0
if dataset is 'MNIST':
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Lambda(lambda x: x + noise * torch.randn(x.size())),
                                    transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = dsets.MNIST(root="data", download=True, transform=transform)
    val_dataset = dsets.MNIST(root="data", download=True, train=False, transform=transform)

elif dataset is 'CIFAR-10':
    if test == 'clean':
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
        train_dataset = dsets.CIFAR10(root="data", download=True, transform=transform)
        val_dataset = dsets.CIFAR10(root="data", download=True, train=False, transform=transform)
    else:
        loader_train, loader_test, loader_val = datagen.load_cifar10()


'''
MAKING DATASET ITERABLE
'''
if task == 'clean':
    print('length of training dataset:', len(train_dataset))
    n_iterations = num_epochs * (len(train_dataset) / batch_size)
    n_iterations = int(n_iterations)
    print('Number of iterations: ', n_iterations)

    loader_train = data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    loader_val = data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

Beispiel #3
0
    def __init__(self,
                 dataset,
                 base_model_fn,
                 num_particles=10,
                 resume=False,
                 resume_epoch=None,
                 resume_lr=1e-4):

        self.dataset = dataset
        self.num_particles = num_particles
        print("running {}-ensemble on {}".format(num_particles, dataset))

        self._use_wandb = False
        self._save_model = False

        if self.dataset == 'mnist':
            self.train_loader, self.test_loader, self.val_loader = datagen.load_mnist(
                split=True)
        elif self.dataset == 'cifar10':
            self.train_loader, self.test_loader, self.val_loader, = datagen.load_cifar10(
                split=True)
        else:
            raise NotImplementedError

        models = [
            base_model_fn(num_classes=6).cuda() for _ in range(num_particles)
        ]

        self.models = models
        self.optimizer = [
            torch.optim.Adam(m.parameters(), lr=1e-3, weight_decay=1e-4)
            for m in models
        ]
        self.schedulers = [
            torch.optim.lr_scheduler.StepLR(o, 100, .1) for o in self.optimizer
        ]

        if resume:
            print('resuming from epoch {}'.format(resume_epoch))
            d = torch.load('saved_models/{}/{}2/model_epoch_{}.pt'.format(
                self.dataset, model_id, resume_epoch))
            for model, sd in zip(self.models, d['models']):
                model.load_state_dict(sd)
            self.param_set = d['params']
            self.optimizer = [
                torch.optim.Adam(model.parameters(),
                                 lr=resume_lr,
                                 weight_decay=1e-4) for model in models
            ]
            self.start_epoch = resume_epoch
        else:
            self.start_epoch = 0

        loss_type = 'ce'
        if loss_type == 'ce':
            self.loss_fn = torch.nn.CrossEntropyLoss()
        elif loss_type == 'kliep':
            self.loss_fn = MattLoss().get_loss_dict()['kliep']

        if self._use_wandb:
            wandb.init(project="open-category-experiments", name="MLE CIFAR")
            for model in models:
                wandb.watch(model)
            config = wandb.config
            config.algo = 'ensemble'
            config.dataset = dataset
            config.kernel_fn = 'none'
            config.num_particles = num_particles
            config.loss_fn = loss_type
Beispiel #4
0
    def __init__(self,
                 algo,
                 dataset,
                 kernel_fn,
                 base_model_fn,
                 num_particles=10,
                 resume=False,
                 resume_epoch=None,
                 resume_lr=1e-4):

        self.algo = algo
        self.dataset = dataset
        self.kernel_fn = kernel_fn
        self.num_particles = num_particles
        print("running {} on {}".format(algo, dataset))

        if self.dataset == 'mnist':
            self.train_loader, self.test_loader, self.val_loader = datagen.load_mnist(
                split=True)
        elif self.dataset == 'cifar10':
            self.train_loader, self.test_loader, self.val_loader, = datagen.load_cifar10(
                split=True)
        else:
            raise NotImplementedError

        if kernel_fn == 'rbf':
            self.kernel = rbf_fn
            return_activations = False
        elif kernel_fn == 'cka':
            self.kernel = kernel_cka
            return_activations = True
        else:
            raise NotImplementedError

        models = [
            base_model_fn(num_classes=6,
                          return_activations=return_activations).cuda()
            for _ in range(num_particles)
        ]

        self.models = models
        param_set, state_dict = extract_parameters(self.models)

        self.state_dict = state_dict
        self.param_set = torch.nn.Parameter(param_set.clone(),
                                            requires_grad=True)

        self.optimizer = torch.optim.Adam([{
            'params': self.param_set,
            'lr': 1e-3,
            'weight_decay': 1e-4
        }])

        if resume:
            print('resuming from epoch {}'.format(resume_epoch))
            d = torch.load('saved_models/{}/{}2/model_epoch_{}.pt'.format(
                self.dataset, model_id, resume_epoch))
            for model, sd in zip(self.models, d['models']):
                model.load_state_dict(sd)
            self.param_set = d['params']
            self.state_dict = d['state_dict']
            self.optimizer = torch.optim.Adam([{
                'params': self.param_set,
                'lr': resume_lr,
                'weight_decay': 1e-4
            }])
            self.start_epoch = resume_epoch
        else:
            self.start_epoch = 0

        self.activation_length = self.models[0].activation_length
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.kernel_width_averager = Averager(shape=())
Beispiel #5
0
    def __init__(self,
                 algo,
                 dataset,
                 kernel_fn,
                 base_model_fn,
                 num_particles=50,
                 resume=False,
                 resume_epoch=None,
                 resume_lr=1e-4):

        self.algo = algo
        self.dataset = dataset
        self.kernel_fn = kernel_fn
        self.num_particles = num_particles
        print("running {} on {}".format(algo, dataset))

        self._use_wandb = False
        self._save_model = False

        if self.dataset == 'mnist':
            self.train_loader, self.test_loader, self.val_loader = datagen.load_mnist(
                split=True)
        elif self.dataset == 'cifar10':
            self.train_loader, self.test_loader, self.val_loader, = datagen.load_cifar10(
                split=True)
        else:
            raise NotImplementedError

        if kernel_fn == 'rbf':
            self.kernel = rbf_fn
        else:
            raise NotImplementedError

        models = [
            base_model_fn(num_classes=6).cuda() for _ in range(num_particles)
        ]

        self.models = models
        param_set, state_dict = extract_parameters(self.models)

        self.state_dict = state_dict
        self.param_set = torch.nn.Parameter(param_set.clone(),
                                            requires_grad=True)

        self.optimizer = torch.optim.Adam([{
            'params': self.param_set,
            'lr': 1e-3,
            'weight_decay': 1e-4
        }])

        if resume:
            print('resuming from epoch {}'.format(resume_epoch))
            d = torch.load('saved_models/{}/{}2/model_epoch_{}.pt'.format(
                self.dataset, model_id, resume_epoch))
            for model, sd in zip(self.models, d['models']):
                model.load_state_dict(sd)
            self.param_set = d['params']
            self.state_dict = d['state_dict']
            self.optimizer = torch.optim.Adam([{
                'params': self.param_set,
                'lr': resume_lr,
                'weight_decay': 1e-4
            }])
            self.start_epoch = resume_epoch
        else:
            self.start_epoch = 0

        loss_type = 'ce'
        if loss_type == 'ce':
            self.loss_fn = torch.nn.CrossEntropyLoss()
        elif loss_type == 'kliep':
            self.loss_fn = MattLoss().get_loss_dict()['kliep']
        self.kernel_width_averager = Averager(shape=())

        if self._use_wandb:
            wandb.init(project="open-category-experiments",
                       name="SVGD {}".format(self.dataset))
            for model in models:
                wandb.watch(model)
            config = wandb.config
            config.algo = algo
            config.dataset = dataset
            config.kernel_fn = kernel_fn
            config.num_particles = num_particles
            config.loss_fn = loss_type