Exemple #1
0
def test(args, model, epoch=None, grad=False):
    model.eval()
    if args.dataset == 'cifar':
        _, test_loader = datagen.load_cifar(args)
    elif args.dataset == 'cifar100':
        _, test_loader = datagen.load_10_class_cifar100(args)

    test_loss = 0.
    correct = 0.
    criterion = nn.CrossEntropyLoss()
    for i, (data, target) in enumerate(test_loader):
        data, target = data.cuda(), target.cuda()
        output = model(data)
        if grad is False:
            test_loss += criterion(output, target).item()
        else:
            test_loss += criterion(output, target)
        pred = output.data.max(1, keepdim=True)[1]
        output = None
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
    test_loss /= len(test_loader.dataset)
    acc = correct.item() / len(test_loader.dataset)
    if epoch:
        print('Epoch: {}, Average loss: {}, Accuracy: {}/{} ({}%)'.format(
            epoch, test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))

    return acc, test_loss
def test(args, Z, names, ext):
    model = weights_to_clf(Z, names)
    model.eval()
    if ext == False:
        _, test_loader = datagen.load_cifar(args)
    else:
        test_loader = load_cifar_ext(args)

    test_loss = 0.
    correct = 0.
    criterion = nn.CrossEntropyLoss()
    for i, (data, target) in enumerate(test_loader):
        data, target = data.cuda(), target.cuda()
        data = data.view(100, 3, 32, 32)
        target = target.view(-1)
        output = model(data)
        test_loss += criterion(output, target).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
    if ext is False:
        test_loss /= len(test_loader.dataset)
        acc = correct.item() / len(test_loader.dataset)
    else:
        test_loss /= 2000
        acc = correct.item() / 2000
    return acc, test_loss
Exemple #3
0
def train(args, model, grad=False):
    if args.dataset == 'cifar':
        train_loader, _ = datagen.load_cifar(args)
    elif args.dataset == 'cifar100':
        train_loader, _ = datagen.load_10_class_cifar100(args)

    train_loss, train_acc = 0., 0.
    criterion = nn.CrossEntropyLoss()
    if args.ft:
        for child in list(model.children())[:-1]:
            # print ('removing {}'.format(child))
            for param in child.parameters():
                param.requires_grad = False

    optimizer = optim.Adam(model.parameters(), lr=2e-3, weight_decay=1e-4)
    #optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    for epoch in range(args.epochs):
        model.train()
        total = 0
        correct = 0
        for i, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            acc = 100. * correct / total
        acc, loss = test(args, model, epoch + 1)
    return acc, loss
Exemple #4
0
def test_cifar(args, Z, names, arch):
    _, test_loader = datagen.load_cifar(args)
    criterion = nn.CrossEntropyLoss()
    pop_size = args.batch_size
    with torch.no_grad():
        correct = 0.
        test_loss = 0
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            outputs = []
            for i in range(pop_size):
                params = [Z[0][i], Z[1][i], Z[2][i], Z[3][i], Z[4][i]]
                model = weights_to_clf(params, names, arch)
                output = model(data)
                outputs.append(output)
            pop_outputs = torch.stack(outputs)
            pop_labels = pop_outputs.max(2, keepdim=True)[1].view(
                pop_size, 100, 1)
            modes, idxs = torch.mode(pop_labels, dim=0, keepdim=True)
            modes = modes.view(100, 1)
            correct += modes.eq(target.data.view_as(modes)).long().cpu().sum()
            test_loss += criterion(output, target).item()  # sum up batch loss
        test_loss /= len(test_loader.dataset)
        acc = (correct.float() / len(test_loader.dataset)).item()
    return acc, test_loss
    def __init__(self, args):
        self.s = args.s
        self.z = args.z
        self.batch_size = args.batch_size
        self.epochs = 200
        self.alpha = 1
        self.beta = args.beta
        self.target = args.target
        self.use_bn = args.use_bn
        self.bias = args.bias
        self.n_hidden = args.n_hidden
        self.pretrain_e = args.pretrain_e
        self.dataset = args.dataset
        self.test_ensemble = args.test_ensemble
        self.test_uncertainty = args.test_uncertainty
        self.vote = args.vote

        self.device = torch.device('cuda')
        torch.manual_seed(8734)        

        self.hypergan = HyperGAN(args, self.device)
        self.hypergan.print_hypergan()
        self.hypergan.attach_optimizers(5e-3, 1e-4, 5e-5)

        if self.dataset == 'mnist':
            self.data_train, self.data_test = datagen.load_mnist()
        elif self.dataset == 'cifar':
            self.data_train, self.data_test = datagen.load_cifar()

        self.best_test_acc = 0.
        self.best_test_loss = np.inf
def load_data(args):
    if args.dataset == 'mnist':
        return datagen.load_mnist(args)
    if args.dataset == 'cifar':
        return datagen.load_cifar(args)
    if args.dataset == 'fmnist':
        return datagen.load_fashion_mnist(args)
    if args.dataset == 'cifar_hidden':
        class_list = [0] ## just load class 0
        return datagen.load_cifar_hidden(args, class_list)
    else:
        print ('Dataset not specified correctly')
        print ('choose --dataset <mnist, fmnist, cifar, cifar_hidden>')
