Ejemplo n.º 1
0
def train(args):
    
    torch.manual_seed(1)
    
    netE = models.Encoder(args).cuda()
    W1 = models.GeneratorW1(args).cuda()
    W2 = models.GeneratorW2(args).cuda()
    W3 = models.GeneratorW3(args).cuda()
    W4 = models.GeneratorW4(args).cuda()
    W5 = models.GeneratorW5(args).cuda()
    netD = models.DiscriminatorZ(args).cuda()
    print (netE, W1, W2, W3, W4, W5, netD)

    optimE = optim.Adam(netE.parameters(), lr=5e-3, betas=(0.5, 0.9), weight_decay=5e-4)
    optimW1 = optim.Adam(W1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4)
    optimW2 = optim.Adam(W2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4)
    optimW3 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4)
    optimW4 = optim.Adam(W4.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4)
    optimW5 = optim.Adam(W5.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4)
    optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4)
    
    best_test_acc, best_clf_acc, best_test_loss = 0., 0., np.inf
    args.best_loss, args.best_acc = best_test_loss, best_test_acc
    args_best_clf_loss, args.best_clf_acc = np.inf, 0.
    
    cifar_train, cifar_test = datagen.load_cifar(args)#_hidden(args, [0, 1, 2, 3, 4])
    one = torch.tensor(1).cuda()
    mone = (one * -1).cuda()
    print ("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.pretrain_e:
        for j in range(1000):
            x = utils.sample_z_like((e_batch_size, args.ze))
            z = utils.sample_z_like((e_batch_size, args.z))
            codes = netE(x)
            for i, code in enumerate(codes):
                code = code.view(e_batch_size, args.z)
                mean_loss, cov_loss = ops.pretrain_loss(code, z)
                loss = mean_loss + cov_loss
                loss.backward(retain_graph=True)
            optimE.step()
            netE.zero_grad()
            print ('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print ('Finished Pretraining Encoder')
                break
    bb = 0
    print ('==> Begin Training')
    for _ in range(args.epochs):
        for batch_idx, (data, target) in enumerate(cifar_train):
            utils.batch_zero_grad([netE, W1, W2, W3, W4, W5, netD])
            z1 = utils.sample_z_like((args.batch_size, args.z,))
            z2 = utils.sample_z_like((args.batch_size, args.z,))
            z3 = utils.sample_z_like((args.batch_size, args.z,))
            z4 = utils.sample_z_like((args.batch_size, args.z,))
            z5 = utils.sample_z_like((args.batch_size, args.z,))
            #codes = netE(z)
            codes = [z1, z2, z3, z4, z5]
            l1 = W1(codes[0])
            l2 = W2(codes[1])
            l3 = W3(codes[2])
            l4 = W4(codes[3])
            l5 = W5(codes[4])
            
            """
            # Z Adversary 
            for code in codes:
                noise = utils.sample_z_like((args.batch_size, args.z))
                d_real = netD(noise)
                d_fake = netD(code)
                d_real_loss = -1 * torch.log((1-d_real).mean())
                d_fake_loss = -1 * torch.log(d_fake.mean())
                d_real_loss.backward(retain_graph=True)
                d_fake_loss.backward(retain_graph=True)
                d_loss = d_real_loss + d_fake_loss
            optimD.step()
            """
            clf_loss = 0
            for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, target)
                clf_loss += loss
            clf_loss *= args.beta
            clf_loss.backward()
               
            optimE.step(); optimW1.step(); optimW2.step()
            optimW3.step(); optimW4.step(); optimW5.step()
            utils.batch_zero_grad([optimE, optimW1, optimW2, optimW3, optimW4, optimW5])
            loss = loss.item()
            """ Update Statistics """
            if batch_idx % 50 == 0:
                bb += 1
                acc = correct 
                print ('**************************************')
                print ("epoch: {}".format(bb))
                print ('Mean Test: Enc, Dz, Lscale: {} test'.format(args.beta))
                print ('Acc: {}, G Loss: {}, D Loss: {}'.format(acc, loss, 0))# d_loss))
                print ('best test loss: {}'.format(args.best_loss))
                print ('best test acc: {}'.format(args.best_acc))
                print ('best clf acc: {}'.format(best_clf_acc))
                print ('**************************************')
            if batch_idx % 100 == 0:
                with torch.no_grad():
                    test_acc = 0.
                    test_loss = 0.
                    for i, (data, y) in enumerate(cifar_test):
                        w1_code = utils.sample_z_like((args.batch_size, args.z,))
                        w2_code = utils.sample_z_like((args.batch_size, args.z,))
                        w3_code = utils.sample_z_like((args.batch_size, args.z,))
                        w4_code = utils.sample_z_like((args.batch_size, args.z,))
                        w5_code = utils.sample_z_like((args.batch_size, args.z,))
                        #w1_code, w2_code, w3_code, w4_code, w5_code = netE(z)

                        l1 = W1(w1_code)
                        l2 = W2(w2_code)
                        l3 = W3(w3_code)
                        l4 = W4(w4_code)
                        l5 = W5(w5_code)
                        for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                            loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, y)
                            test_acc += correct.item()
                            test_loss += loss.item()
                        #clf_acc, clf_loss = test_clf(args, [l1, l2, l3, l4, l5])
                test_loss /= len(cifar_test.dataset) * args.batch_size
                test_acc /= len(cifar_test.dataset) * args.batch_size

                stats.update_logger(l1, l2, l3, l4, l5, logger)
                stats.update_acc(logger, test_acc)
                #stats.update_grad(logger, grads, norms)
                stats.save_logger(logger, args.exp)
                stats.plot_logger(logger)
                
                print ('Test Accuracy: {}, Test Loss: {}'.format(test_acc, test_loss))
                #print ('Clf Accuracy: {}, Clf Loss: {}'.format(clf_acc, clf_loss))
                if test_loss < best_test_loss:
                    best_test_loss, args.best_loss = test_loss, test_loss
                if test_acc > best_test_acc:
                    best_test_acc, args.best_acc = test_acc, test_acc
