Exemplo n.º 1
0
def sample_model(hypernet, arch):
    w_batch = utils.sample_hypernet(hypernet)
    rand = np.random.randint(32)
    sample_w = (w_batch[0][rand], w_batch[1][rand], w_batch[2][rand])
    model = utils.weights_to_clf(sample_w, arch, args.stat['layer_names'])
    model.eval()
    return model
Exemplo n.º 2
0
def measure_acc(args, hypernet, arch):
    _, test_loader = datagen.load_mnist(args)
    test_loss = 0
    correct = 0.
    criterion = nn.CrossEntropyLoss()
    e1, e5, e10, e100 = 0., 0., 0., 0.
    for n in [1, 5, 10, 100]:
        test_acc = 0.
        test_loss = 0.
        weights = utils.sample_hypernet(hypernet, n)
        for i, (data, y) in enumerate(mnist_test):
            n_votes = []
            for k in range(n):
                sample_w = (weights[0][k], weights[1][k], weights[2][k])
                model = utils.weights_to_clf(sample_w, arch,
                                             args.stat['layer_names'])
                votes = model(data)
                n_votes.append(votes.cpu().numpy())
            votes = np.array(n_votes)
            vote_modes = stats.mode(votes, axis=0)[0]
            vote_modes = torch.tensor(vote_modes)
            if n == 2:
                e1 += vote_modes.eq(
                    y.data.view_as(vote_modes)).long().cpu().sum()
            elif n == 5:
                e5 += vote_modes.eq(
                    y.data.view_as(vote_modes)).long().cpu().sum()
            elif n == 10:
                e10 += vote_modes.eq(
                    y.data.view_as(vote_modes)).long().cpu().sum()
            elif n == 100:
                e100 += vote_modes.eq(
                    y.data.view_as(vote_modes)).long().cpu().sum()

    test_loss /= len(mnist_test.dataset) * args.batch_size
    test_acc /= len(mnist_test.dataset) * args.batch_size
    e1 = e1.item() / len(mnist_test.dataset)
    e5 = e5.item() / len(mnist_test.dataset)
    e10 = e10.item() / len(mnist_test.dataset)
    e100 = e100.item() / len(mnist_test.dataset)
    print('Test Accuracy: {}, Test Loss: {}'.format(test_acc, test_loss))
Exemplo n.º 3
0
def load_models(args, path):

    model = get_network(args)
    paths = glob(path + '*.pt')
    print(path)
    paths = [path for path in paths if 'mnist' in path]
    natpaths = natsort.natsorted(paths)
    accs = []
    losses = []
    natpaths = [x for x in natpaths if 'hypermnist_mi_0.987465625' in x]
    for i, path in enumerate(natpaths):
        print("loading model {}".format(path))
        if args.hyper:
            hn = utils.load_hypernet(path)
            for i in range(10):
                samples = utils.sample_hypernet(hn)
                print('sampled a batches of {} networks'.format(len(
                    samples[0])))
                for i, sample in enumerate(
                        zip(samples[0], samples[1], samples[2])):
                    model = utils.weights_to_clf(sample, model,
                                                 args.stat['layer_names'])
                    acc, loss = test(args, model)
                    print(i, ': Test Acc: {}, Loss: {}'.format(acc, loss))
                    accs.append(acc)
                    losses.append(loss)
                    #acc, loss = train(args, model)
                    #print ('Test1 Acc: {}, Loss: {}'.format(acc, loss))
                    #extract_weights_all(args, model, i)
            print(accs, losses)
        else:
            ckpt = torch.load(path)
            state = ckpt['state_dict']
            try:
                model.load_state_dict()
            except RuntimeError:
                model_dict = model.state_dict()
                filtered = {k: v for k, v in state.items() if k in model_dict}
                model_dict.update(filtered)
                model.load_state_dict(filtered)