Exemple #7
0
    def __init__(self, args):
        self.lr = args.lr
        self.wd = args.wd
        self.epochs = 200
        self.dataset = args.dataset
        self.test_uncertainty = args.test_uncertainty
        self.vote = args.vote
        self.device = torch.device('cuda')
        torch.manual_seed(8734)        
        
        self.model = models.LeNet_Dropout().to(self.device)
        self.optimizer = torch.optim.Adam(model.parameters(), self.lr, weight_decay=self.wd)

        if self.dataset == 'mnist':
            self.data_train, self.data_test = datagen.load_mnist()
        elif self.dataset == 'cifar':
            self.data_train, self.data_test = datagen.load_cifar()

        self.best_test_acc = 0.
        self.best_test_loss = np.inf
        print (self.model)
Exemple #8
0
    def __init__(self, args):
        self.lr = args.lr
        self.wd = args.wd
        self.epochs = 200
        self.dataset = args.dataset
        self.test_uncertainty = args.test_uncertainty
        self.vote = args.vote
        self.n_models = args.n_models
        self.device = torch.device('cuda')
        torch.manual_seed(8734)

        self.ensemble = [
            models.LeNet().to(self.device) for _ in range(self.n_models)
        ]
        self.attach_optimizers()

        if self.dataset == 'mnist':
            self.data_train, self.data_test = datagen.load_mnist()
        elif self.dataset == 'cifar':
            self.data_train, self.data_test = datagen.load_cifar()

        self.best_test_acc = 0.
        self.best_test_loss = np.inf
        print(self.ensemble[0], ' X {}'.format(self.n_models))