Ejemplo n.º 2
0
def train(args):

    torch.manual_seed(8734)
    netE = models.Encoder(args).cuda()
    W1 = models.GeneratorW1(args).cuda()
    W2 = models.GeneratorW2(args).cuda()
    W3 = models.GeneratorW3(args).cuda()
    W4 = models.GeneratorW4(args).cuda()
    W5 = models.GeneratorW5(args).cuda()
    netD = models.DiscriminatorZ(args).cuda()
    print(netE, W1, W2, W3, W4, W5, netD)

    optimE = optim.Adam(netE.parameters(),
                        lr=5e-4,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)
    optimW1 = optim.Adam(W1.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW2 = optim.Adam(W2.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW3 = optim.Adam(W3.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW4 = optim.Adam(W3.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW5 = optim.Adam(W3.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimD = optim.Adam(netD.parameters(),
                        lr=1e-4,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)

    best_test_acc, best_test_loss = 0., np.inf
    args.best_loss, args.best_acc = best_test_loss, best_test_acc

    mnist_train, mnist_test = datagen.load_mnist(args)
    x_dist = utils.create_d(args.ze)
    z_dist = utils.create_d(args.z)
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()
    print("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.pretrain_e:
        for j in range(100):
            x = utils.sample_d(x_dist, e_batch_size)
            z = utils.sample_d(z_dist, e_batch_size)
            codes = netE(x)
            for i, code in enumerate(codes):
                code = code.view(e_batch_size, args.z)
                mean_loss, cov_loss = ops.pretrain_loss(code, z)
                loss = mean_loss + cov_loss
                loss.backward(retain_graph=True)
            optimE.step()
            netE.zero_grad()
            print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print('Finished Pretraining Encoder')
                break

    print('==> Begin Training')
    for _ in range(args.epochs):
        for batch_idx, (data, target) in enumerate(mnist_train):
            netE.zero_grad()
            W1.zero_grad()
            W2.zero_grad()
            W3.zero_grad()
            W4.zero_grad()
            W5.zero_grad()
            z = utils.sample_d(x_dist, args.batch_size)
            codes = netE(z)
            #ops.free_params([netD]); ops.frozen_params([netE, W1, W2, W3])
            for code in codes:
                noise = utils.sample_z_like((args.batch_size, args.z))
                d_real = netD(noise)
                d_fake = netD(code)
                d_real_loss = torch.log((1 - d_real).mean())
                d_fake_loss = torch.log(d_fake.mean())
                d_real_loss.backward(torch.tensor(-1,
                                                  dtype=torch.float).cuda(),
                                     retain_graph=True)
                d_fake_loss.backward(torch.tensor(-1,
                                                  dtype=torch.float).cuda(),
                                     retain_graph=True)
                d_loss = d_real_loss + d_fake_loss
            optimD.step()
            #ops.frozen_params([netD])
            #ops.free_params([netE, W1, W2, W3])
            netD.zero_grad()
            z = utils.sample_d(x_dist, args.batch_size)
            codes = netE(z)
            l1 = W1(codes[0])
            l2 = W2(codes[1])
            l3 = W3(codes[2])
            l4 = W4(codes[3])
            l5 = W5(codes[4])
            d_real = []
            for code in codes:
                d = netD(code)
                d_real.append(d)

            netD.zero_grad()
            d_loss = torch.stack(d_real).log().mean() * 10.
            for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                correct, loss = train_clf(args, [g1, g2, g3, g4, g5], data,
                                          target)
                scaled_loss = args.beta * loss
                scaled_loss.backward(retain_graph=True)
                d_loss.backward(torch.tensor(-1, dtype=torch.float).cuda(),
                                retain_graph=True)
            optimE.step()
            optimW1.step()
            optimW2.step()
            optimW3.step()
            optimW4.step()
            optimW5.step()
            loss = loss.item()

            if batch_idx % 50 == 0:
                acc = correct
                print('**************************************')
                print('{} MNIST Test, beta: {}'.format(args.model, args.beta))
                print('Acc: {}, Loss: {}'.format(acc, loss))
                print('D loss: {}'.format(d_loss))
                print('best test loss: {}'.format(args.best_loss))
                print('best test acc: {}'.format(args.best_acc))
                print('**************************************')

            if batch_idx > 1 and batch_idx % 199 == 0:
                test_acc = 0.
                test_loss = 0.
                for i, (data, y) in enumerate(mnist_test):
                    z = utils.sample_d(x_dist, args.batch_size)
                    codes = netE(z)
                    l1 = W1(codes[0])
                    l2 = W2(codes[1])
                    l3 = W3(codes[2])
                    l4 = W4(codes[3])
                    l5 = W5(codes[4])
                    for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                        correct, loss = train_clf(args, [g1, g2, g3, g4, g5],
                                                  data, y)
                        test_acc += correct.item()
                        test_loss += loss.item()
                test_loss /= len(mnist_test.dataset) * args.batch_size
                test_acc /= len(mnist_test.dataset) * args.batch_size

                print('Test Accuracy: {}, Test Loss: {}'.format(
                    test_acc, test_loss))
                if test_loss < best_test_loss or test_acc > best_test_acc:
                    print('==> new best stats, saving')
                    #utils.save_clf(args, z_test, test_acc)
                    if test_acc > .85:
                        utils.save_hypernet_cifar(args,
                                                  [netE, W1, W2, W3, W4, W5],
                                                  test_acc)
                    if test_loss < best_test_loss:
                        best_test_loss = test_loss
                        args.best_loss = test_loss
                    if test_acc > best_test_acc:
                        best_test_acc = test_acc
                        args.best_acc = test_acc
Ejemplo n.º 3
0
def train(args):
    from torch import optim
    #torch.manual_seed(8734)
    netE = models.Encoderz(args).cuda()
    netD = models.DiscriminatorZ(args).cuda()
    E1 = models.GeneratorE1(args).cuda()
    E2 = models.GeneratorE2(args).cuda()
    #E3 = models.GeneratorE3(args).cuda()
    #E4 = models.GeneratorE4(args).cuda()
    #D1 = models.GeneratorD1(args).cuda()
    D1 = models.GeneratorD2(args).cuda()
    D2 = models.GeneratorD3(args).cuda()
    D3 = models.GeneratorD4(args).cuda()
    print(netE, netD)
    print(E1, E2, D1, D2, D3)

    optimE = optim.Adam(netE.parameters(),
                        lr=5e-4,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)
    optimD = optim.Adam(netD.parameters(),
                        lr=1e-4,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)

    Eoptim = [
        optim.Adam(E1.parameters(),
                   lr=1e-4,
                   betas=(0.5, 0.9),
                   weight_decay=1e-4),
        optim.Adam(E2.parameters(),
                   lr=1e-4,
                   betas=(0.5, 0.9),
                   weight_decay=1e-4),
        #optim.Adam(E3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4),
        #optim.Adam(E4.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    ]
    Doptim = [
        #optim.Adam(D1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4),
        optim.Adam(D1.parameters(),
                   lr=1e-4,
                   betas=(0.5, 0.9),
                   weight_decay=1e-4),
        optim.Adam(D2.parameters(),
                   lr=1e-4,
                   betas=(0.5, 0.9),
                   weight_decay=1e-4),
        optim.Adam(D3.parameters(),
                   lr=1e-4,
                   betas=(0.5, 0.9),
                   weight_decay=1e-4)
    ]

    Enets = [E1, E2]
    Dnets = [D1, D2, D3]

    best_test_loss = np.inf
    args.best_loss = best_test_loss

    mnist_train, mnist_test = datagen.load_mnist(args)
    x_dist = utils.create_d(args.ze)
    z_dist = utils.create_d(args.z)
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()
    print("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.pretrain_e:
        for j in range(100):
            x = utils.sample_d(x_dist, e_batch_size)
            z = utils.sample_d(z_dist, e_batch_size)
            codes = netE(x)
            for i, code in enumerate(codes):
                code = code.view(e_batch_size, args.z)
                mean_loss, cov_loss = ops.pretrain_loss(code, z)
                loss = mean_loss + cov_loss
                loss.backward(retain_graph=True)
            optimE.step()
            netE.zero_grad()
            print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print('Finished Pretraining Encoder')
                break

    print('==> Begin Training')
    for _ in range(args.epochs):
        for batch_idx, (data, target) in enumerate(mnist_train):
            netE.zero_grad()
            for optim in Eoptim:
                optim.zero_grad()
            for optim in Doptim:
                optim.zero_grad()
            z = utils.sample_d(x_dist, args.batch_size)
            codes = netE(z)
            for code in codes:
                noise = utils.sample_z_like((args.batch_size, args.z))
                d_real = netD(noise)
                d_fake = netD(code)
                d_real_loss = torch.log((1 - d_real).mean())
                d_fake_loss = torch.log(d_fake.mean())
                d_real_loss.backward(torch.tensor(-1,
                                                  dtype=torch.float).cuda(),
                                     retain_graph=True)
                d_fake_loss.backward(torch.tensor(-1,
                                                  dtype=torch.float).cuda(),
                                     retain_graph=True)
                d_loss = d_real_loss + d_fake_loss
            optimD.step()
            netD.zero_grad()
            z = utils.sample_d(x_dist, args.batch_size)
            codes = netE(z)
            Eweights, Dweights = [], []
            i = 0
            for net in Enets:
                Eweights.append(net(codes[i]))
                i += 1
            for net in Dnets:
                Dweights.append(net(codes[i]))
                i += 1
            d_real = []
            for code in codes:
                d = netD(code)
                d_real.append(d)

            netD.zero_grad()
            d_loss = torch.stack(d_real).log().mean() * 10.

            for layers in zip(*(Eweights + Dweights)):
                loss, _ = train_clf(args, layers, data, target)
                scaled_loss = args.beta * loss
                scaled_loss.backward(retain_graph=True)
                d_loss.backward(torch.tensor(-1, dtype=torch.float).cuda(),
                                retain_graph=True)
            optimE.step()
            for optim in Eoptim:
                optim.step()
            for optim in Doptim:
                optim.step()
            loss = loss.item()

            if batch_idx % 50 == 0:
                print('**************************************')
                print('AE MNIST Test, beta: {}'.format(args.beta))
                print('MSE Loss: {}'.format(loss))
                print('D loss: {}'.format(d_loss))
                print('best test loss: {}'.format(args.best_loss))
                print('**************************************')

            if batch_idx > 1 and batch_idx % 199 == 0:
                test_acc = 0.
                test_loss = 0.
                for i, (data, y) in enumerate(mnist_test):
                    z = utils.sample_d(x_dist, args.batch_size)
                    codes = netE(z)
                    Eweights, Dweights = [], []
                    i = 0
                    for net in Enets:
                        Eweights.append(net(codes[i]))
                        i += 1
                    for net in Dnets:
                        Dweights.append(net(codes[i]))
                        i += 1
                    for layers in zip(*(Eweights + Dweights)):
                        loss, out = train_clf(args, layers, data, y)
                        test_loss += loss.item()
                    if i == 10:
                        break
                test_loss /= 10 * len(y) * args.batch_size
                print('Test Loss: {}'.format(test_loss))
                if test_loss < best_test_loss:
                    print('==> new best stats, saving')
                    #utils.save_clf(args, z_test, test_acc)
                    if test_loss < best_test_loss:
                        best_test_loss = test_loss
                        args.best_loss = test_loss
                archE = sampleE(args).cuda()
                archD = sampleD(args).cuda()
                rand = np.random.randint(args.batch_size)
                eweight = list(zip(*Eweights))[rand]
                dweight = list(zip(*Dweights))[rand]
                modelE = utils.weights_to_clf(eweight, archE,
                                              args.statE['layer_names'])
                modelD = utils.weights_to_clf(dweight, archD,
                                              args.statD['layer_names'])
                utils.generate_image(args, batch_idx, modelE, modelD,
                                     data.cuda())
def train(args):

    torch.manual_seed(1)
    netE = models.Encoder(args).cuda()
    W1 = models.GeneratorW1(args).cuda()
    W2 = models.GeneratorW2(args).cuda()
    W3 = models.GeneratorW3(args).cuda()
    W4 = models.GeneratorW4(args).cuda()
    W5 = models.GeneratorW5(args).cuda()
    netD = models.DiscriminatorZ(args).cuda()
    print(netE, W1, W2, W3, W4, W5, netD)

    optimE = optim.Adam(netE.parameters(),
                        lr=5e-3,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)
    optimW1 = optim.Adam(W1.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW2 = optim.Adam(W2.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW3 = optim.Adam(W3.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW4 = optim.Adam(W4.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimW5 = optim.Adam(W5.parameters(),
                         lr=1e-4,
                         betas=(0.5, 0.9),
                         weight_decay=1e-4)
    optimD = optim.Adam(netD.parameters(),
                        lr=5e-5,
                        betas=(0.5, 0.9),
                        weight_decay=1e-4)

    best_test_acc, best_clf_acc, best_test_loss = 0., 0., np.inf
    args.best_loss, args.best_acc = best_test_loss, best_test_acc
    args.best_clf_loss, args.best_clf_acc = np.inf, 0.

    cifar_train, cifar_test = datagen.load_cifar(args)
    x_dist = utils.create_d(args.ze)
    z_dist = utils.create_d(args.z)
    qz_dist = utils.create_d(args.z * 5)

    one = torch.tensor(1).cuda()
    mone = one * -1
    print("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.pretrain_e:
        for j in range(1000):
            #x = utils.sample_d(x_dist, e_batch_size)
            #z = utils.sample_d(z_dist, e_batch_size)
            x = utils.sample_z_like((e_batch_size, args.ze))
            z = utils.sample_z_like((e_batch_size, args.z))
            codes = netE(x)
            for code in codes:
                code = code.view(e_batch_size, args.z)
                mean_loss, cov_loss = ops.pretrain_loss(code, z)
                #codes = torch.stack(netE(x)).view(-1, args.z*5)
                #qz = utils.sample_d(qz_dist, e_batch_size)
                #mean_loss, cov_loss = ops.pretrain_loss(codes, qz)
                #loss.backward()
                loss = mean_loss + cov_loss
                loss.backward(retain_graph=True)
            optimE.step()
            netE.zero_grad()
            print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print('Finished Pretraining Encoder')
                break

    print('==> Begin Training')
    for _ in range(args.epochs):
        for batch_idx, (data, target) in enumerate(cifar_train):
            #z = utils.sample_d(x_dist, args.batch_size)
            #ze = utils.sample_d(z_dist, args.batch_size)
            #qz = utils.sample_d(qz_dist, args.batch_size)
            z = utils.sample_z_like((args.batch_size, args.ze))
            codes = netE(z)
            """
            noise = utils.sample_d(qz_dist, args.batch_size)
            log_pz = ops.log_density(ze, 2).view(-1, 1)
            d_loss, d_q = ops.calc_d_loss(args, netD, ze, codes, log_pz, cifar=True)
            optimD.zero_grad()
            d_loss.backward(retain_graph=True)
            optimD.step()
            """
            l1 = W1(codes[0])
            l2 = W2(codes[1])
            l3 = W3(codes[2])
            l4 = W4(codes[3])
            l5 = W5(codes[4])

            #gp,grads, norms = ops.calc_gradient_penalty(z, [W1, W2, W3, W4, W5],netE, cifar=True)
            #reduce = lambda x: x.mean(0).mean(0).item()
            #grads = [reduce(grad) for grad in grads]
            clf_loss = 0.
            for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data,
                                          target)
                clf_loss += loss
                loss.backward(retain_graph=True)
            G_loss = clf_loss  #/ args.batch_size
            #one_qz = torch.ones((160, 1), requires_grad=True).cuda()
            #log_qz = ops.log_density(torch.ones(160, 1), 2).view(-1, 1)
            #Q_loss = F.binary_cross_entropy_with_logits(d_q+log_qz, one_qz)
            #total_hyper_loss = Q_loss + G_loss
            total_hyper_loss = G_loss
            #total_hyper_loss.backward()

            optimE.step()
            optimW1.step()
            optimW2.step()
            optimW4.step()
            optimW5.step()
            #optimE.zero_grad(); optimW1.zero_grad(); optimW2.zero_grad();
            #optimW3.zero_grad(); optimW4.zero_grad(); optimW5.zero_grad()

            total_loss = total_hyper_loss.item()

            if batch_idx % 50 == 0:
                acc = correct
                print('**************************************')
                print('Iter: {}'.format(len(logger['acc'])))
                #print ('Acc: {}, MD Loss: {}, D Loss: {}'.format(acc, total_hyper_loss, d_loss))
                print('Acc: {}, MD Loss: {}'.format(acc, loss))
                #print ('penalties: ', [gp[x].item() for x in range(len(gp))])
                #print ('grads: ', grads)
                print('best test loss: {}'.format(args.best_loss))
                print('best test acc: {}'.format(args.best_acc))
                print('best clf acc: {}'.format(args.best_clf_acc))
                print('**************************************')

            if batch_idx > 1 and batch_idx % 100 == 0:
                test_acc = 0.
                test_loss = 0.
                with torch.no_grad():
                    for i, (data, y) in enumerate(cifar_test):
                        #z = utils.sample_d(x_dist, args.batch_size)
                        z = utils.sample_z_like((args.batch_size, args.ze))
                        codes = netE(z)
                        l1 = W1(codes[0])
                        l2 = W2(codes[1])
                        l3 = W3(codes[2])
                        l4 = W4(codes[3])
                        l5 = W5(codes[4])
                        for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                            loss, correct = train_clf(args,
                                                      [g1, g2, g3, g4, g5],
                                                      data, y)
                            test_acc += correct.item()
                            test_loss += loss.item()
                    test_loss /= len(cifar_test.dataset) * args.batch_size
                    test_acc /= len(cifar_test.dataset) * args.batch_size
                    clf_acc, clf_loss = test_clf(args, [l1, l2, l3, l4, l5])
                    #stats.update_logger(l1, l2, l3, l4, l5, logger)
                    stats.update_acc(logger, test_acc)
                    #stats.update_grad(logger, grads, norms)
                    #stats.save_logger(logger, args.exp)
                    #stats.plot_logger(logger)

                    print('Test Accuracy: {}, Test Loss: {}'.format(
                        test_acc, test_loss))
                    print('Clf Accuracy: {}, Clf Loss: {}'.format(
                        clf_acc, clf_loss))
                    if test_loss < best_test_loss:
                        best_test_loss, args.best_loss = test_loss, test_loss
                    if test_acc > best_test_acc:
                        best_test_acc, args.best_acc = test_acc, test_acc
                    if clf_acc > best_clf_acc:
                        best_clf_acc, args.best_clf_acc = clf_acc, clf_acc
                        utils.save_hypernet_cifar(
                            args, [netE, netD, W1, W2, W3, W4, W5], clf_acc)