Exemplo n.º 4
0
def train(args):
    from torch import optim
    #torch.manual_seed(8734)
    netE = models.Encoderz(args).cuda()
    netD = models.DiscriminatorZ(args).cuda()
    E1 = models.GeneratorE1(args).cuda()
    E2 = models.GeneratorE2(args).cuda()
    #E3 = models.GeneratorE3(args).cuda()
    #E4 = models.GeneratorE4(args).cuda()
    #D1 = models.GeneratorD1(args).cuda()
    D1 = models.GeneratorD2(args).cuda()
    D2 = models.GeneratorD3(args).cuda()
    D3 = models.GeneratorD4(args).cuda()
    print(netE, netD)
    print(E1, E2, D1, D2, D3)

    optimE = optim.Adam(netE.parameters(),
                        lr=5e-4,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)
    optimD = optim.Adam(netD.parameters(),
                        lr=1e-4,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)

    Eoptim = [
        optim.Adam(E1.parameters(),
                   lr=1e-4,
                   betas=(0.5, 0.9),
                   weight_decay=1e-4),
        optim.Adam(E2.parameters(),
                   lr=1e-4,
                   betas=(0.5, 0.9),
                   weight_decay=1e-4),
        #optim.Adam(E3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4),
        #optim.Adam(E4.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    ]
    Doptim = [
        #optim.Adam(D1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4),
        optim.Adam(D1.parameters(),
                   lr=1e-4,
                   betas=(0.5, 0.9),
                   weight_decay=1e-4),
        optim.Adam(D2.parameters(),
                   lr=1e-4,
                   betas=(0.5, 0.9),
                   weight_decay=1e-4),
        optim.Adam(D3.parameters(),
                   lr=1e-4,
                   betas=(0.5, 0.9),
                   weight_decay=1e-4)
    ]

    Enets = [E1, E2]
    Dnets = [D1, D2, D3]

    best_test_loss = np.inf
    args.best_loss = best_test_loss

    mnist_train, mnist_test = datagen.load_mnist(args)
    x_dist = utils.create_d(args.ze)
    z_dist = utils.create_d(args.z)
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()
    print("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.pretrain_e:
        for j in range(100):
            x = utils.sample_d(x_dist, e_batch_size)
            z = utils.sample_d(z_dist, e_batch_size)
            codes = netE(x)
            for i, code in enumerate(codes):
                code = code.view(e_batch_size, args.z)
                mean_loss, cov_loss = ops.pretrain_loss(code, z)
                loss = mean_loss + cov_loss
                loss.backward(retain_graph=True)
            optimE.step()
            netE.zero_grad()
            print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print('Finished Pretraining Encoder')
                break

    print('==> Begin Training')
    for _ in range(args.epochs):
        for batch_idx, (data, target) in enumerate(mnist_train):
            netE.zero_grad()
            for optim in Eoptim:
                optim.zero_grad()
            for optim in Doptim:
                optim.zero_grad()
            z = utils.sample_d(x_dist, args.batch_size)
            codes = netE(z)
            for code in codes:
                noise = utils.sample_z_like((args.batch_size, args.z))
                d_real = netD(noise)
                d_fake = netD(code)
                d_real_loss = torch.log((1 - d_real).mean())
                d_fake_loss = torch.log(d_fake.mean())
                d_real_loss.backward(torch.tensor(-1,
                                                  dtype=torch.float).cuda(),
                                     retain_graph=True)
                d_fake_loss.backward(torch.tensor(-1,
                                                  dtype=torch.float).cuda(),
                                     retain_graph=True)
                d_loss = d_real_loss + d_fake_loss
            optimD.step()
            netD.zero_grad()
            z = utils.sample_d(x_dist, args.batch_size)
            codes = netE(z)
            Eweights, Dweights = [], []
            i = 0
            for net in Enets:
                Eweights.append(net(codes[i]))
                i += 1
            for net in Dnets:
                Dweights.append(net(codes[i]))
                i += 1
            d_real = []
            for code in codes:
                d = netD(code)
                d_real.append(d)

            netD.zero_grad()
            d_loss = torch.stack(d_real).log().mean() * 10.

            for layers in zip(*(Eweights + Dweights)):
                loss, _ = train_clf(args, layers, data, target)
                scaled_loss = args.beta * loss
                scaled_loss.backward(retain_graph=True)
                d_loss.backward(torch.tensor(-1, dtype=torch.float).cuda(),
                                retain_graph=True)
            optimE.step()
            for optim in Eoptim:
                optim.step()
            for optim in Doptim:
                optim.step()
            loss = loss.item()

            if batch_idx % 50 == 0:
                print('**************************************')
                print('AE MNIST Test, beta: {}'.format(args.beta))
                print('MSE Loss: {}'.format(loss))
                print('D loss: {}'.format(d_loss))
                print('best test loss: {}'.format(args.best_loss))
                print('**************************************')

            if batch_idx > 1 and batch_idx % 199 == 0:
                test_acc = 0.
                test_loss = 0.
                for i, (data, y) in enumerate(mnist_test):
                    z = utils.sample_d(x_dist, args.batch_size)
                    codes = netE(z)
                    Eweights, Dweights = [], []
                    i = 0
                    for net in Enets:
                        Eweights.append(net(codes[i]))
                        i += 1
                    for net in Dnets:
                        Dweights.append(net(codes[i]))
                        i += 1
                    for layers in zip(*(Eweights + Dweights)):
                        loss, out = train_clf(args, layers, data, y)
                        test_loss += loss.item()
                    if i == 10:
                        break
                test_loss /= 10 * len(y) * args.batch_size
                print('Test Loss: {}'.format(test_loss))
                if test_loss < best_test_loss:
                    print('==> new best stats, saving')
                    #utils.save_clf(args, z_test, test_acc)
                    if test_loss < best_test_loss:
                        best_test_loss = test_loss
                        args.best_loss = test_loss
                archE = sampleE(args).cuda()
                archD = sampleD(args).cuda()
                rand = np.random.randint(args.batch_size)
                eweight = list(zip(*Eweights))[rand]
                dweight = list(zip(*Dweights))[rand]
                modelE = utils.weights_to_clf(eweight, archE,
                                              args.statE['layer_names'])
                modelD = utils.weights_to_clf(dweight, archD,
                                              args.statD['layer_names'])
                utils.generate_image(args, batch_idx, modelE, modelD,
                                     data.cuda())