def sample_model(hypernet, arch):
    w_batch = utils.sample_hypernet(hypernet)
    rand = np.random.randint(32)
    sample_w = (w_batch[0][rand], w_batch[1][rand], w_batch[2][rand])
    model = utils.weights_to_clf(sample_w, arch, args.stat['layer_names'])
    model.eval()
    return model
Example #2
0
def measure_acc(args, hypernet, arch):
    _, test_loader = datagen.load_mnist(args)
    test_loss = 0
    correct = 0.
    criterion = nn.CrossEntropyLoss()
    e1, e5, e10, e100 = 0., 0., 0., 0.
    for n in [1, 5, 10, 100]:
        test_acc = 0.
        test_loss = 0.
        weights = utils.sample_hypernet(hypernet, n)
        for i, (data, y) in enumerate(mnist_test):
            n_votes = []
            for k in range(n):
                sample_w = (weights[0][k], weights[1][k], weights[2][k])
                model = utils.weights_to_clf(sample_w, arch,
                                             args.stat['layer_names'])
                votes = model(data)
                n_votes.append(votes.cpu().numpy())
            votes = np.array(n_votes)
            vote_modes = stats.mode(votes, axis=0)[0]
            vote_modes = torch.tensor(vote_modes)
            if n == 2:
                e1 += vote_modes.eq(
                    y.data.view_as(vote_modes)).long().cpu().sum()
            elif n == 5:
                e5 += vote_modes.eq(
                    y.data.view_as(vote_modes)).long().cpu().sum()
            elif n == 10:
                e10 += vote_modes.eq(
                    y.data.view_as(vote_modes)).long().cpu().sum()
            elif n == 100:
                e100 += vote_modes.eq(
                    y.data.view_as(vote_modes)).long().cpu().sum()

    test_loss /= len(mnist_test.dataset) * args.batch_size
    test_acc /= len(mnist_test.dataset) * args.batch_size
    e1 = e1.item() / len(mnist_test.dataset)
    e5 = e5.item() / len(mnist_test.dataset)
    e10 = e10.item() / len(mnist_test.dataset)
    e100 = e100.item() / len(mnist_test.dataset)
    print('Test Accuracy: {}, Test Loss: {}'.format(test_acc, test_loss))
Example #3
0
def load_models(args, path):

    model = get_network(args)
    paths = glob(path + '*.pt')
    print(path)
    paths = [path for path in paths if 'mnist' in path]
    natpaths = natsort.natsorted(paths)
    accs = []
    losses = []
    natpaths = [x for x in natpaths if 'hypermnist_mi_0.987465625' in x]
    for i, path in enumerate(natpaths):
        print("loading model {}".format(path))
        if args.hyper:
            hn = utils.load_hypernet(path)
            for i in range(10):
                samples = utils.sample_hypernet(hn)
                print('sampled a batches of {} networks'.format(len(
                    samples[0])))
                for i, sample in enumerate(
                        zip(samples[0], samples[1], samples[2])):
                    model = utils.weights_to_clf(sample, model,
                                                 args.stat['layer_names'])
                    acc, loss = test(args, model)
                    print(i, ': Test Acc: {}, Loss: {}'.format(acc, loss))
                    accs.append(acc)
                    losses.append(loss)
                    #acc, loss = train(args, model)
                    #print ('Test1 Acc: {}, Loss: {}'.format(acc, loss))
                    #extract_weights_all(args, model, i)
            print(accs, losses)
        else:
            ckpt = torch.load(path)
            state = ckpt['state_dict']
            try:
                model.load_state_dict()
            except RuntimeError:
                model_dict = model.state_dict()
                filtered = {k: v for k, v in state.items() if k in model_dict}
                model_dict.update(filtered)
                model.load_state_dict(filtered)