Exemple #9
0
def train(args):

    torch.manual_seed(8734)

    netE = Encoder(args).cuda()
    W1 = GeneratorW1(args).cuda()
    W2 = GeneratorW2(args).cuda()
    W3 = GeneratorW3(args).cuda()
    W4 = GeneratorW4(args).cuda()
    W5 = GeneratorW5(args).cuda()
    netD = DiscriminatorZ(args).cuda()
    print(netE, W1, W2, W3, W4, W5, netD)

    optimE = optim.Adam(netE.parameters(),
                        lr=5e-3,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)
    optimW1 = optim.Adam(W1.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW2 = optim.Adam(W2.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW3 = optim.Adam(W3.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW4 = optim.Adam(W4.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW5 = optim.Adam(W5.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimD = optim.Adam(netD.parameters(),
                        lr=5e-5,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)

    best_test_acc, best_test_loss = 0., np.inf
    args.best_loss, args.best_acc = best_test_loss, best_test_acc

    if args.hidden:
        c_idx = [0, 1, 2, 3, 4]
        cifar_train, cifar_test = datagen.load_cifar_hidden(args, c_idx)
    else:
        cifar_train, cifar_test = datagen.load_cifar(args)
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()
    print("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.pretrain_e:
        for j in range(2000):
            x = sample_z_like((e_batch_size, args.ze))
            z = sample_z_like((e_batch_size, args.z))
            codes = netE(x)
            for i, code in enumerate(codes):
                code = code.view(e_batch_size, args.z)
                mean_loss, cov_loss = pretrain_loss(code, z)
                loss = mean_loss + cov_loss
                loss.backward(retain_graph=True)
            optimE.step()
            netE.zero_grad()
            print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print('Finished Pretraining Encoder')
                break

    print('==> Begin Training')
    for _ in range(args.epochs):
        for batch_idx, (data, target) in enumerate(cifar_train):
            batch_zero_grad([netE, W1, W2, W3, W4, W5, netD])
            z = sample_z_like((
                args.batch_size,
                args.ze,
            ))
            codes = netE(z)
            l1 = W1(codes[0])
            l2 = W2(codes[1])
            l3 = W3(codes[2])
            l4 = W4(codes[3])
            l5 = W5(codes[4])

            # Z Adversary
            free_params([netD])
            frozen_params([netE, W1, W2, W3, W4, W5])
            for code in codes:
                noise = sample_z_like((args.batch_size, args.z))
                d_real = netD(noise)
                d_fake = netD(code)
                d_real_loss = -1 * torch.log((1 - d_real).mean())
                d_fake_loss = -1 * torch.log(d_fake.mean())
                d_real_loss.backward(retain_graph=True)
                d_fake_loss.backward(retain_graph=True)
                d_loss = d_real_loss + d_fake_loss
            optimD.step()

            # Generator (Mean test)
            frozen_params([netD])
            free_params([netE, W1, W2, W3, W4, W5])
            for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                correct, loss = train_clf(args, [g1, g2, g3, g4, g5], data,
                                          target)
                scaled_loss = args.beta * loss
                scaled_loss.backward(retain_graph=True)

            optimE.step()
            optimW1.step()
            optimW2.step()
            optimW3.step()
            optimW4.step()
            optimW5.step()
            loss = loss.item()
            """ Update Statistics """
            if batch_idx % 50 == 0:
                acc = (correct / 1)
                norm_z1 = np.linalg.norm(l1.data)
                norm_z2 = np.linalg.norm(l2.data)
                norm_z3 = np.linalg.norm(l3.data)
                norm_z4 = np.linalg.norm(l4.data)
                norm_z5 = np.linalg.norm(l5.data)
                print('**************************************')
                print('Mean Test: Enc, Dz, Lscale: {} test'.format(args.beta))
                print('Acc: {}, G Loss: {}, D Loss: {}'.format(
                    acc, loss, d_loss))
                print('Filter norm: ', norm_z1)
                print('Filter norm: ', norm_z2)
                print('Filter norm: ', norm_z3)
                print('Linear norm: ', norm_z4)
                print('Linear norm: ', norm_z5)
                print('best test loss: {}'.format(args.best_loss))
                print('best test acc: {}'.format(args.best_acc))
                print('**************************************')
            if batch_idx % 100 == 0:
                test_acc = 0.
                test_loss = 0.
                for i, (data, y) in enumerate(cifar_test):
                    z = sample_z_like((
                        args.batch_size,
                        args.ze,
                    ))
                    w1_code, w2_code, w3_code, w4_code, w5_code = netE(z)
                    l1 = W1(w1_code)
                    l2 = W2(w2_code)
                    l3 = W3(w3_code)
                    l4 = W4(w4_code)
                    l5 = W5(w5_code)
                    for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                        correct, loss = train_clf(args, [g1, g2, g3, g4, g5],
                                                  data, y)
                        test_acc += correct.item()
                        test_loss += loss.item()
                test_loss /= len(cifar_test.dataset) * args.batch_size
                test_acc /= len(cifar_test.dataset) * args.batch_size
                print('Test Accuracy: {}, Test Loss: {}'.format(
                    test_acc, test_loss))
                if test_loss < best_test_loss or test_acc > best_test_acc:
                    utils.save_hypernet_cifar(args,
                                              [netE, W1, W2, W3, W4, W5, netD],
                                              test_acc)
                    print('==> new best stats, saving')
                    if test_loss < best_test_loss:
                        best_test_loss = test_loss
                        args.best_loss = test_loss
                    if test_acc > best_test_acc:
                        best_test_acc = test_acc
                        args.best_acc = test_acc
Exemple #10
0
def train(args):

    torch.manual_seed(8734)

    netE = models.Encoder(args).cuda()
    W1 = models.GeneratorW1(args).cuda()
    W2 = models.GeneratorW2(args).cuda()
    W3 = models.GeneratorW3(args).cuda()
    W4 = models.GeneratorW4(args).cuda()
    W5 = models.GeneratorW5(args).cuda()
    netD = models.DiscriminatorZ(args).cuda()
    print(netE, W1, W2, W3, W4, W5, netD)

    optimE = optim.Adam(netE.parameters(),
                        lr=5e-3,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)
    optimW1 = optim.Adam(W1.parameters(),
                         lr=5e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW2 = optim.Adam(W2.parameters(),
                         lr=5e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW3 = optim.Adam(W3.parameters(),
                         lr=5e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW4 = optim.Adam(W4.parameters(),
                         lr=5e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW5 = optim.Adam(W5.parameters(),
                         lr=5e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimD = optim.Adam(netD.parameters(),
                        lr=5e-5,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)

    m_best_test_acc, m_best_test_loss = 0., np.inf
    c_best_test_acc, c_best_test_loss = 0., np.inf
    args.m_best_loss, args.m_best_acc = m_best_test_loss, m_best_test_acc
    args.c_best_loss, args.c_best_acc = c_best_test_loss, c_best_test_acc

    mnist_train, mnist_test = datagen.load_mnist(args)
    cifar_train, cifar_test = datagen.load_cifar(args)
    x_dist = utils.create_d(args.ze)
    z_dist = utils.create_d(args.z)
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()
    print("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.pretrain_e:
        mask1 = torch.zeros(e_batch_size, args.ze).cuda()
        mask2 = torch.ones(e_batch_size, args.ze).cuda()
        for j in range(500):
            x = utils.sample_d(x_dist, e_batch_size)
            z = utils.sample_d(z_dist, e_batch_size)
            if j % 2 == 0: x = torch.cat((x, mask1), dim=0)
            if j % 2 == 1: x = torch.cat((x, mask2), dim=0)
            codes = netE(x)
            for i, code in enumerate(codes):
                code = code.view(e_batch_size, args.z)
                mean_loss, cov_loss = pretrain_loss(code, z)
                loss = mean_loss + cov_loss
                loss.backward(retain_graph=True)
            optimE.step()
            netE.zero_grad()
            print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print('Finished Pretraining Encoder')
                break

    print('==> Begin Training')
    for _ in range(args.epochs):
        for batch_idx, (mnist,
                        cifar) in enumerate(zip(mnist_train, cifar_train)):
            if batch_idx % 2 == 0:
                data, target = mnist
                mask = torch.zeros(args.batch_size, args.ze).cuda()
            else:
                data, target = cifar
                mask = torch.ones(args.batch_size, args.ze).cuda()

            batch_zero_grad([netE, W1, W2, W3, W4, W5, netD])
            z = utils.sample_d(x_dist, args.batch_size)
            z = torch.cat((z, mask), dim=0)
            codes = netE(z)
            l1 = W1(codes[0])
            l2 = W2(codes[1])
            l3 = W3(codes[2])
            l4 = W4(codes[3])
            l5 = W5(codes[4])
            # Z Adversary
            for code in codes:
                noise = utils.sample_d(z_dist, args.batch_size)
                d_real = netD(noise)
                d_fake = netD(code)
                d_real_loss = -1 * torch.log((1 - d_real).mean())
                d_fake_loss = -1 * torch.log(d_fake.mean())
                d_real_loss.backward(retain_graph=True)
                d_fake_loss.backward(retain_graph=True)
                d_loss = d_real_loss + d_fake_loss
            optimD.step()
            # Generator (Mean test)
            netD.zero_grad()
            z = utils.sample_d(x_dist, args.batch_size)
            z = torch.cat((z, mask), dim=0)
            codes = netE(z)
            l1 = W1(codes[0])
            l2 = W2(codes[1])
            l3 = W3(codes[2])
            l4 = W4(codes[3])
            l5 = W5(codes[4])
            d_real = []
            for code in codes:
                d = netD(code)
                d_real.append(d)
            netD.zero_grad()
            d_loss = torch.stack(d_real).log().mean() * 10.
            for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                correct, loss = train_clf(args, [g1, g2, g3, g4, g5], data,
                                          target)
                scaled_loss = args.beta * loss
                if loss != loss:
                    sys.exit(0)
                scaled_loss.backward(retain_graph=True)
                d_loss.backward(torch.tensor(-1, dtype=torch.float).cuda(),
                                retain_graph=True)
            optimE.step()
            optimW1.step()
            optimW2.step()
            optimW3.step()
            optimW4.step()
            optimW5.step()

            loss = loss.item()
            """ Update Statistics """
            if batch_idx % 50 == 0 or batch_idx % 50 == 1:
                acc = correct
                print('**************************************')
                if batch_idx % 50 == 0:
                    print('MNIST Test: Enc, Dz, Lscale: {} test'.format(
                        args.beta))
                if batch_idx % 50 == 1:
                    print('CIFAR Test: Enc, Dz, Lscale: {} test'.format(
                        args.beta))
                print('Acc: {}, G Loss: {}, D Loss: {}'.format(
                    acc, loss, d_loss))
                print('best test loss: {}, {}'.format(args.m_best_loss,
                                                      args.c_best_loss))
                print('best test acc: {}, {}'.format(args.m_best_acc,
                                                     args.c_best_acc))
                print('**************************************')
            if batch_idx > 1 and batch_idx % 100 == 0:
                m_test_acc = 0.
                m_test_loss = 0.
                for i, (data, y) in enumerate(mnist_test):
                    z = utils.sample_d(x_dist, args.batch_size)
                    z = torch.cat(
                        (z, torch.zeros(args.batch_size, args.ze).cuda()),
                        dim=0)
                    w1_code, w2_code, w3_code, w4_code, w5_code = netE(z)
                    l1 = W1(w1_code)
                    l2 = W2(w2_code)
                    l3 = W3(w3_code)
                    l4 = W4(w4_code)
                    l5 = W5(w5_code)
                    for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                        correct, loss = train_clf(args, [g1, g2, g3, g4, g5],
                                                  data, y)
                        m_test_acc += correct.item()
                        m_test_loss += loss.item()
                m_test_loss /= len(mnist_test.dataset) * args.batch_size
                m_test_acc /= len(mnist_test.dataset) * args.batch_size
                print('MNIST Test Accuracy: {}, Test Loss: {}'.format(
                    m_test_acc, m_test_loss))

                c_test_acc = 0.
                c_test_loss = 0
                for i, (data, y) in enumerate(cifar_test):
                    z = utils.sample_d(x_dist, args.batch_size)
                    z = torch.cat(
                        (z, torch.ones(args.batch_size, args.ze).cuda()),
                        dim=0)
                    w1_code, w2_code, w3_code, w4_code, w5_code = netE(z)
                    l1 = W1(w1_code)
                    l2 = W2(w2_code)
                    l3 = W3(w3_code)
                    l4 = W4(w4_code)
                    l5 = W5(w5_code)
                    for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                        correct, loss = train_clf(args, [g1, g2, g3, g4, g5],
                                                  data, y)
                        c_test_acc += correct.item()
                        c_test_loss += loss.item()
                c_test_loss /= len(cifar_test.dataset) * args.batch_size
                c_test_acc /= len(cifar_test.dataset) * args.batch_size
                print('CIFAR Test Accuracy: {}, Test Loss: {}'.format(
                    c_test_acc, c_test_loss))

                if m_test_loss < m_best_test_loss or m_test_acc > m_best_test_acc:
                    #utils.save_hypernet_cifar(args, [netE, W1, W2, W3, W4, W5, netD], test_acc)
                    print('==> new best stats, saving')
                    if m_test_loss < m_best_test_loss:
                        m_best_test_loss = m_test_loss
                        args.m_best_loss = m_test_loss
                    if m_test_acc > m_best_test_acc:
                        m_best_test_acc = m_test_acc
                        args.m_best_acc = m_test_acc

                if c_test_loss < c_best_test_loss or c_test_acc > c_best_test_acc:
                    #utils.save_hypernet_cifar(args, [netE, W1, W2, W3, W4, W5, netD], test_acc)
                    print('==> new best stats, saving')
                    if c_test_loss < c_best_test_loss:
                        c_best_test_loss = c_test_loss
                        args.c_best_loss = c_test_loss
                    if c_test_acc > c_best_test_acc:
                        c_best_test_acc = c_test_acc
                        args.c_best_acc = c_test_acc
def train(args):
    
    torch.manual_seed(1)
    
    netE = models.Encoder(args).cuda()
    W1 = models.GeneratorW1(args).cuda()
    W2 = models.GeneratorW2(args).cuda()
    W3 = models.GeneratorW3(args).cuda()
    W4 = models.GeneratorW4(args).cuda()
    W5 = models.GeneratorW5(args).cuda()
    netD = models.DiscriminatorZ(args).cuda()
    print (netE, W1, W2, W3, W4, W5, netD)

    optimE = optim.Adam(netE.parameters(), lr=5e-3, betas=(0.5, 0.9), weight_decay=5e-4)
    optimW1 = optim.Adam(W1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4)
    optimW2 = optim.Adam(W2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4)
    optimW3 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4)
    optimW4 = optim.Adam(W4.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4)
    optimW5 = optim.Adam(W5.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4)
    optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4)
    
    best_test_acc, best_clf_acc, best_test_loss = 0., 0., np.inf
    args.best_loss, args.best_acc = best_test_loss, best_test_acc
    args_best_clf_loss, args.best_clf_acc = np.inf, 0.
    
    cifar_train, cifar_test = datagen.load_cifar(args)#_hidden(args, [0, 1, 2, 3, 4])
    one = torch.tensor(1).cuda()
    mone = (one * -1).cuda()
    print ("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.pretrain_e:
        for j in range(1000):
            x = utils.sample_z_like((e_batch_size, args.ze))
            z = utils.sample_z_like((e_batch_size, args.z))
            codes = netE(x)
            for i, code in enumerate(codes):
                code = code.view(e_batch_size, args.z)
                mean_loss, cov_loss = ops.pretrain_loss(code, z)
                loss = mean_loss + cov_loss
                loss.backward(retain_graph=True)
            optimE.step()
            netE.zero_grad()
            print ('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print ('Finished Pretraining Encoder')
                break
    bb = 0
    print ('==> Begin Training')
    for _ in range(args.epochs):
        for batch_idx, (data, target) in enumerate(cifar_train):
            utils.batch_zero_grad([netE, W1, W2, W3, W4, W5, netD])
            z1 = utils.sample_z_like((args.batch_size, args.z,))
            z2 = utils.sample_z_like((args.batch_size, args.z,))
            z3 = utils.sample_z_like((args.batch_size, args.z,))
            z4 = utils.sample_z_like((args.batch_size, args.z,))
            z5 = utils.sample_z_like((args.batch_size, args.z,))
            #codes = netE(z)
            codes = [z1, z2, z3, z4, z5]
            l1 = W1(codes[0])
            l2 = W2(codes[1])
            l3 = W3(codes[2])
            l4 = W4(codes[3])
            l5 = W5(codes[4])
            
            """
            # Z Adversary 
            for code in codes:
                noise = utils.sample_z_like((args.batch_size, args.z))
                d_real = netD(noise)
                d_fake = netD(code)
                d_real_loss = -1 * torch.log((1-d_real).mean())
                d_fake_loss = -1 * torch.log(d_fake.mean())
                d_real_loss.backward(retain_graph=True)
                d_fake_loss.backward(retain_graph=True)
                d_loss = d_real_loss + d_fake_loss
            optimD.step()
            """
            clf_loss = 0
            for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, target)
                clf_loss += loss
            clf_loss *= args.beta
            clf_loss.backward()
               
            optimE.step(); optimW1.step(); optimW2.step()
            optimW3.step(); optimW4.step(); optimW5.step()
            utils.batch_zero_grad([optimE, optimW1, optimW2, optimW3, optimW4, optimW5])
            loss = loss.item()
            """ Update Statistics """
            if batch_idx % 50 == 0:
                bb += 1
                acc = correct 
                print ('**************************************')
                print ("epoch: {}".format(bb))
                print ('Mean Test: Enc, Dz, Lscale: {} test'.format(args.beta))
                print ('Acc: {}, G Loss: {}, D Loss: {}'.format(acc, loss, 0))# d_loss))
                print ('best test loss: {}'.format(args.best_loss))
                print ('best test acc: {}'.format(args.best_acc))
                print ('best clf acc: {}'.format(best_clf_acc))
                print ('**************************************')
            if batch_idx % 100 == 0:
                with torch.no_grad():
                    test_acc = 0.
                    test_loss = 0.
                    for i, (data, y) in enumerate(cifar_test):
                        w1_code = utils.sample_z_like((args.batch_size, args.z,))
                        w2_code = utils.sample_z_like((args.batch_size, args.z,))
                        w3_code = utils.sample_z_like((args.batch_size, args.z,))
                        w4_code = utils.sample_z_like((args.batch_size, args.z,))
                        w5_code = utils.sample_z_like((args.batch_size, args.z,))
                        #w1_code, w2_code, w3_code, w4_code, w5_code = netE(z)

                        l1 = W1(w1_code)
                        l2 = W2(w2_code)
                        l3 = W3(w3_code)
                        l4 = W4(w4_code)
                        l5 = W5(w5_code)
                        for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                            loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, y)
                            test_acc += correct.item()
                            test_loss += loss.item()
                        #clf_acc, clf_loss = test_clf(args, [l1, l2, l3, l4, l5])
                test_loss /= len(cifar_test.dataset) * args.batch_size
                test_acc /= len(cifar_test.dataset) * args.batch_size

                stats.update_logger(l1, l2, l3, l4, l5, logger)
                stats.update_acc(logger, test_acc)
                #stats.update_grad(logger, grads, norms)
                stats.save_logger(logger, args.exp)
                stats.plot_logger(logger)
                
                print ('Test Accuracy: {}, Test Loss: {}'.format(test_acc, test_loss))
                #print ('Clf Accuracy: {}, Clf Loss: {}'.format(clf_acc, clf_loss))
                if test_loss < best_test_loss:
                    best_test_loss, args.best_loss = test_loss, test_loss
                if test_acc > best_test_acc:
                    best_test_acc, args.best_acc = test_acc, test_acc
Exemple #12
0
    def train(self):
        cifar_train, cifar_test = datagen.load_cifar()
        best_test_acc, best_test_loss = 0., np.inf
        self.best_loss, self.best_acc = best_test_loss, best_test_acc

        one = torch.FloatTensor([1]).cuda()
        mone = (one * -1).cuda()
        if self.pretrain_e:
            print("==> pretraining encoder")
            self.pretrain_encoder()

        print('==> Begin Training')
        for epoch in range(1000):
            for batch_idx, (data, target) in enumerate(cifar_train):
                s = torch.randn(self.batch_size, self.s).cuda()
                codes = self.hypergan.mixer(s)
                params = self.hypergan.generator(codes)

                # Z Adversary

                for code in codes:
                    noise = torch.randn(self.batch_size, self.z).cuda()
                    d_real = self.hypergan.discriminator(noise)
                    d_fake = self.hypergan.discriminator(code)
                    d_real_loss = -1 * torch.log((1 - d_real).mean())
                    d_fake_loss = -1 * torch.log(d_fake.mean())
                    d_real_loss.backward(retain_graph=True)
                    d_fake_loss.backward(retain_graph=True)
                    d_loss = d_real_loss + d_fake_loss
                self.hypergan.optim_disc.step()

                d_loss = 0
                losses, corrects = [], []
                for (layers) in zip(*params):
                    correct, loss = self.train_clf(layers,
                                                   data,
                                                   target,
                                                   val=True)
                    losses.append(loss)
                    corrects.append(correct)
                loss = torch.stack(losses).mean()
                correct = torch.stack(corrects).mean()
                scaled_loss = self.beta * loss
                scaled_loss.backward()

                self.hypergan.optim_mixer.step()
                self.hypergan.update_generator()
                self.hypergan.zero_grad()

                loss = loss.item()
                """ Update Statistics """
                if batch_idx % 100 == 0:
                    acc = (100 * (correct / self.batch_size))
                    print('**************************************')
                    print('CIFAR Test, epoch: {}'.format(epoch))
                    print('Acc: {}, G Loss: {}, D Loss: {}'.format(
                        acc, loss, d_loss))
                    print('best test loss: {}'.format(self.best_loss))
                    print('best test acc: {}'.format(self.best_acc))
                    print('**************************************')

            with torch.no_grad():
                test_acc = 0.
                test_loss = 0.
                total_correct = 0.
                for i, (data, target) in enumerate(cifar_test):
                    z = torch.randn(self.batch_size, self.s).cuda()
                    codes = self.hypergan.mixer(z)
                    params = self.hypergan.generator(codes)

                    losses, corrects = [], []
                    for (layers) in zip(*params):
                        correct, loss = self.train_clf(layers,
                                                       data,
                                                       target,
                                                       val=True)
                        losses.append(loss)
                        corrects.append(correct)
                    losses = torch.stack(losses)
                    corrects = torch.stack(corrects)
                    test_acc += corrects.mean().item()
                    total_correct += corrects.sum().item()
                    test_loss += losses.mean().item()
                test_loss /= len(cifar_test.dataset)
                test_acc /= len(cifar_test.dataset)
                total_correct /= self.batch_size

                print('[Epoch {}] Test Loss: {}, Test Accuracy: {},  ({}/{})'.
                      format(epoch, test_loss, test_acc, total_correct,
                             len(cifar_test.dataset)))

                if test_loss < best_test_loss or test_acc > best_test_acc:
                    print('==> new best stats, saving')
                    if test_loss < best_test_loss:
                        best_test_loss = test_loss
                        self.best_loss = test_loss
                    if test_acc > best_test_acc:
                        best_test_acc = test_acc
                        self.best_acc = best_test_acc
Exemple #13
0
def train(args):

    torch.manual_seed(1)
    netE = models.Encoder(args).cuda()
    W1 = models.GeneratorW1(args).cuda()
    W2 = models.GeneratorW2(args).cuda()
    W3 = models.GeneratorW3(args).cuda()
    W4 = models.GeneratorW4(args).cuda()
    W5 = models.GeneratorW5(args).cuda()
    netD = models.DiscriminatorZ(args).cuda()
    print(netE, W1, W2, W3, W4, W5, netD)

    optimE = optim.Adam(netE.parameters(),
                        lr=5e-5,
                        betas=(0.5, 0.9),
                        weight_decay=5e-4)
    optimW1 = optim.Adam(W1.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimW2 = optim.Adam(W2.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimW3 = optim.Adam(W3.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimW4 = optim.Adam(W4.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimW5 = optim.Adam(W5.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimD = optim.Adam(netD.parameters(),
                        lr=5e-5,
                        betas=(0.5, 0.9),
                        weight_decay=5e-4)

    best_test_acc, best_clf_acc, best_test_loss = 0., 0., np.inf
    args.best_loss, args.best_acc = best_test_loss, best_test_acc
    args.best_clf_loss, args.best_clf_acc = np.inf, 0.

    cifar_train, cifar_test = datagen.load_cifar(args)
    x_dist = utils.create_d(args.ze)
    z_dist = utils.create_d(args.z)
    qz_dist = utils.create_d(args.z * 5)

    one = torch.tensor(1).cuda()
    mone = one * -1
    print("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.pretrain_e:
        for j in range(700):
            x = utils.sample_d(x_dist, e_batch_size)
            z = utils.sample_d(z_dist, e_batch_size)
            codes = torch.stack(netE(x)).view(-1, args.z * 5)
            qz = utils.sample_d(qz_dist, e_batch_size)
            mean_loss, cov_loss = ops.pretrain_loss(codes, qz)
            loss = mean_loss + cov_loss
            loss.backward()
            optimE.step()
            netE.zero_grad()
            print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print('Finished Pretraining Encoder')
                break

    print('==> Begin Training')
    for _ in range(args.epochs):
        for batch_idx, (data, target) in enumerate(cifar_train):
            z = utils.sample_d(x_dist, args.batch_size)
            ze = utils.sample_d(z_dist, args.batch_size)
            qz = utils.sample_d(qz_dist, args.batch_size)
            codes = netE(z)
            noise = utils.sample_d(qz_dist, args.batch_size)
            log_pz = ops.log_density(ze, 2).view(-1, 1)
            d_loss, d_q = ops.calc_d_loss(args,
                                          netD,
                                          ze,
                                          codes,
                                          log_pz,
                                          cifar=True)
            optimD.zero_grad()
            d_loss.backward(retain_graph=True)
            optimD.step()

            l1 = W1(codes[0])
            l2 = W2(codes[1])
            l3 = W3(codes[2])
            l4 = W4(codes[3])
            l5 = W5(codes[4])

            gp, grads, norms = ops.calc_gradient_penalty(z,
                                                         [W1, W2, W3, W4, W5],
                                                         netE,
                                                         cifar=True)
            reduce = lambda x: x.mean(0).mean(0).item()
            grads = [reduce(grad) for grad in grads]
            clf_loss = 0.
            for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data,
                                          target)
                clf_loss += loss
            G_loss = clf_loss / args.batch_size
            one_qz = torch.ones((160, 1), requires_grad=True).cuda()
            log_qz = ops.log_density(torch.ones(160, 1), 2).view(-1, 1)
            Q_loss = F.binary_cross_entropy_with_logits(d_q + log_qz, one_qz)
            total_hyper_loss = Q_loss + G_loss
            total_hyper_loss.backward()

            optimE.step()
            optimW1.step()
            optimW2.step()
            optimW4.step()
            optimW5.step()
            optimE.zero_grad()
            optimW1.zero_grad()
            optimW2.zero_grad()
            optimW3.zero_grad()
            optimW4.zero_grad()
            optimW5.zero_grad()

            total_loss = total_hyper_loss.item()

            if batch_idx % 50 == 0:
                acc = correct
                print('**************************************')
                print('Acc: {}, MD Loss: {}, D Loss: {}'.format(
                    acc, total_hyper_loss, d_loss))
                #print ('penalties: ', [gp[x].item() for x in range(len(gp))])
                print('grads: ', grads)
                print('best test loss: {}'.format(args.best_loss))
                print('best test acc: {}'.format(args.best_acc))
                print('best clf acc: {}'.format(args.best_clf_acc))
                print('**************************************')

            if batch_idx > 1 and batch_idx % 100 == 0:
                test_acc = 0.
                test_loss = 0.
                with torch.no_grad():
                    for i, (data, y) in enumerate(cifar_test):
                        z = utils.sample_d(x_dist, args.batch_size)
                        codes = netE(z)
                        l1 = W1(codes[0])
                        l2 = W2(codes[1])
                        l3 = W3(codes[2])
                        l4 = W4(codes[3])
                        l5 = W5(codes[4])
                        for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                            loss, correct = train_clf(args,
                                                      [g1, g2, g3, g4, g5],
                                                      data, y)
                            test_acc += correct.item()
                            test_loss += loss.item()
                    test_loss /= len(cifar_test.dataset) * args.batch_size
                    test_acc /= len(cifar_test.dataset) * args.batch_size
                    clf_acc, clf_loss = test_clf(args, [l1, l2, l3, l4, l5])
                    stats.update_logger(l1, l2, l3, l4, l5, logger)
                    stats.update_acc(logger, test_acc)
                    stats.update_grad(logger, grads, norms)
                    stats.save_logger(logger, args.exp)
                    stats.plot_logger(logger)

                    print('Test Accuracy: {}, Test Loss: {}'.format(
                        test_acc, test_loss))
                    print('Clf Accuracy: {}, Clf Loss: {}'.format(
                        clf_acc, clf_loss))
                    if test_loss < best_test_loss:
                        best_test_loss, args.best_loss = test_loss, test_loss
                    if test_acc > best_test_acc:
                        best_test_acc, args.best_acc = test_acc, test_acc
                    if clf_acc > best_clf_acc:
                        best_clf_acc, args.best_clf_acc = clf_acc, clf_acc
                        utils.save_hypernet_cifar(
                            args, [netE, netD, W1, W2, W3, W4, W5], clf_acc)
Exemple #14
0
def train(args):

    torch.manual_seed(8734)

    netG = Generator(args).cuda()
    netD = Discriminator(args).cuda()
    print(netG, netD)

    optimG = optim.Adam(netG.parameters(),
                        lr=1e-4,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)
    optimD = optim.Adam(netD.parameters(),
                        lr=1e-4,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)

    cifar_train, cifar_test = datagen.load_cifar(args)
    train = inf_gen(cifar_train)
    print('saving reals')
    reals, _ = next(train)

    if not os.path.exists('results/'):
        os.makedirs('results')

    save_image(reals, 'results/reals.png')

    one = torch.tensor(1.).cuda()
    mone = (one * -1)
    total_batches = 0

    print('==> Begin Training')
    for iter in range(args.epochs):
        total_batches += 1
        ops.batch_zero_grad([netG, netD])
        for p in netD.parameters():
            p.requires_grad = True
        for _ in range(args.disc_iters):
            data, targets = next(train)
            netD.zero_grad()
            d_real = netD(data).mean()
            d_real.backward(mone, retain_graph=True)
            noise = torch.randn(args.batch_size, args.z,
                                requires_grad=True).cuda()
            with torch.no_grad():
                fake = netG(noise)
            fake.requires_grad_(True)
            d_fake = netD(fake)
            d_fake = d_fake.mean()
            d_fake.backward(one, retain_graph=True)
            gp = ops.grad_penalty_3dim(args, netD, data, fake)
            gp.backward()
            d_cost = d_fake - d_real + gp
            wasserstein_d = d_real - d_fake
            optimD.step()

        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()
        noise = torch.randn(args.batch_size, args.z, requires_grad=True).cuda()
        fake = netG(noise)
        G = netD(fake)
        G = G.mean()
        G.backward(mone)
        g_cost = -G
        optimG.step()

        if iter % 100 == 0:
            print('iter: ', iter, 'train D cost', d_cost.cpu().item())
            print('iter: ', iter, 'train G cost', g_cost.cpu().item())
        if iter % 300 == 0:
            val_d_costs = []
            for i, (data, target) in enumerate(cifar_test):
                data = data.cuda()
                d = netD(data)
                val_d_cost = -d.mean().item()
                val_d_costs.append(val_d_cost)
            utils.generate_image(args, iter, netG)
Exemple #15
0
def train(args):

    torch.manual_seed(8734)

    netE = models.Encoder(args).cuda()
    W1 = models.GeneratorW1(args).cuda()
    W2 = models.GeneratorW2(args).cuda()
    W3 = models.GeneratorW3(args).cuda()
    W4 = models.GeneratorW4(args).cuda()
    W5 = models.GeneratorW5(args).cuda()
    netD = models.DiscriminatorZ(args).cuda()
    print(netE, W1, W2, W3, W4, W5, netD)

    optimE = optim.Adam(netE.parameters(),
                        lr=5e-3,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)
    optimW1 = optim.Adam(W1.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW2 = optim.Adam(W2.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW3 = optim.Adam(W3.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW4 = optim.Adam(W4.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW5 = optim.Adam(W5.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimD = optim.Adam(netD.parameters(),
                        lr=5e-5,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)

    best_test_acc, best_test_loss = 0., np.inf
    args.best_loss, args.best_acc = best_test_loss, best_test_acc

    cifar_train, cifar_test = datagen.load_cifar(args)
    x_dist = utils.create_d(args.ze)
    z_dist = utils.create_d(args.z)
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()
    print("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.pretrain_e:
        for j in range(700):
            x = utils.sample_d(x_dist, e_batch_size)
            z = utils.sample_d(z_dist, e_batch_size)
            codes = netE(x)
            for i, code in enumerate(codes):
                code = code.view(e_batch_size, args.z)
                mean_loss, cov_loss = ops.pretrain_loss(code, z)
                loss = mean_loss + cov_loss
                loss.backward(retain_graph=True)
            optimE.step()
            netE.zero_grad()
            print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print('Finished Pretraining Encoder')
                break

    print('==> Begin Training')
    for _ in range(1000):
        for batch_idx, (data, target) in enumerate(cifar_train):

            batch_zero_grad([netE, W1, W2, W3, W4, W5, netD])
            z = utils.sample_d(x_dist, args.batch_size)
            codes = netE(z)
            l1 = W1(codes[0]).mean(0)
            l2 = W2(codes[1]).mean(0)
            l3 = W3(codes[2]).mean(0)
            l4 = W4(codes[3]).mean(0)
            l5 = W5(codes[4]).mean(0)

            # Z Adversary
            free_params([netD])
            frozen_params([netE, W1, W2, W3, W4, W5])
            for code in codes:
                noise = utils.sample_d(z_dist, args.batch_size)
                d_real = netD(noise)
                d_fake = netD(code)
                d_real_loss = -1 * torch.log((1 - d_real).mean())
                d_fake_loss = -1 * torch.log(d_fake.mean())
                d_real_loss.backward(retain_graph=True)
                d_fake_loss.backward(retain_graph=True)
                d_loss = d_real_loss + d_fake_loss
            optimD.step()
            frozen_params([netD])
            free_params([netE, W1, W2, W3, W4, W5])

            correct, loss = train_clf(args, [l1, l2, l3, l4, l5],
                                      data,
                                      target,
                                      val=True)
            scaled_loss = args.beta * loss
            scaled_loss.backward()

            optimE.step()
            optimW1.step()
            optimW2.step()
            optimW3.step()
            optimW4.step()
            optimW5.step()
            loss = loss.item()
            """ Update Statistics """
            if batch_idx % 50 == 0:
                acc = (correct / 1)
                print('**************************************')
                print('{} CIFAR Test, beta: {}'.format(args.model, args.beta))
                print('Acc: {}, G Loss: {}, D Loss: {}'.format(
                    acc, loss, d_loss))
                print('best test loss: {}'.format(args.best_loss))
                print('best test acc: {}'.format(args.best_acc))
                print('**************************************')
            if batch_idx > 1 and batch_idx % 199 == 0:
                test_acc = 0.
                test_loss = 0.
                total_correct = 0.
                for i, (data, y) in enumerate(cifar_test):
                    z = utils.sample_d(x_dist, args.batch_size)
                    codes = netE(z)
                    l1 = W1(codes[0]).mean(0)
                    l2 = W2(codes[1]).mean(0)
                    l3 = W3(codes[2]).mean(0)
                    l4 = W4(codes[3]).mean(0)
                    l5 = W5(codes[4]).mean(0)
                    correct, loss = train_clf(args, [l1, l2, l3, l4, l5],
                                              data,
                                              y,
                                              val=True)
                    test_acc += correct.item()
                    total_correct += correct.item()
                    test_loss += loss.item()
                test_loss /= len(cifar_test.dataset)
                test_acc /= len(cifar_test.dataset)

                print('Test Accuracy: {}, Test Loss: {},  ({}/{})'.format(
                    test_acc, test_loss, total_correct,
                    len(cifar_test.dataset)))

                if test_loss < best_test_loss or test_acc > best_test_acc:
                    print('==> new best stats, saving')
                    utils.save_clf(args, [l1, l2, l3, l4, l5], test_acc)
                    #utils.save_hypernet_cifar(args, [netE, W1, W2, W3, W4, W5, netD], test_acc)
                    if test_loss < best_test_loss:
                        best_test_loss = test_loss
                        args.best_loss = test_loss
                    if test_acc > best_test_acc:
                        best_test_acc = test_acc
                        args.best_acc = test_acc