Exemplo n.º 1
0
def save_clf(args, Z, acc):
    """ gross """
    if args.dataset == 'mnist':
        import models.mnist_clf as models
        model = models.Small2().cuda()
    elif args.dataset == 'cifar':
        import models.cifar_clf as models
        model = models.MedNet().cuda() 
    """ end gross """

    state = model.state_dict()
    layers = zip(args.stat['layer_names'], Z)
    for i, (name, params) in enumerate(layers):
        name = name + '.weight'
        loader = state[name]
        state[name] = params.detach()
        assert state[name].equal(loader) == False
        model.load_state_dict(state)
    #import cifar
    #ac, loss = cifar.test(args, model, 0)
    #print ('acc: {}, loss: {}'.format(ac, loss))
    path = 'exp_models/hyper{}_clf_{}_{}.pt'.format(args.dataset, args.exp, acc)
    if args.scratch:
        path = '/scratch/eecs-share/ratzlafn/HyperGAN/' + path
    print ('saving hypernet to {}'.format(path))
    torch.save({'state_dict': model.state_dict()}, path)
Exemplo n.º 2
0
def get_network(args):
    if args.net == 'small':
        model = models.Small().cuda()
    elif args.net == 'small2':
        model = models.Small2().cuda()
    else:
        raise NotImplementedError
    return model
Exemplo n.º 3
0
        pop_mean = pop_outputs.mean(0).view(10000, 10)
        ent = entropy(pop_mean.cpu().numpy().T)
    return ent



def get_stats(l1, l2, l3, inspect):
    norms = []
    for i in range(len(inspect)):
        norms.append(np.linalg.norm(inspect[i].detach()))
    m, s = np.array(norms).mean(), np.array(norms).std()
    return m, s


args = arg.load_mnist_args()
model = mnist_clf.Small2().cuda()
modeldef = netdef.nets()[args.target]
names = modeldef['layer_names']
paths = natsorted(glob('saved_models/mnist/noadds/*.pt'))
paths = sorted(paths, key=lambda i: float(i[44:-3]))
accs, losses, cv = [], [], []
for path in paths:
    netE, netD, W1, W2, W3 = utils.load_hypernet_mnist(args, path)
    l1, l2, l3, codes = utils.sample_hypernet_mnist(args, [netE.eval(), W1.eval(), W2.eval(), W3.eval()], 1)
    acc, loss = test_mnist(args, [l1, l2, l3], names, model, 1)
    m, s = get_stats(l1, l2, l3, inspect=l2)
    print ('Acc: {}, Loss: {} Nmean: {} Nstd: {}\n'.format(acc, loss, m, s))
    accs.append(acc)
    losses.append(loss)
    cv.append(s/m)
accs = np.array(accs)