Example #1
0
def evaluate_recon(sess, dcgan, args, dataset, train_iter):
    for (gi, ei), recon in dcgan.reconstructers.items():
        for c in range(dataset.num_classes):
            inputs_c, _ = dataset.next_batch(args.batch_size, class_id=c)
            recons_c = sess.run(recon, feed_dict={dcgan.inputs: inputs_c})
            filename = "B_DCGAN_RECON_g%i_e%i_c%i" % (gi, ei, c)
            print_images(recons_c,
                         filename,
                         train_iter,
                         directory=args.out_dir)
Example #2
0
def b_dcgan(dataset, args):

    z_dim = args.z_dim
    x_dim = dataset.x_dim
    batch_size = args.batch_size
    dataset_size = dataset.dataset_size

    session = get_session()
    tf.set_random_seed(args.random_seed)
    # due to how much the TF code sucks all functions take fixed batch_size at all times
    dcgan = BDCGAN_Semi(x_dim,
                        z_dim,
                        dataset_size,
                        batch_size=batch_size,
                        J=args.J,
                        J_d=args.J_d,
                        M=args.M,
                        num_layers=args.num_layers,
                        lr=args.lr,
                        optimizer=args.optimizer,
                        gf_dim=args.gf_dim,
                        df_dim=args.df_dim,
                        ml=(args.ml and args.J == 1 and args.M == 1
                            and args.J_d == 1),
                        num_classes=dataset.num_classes)

    print "Starting session"
    session.run(tf.global_variables_initializer())

    print "Starting training loop"

    num_train_iter = args.train_iter

    if hasattr(dataset, "supervised_batches"):
        # implement own data feeder if data doesnt fit in memory
        supervised_batches = dataset.supervised_batches(args.N, batch_size)
    else:
        supervised_batches = get_supervised_batches(dataset, args.N,
                                                    batch_size,
                                                    range(dataset.num_classes))

    test_image_batches, test_label_batches = get_test_batches(
        dataset, batch_size)

    optimizer_dict = {
        "disc_semi": dcgan.d_optims_semi_adam,
        "sup_d": dcgan.s_optim_adam,
        "gen": dcgan.g_optims_semi_adam
    }

    base_learning_rate = args.lr  # for now we use same learning rate for Ds and Gs
    lr_decay_rate = args.lr_decay
    num_disc = args.J_d

    for train_iter in range(num_train_iter):

        if train_iter == 5000:
            print "Switching to user-specified optimizer"
            optimizer_dict = {
                "disc_semi": dcgan.d_optims_semi,
                "sup_d": dcgan.s_optim,
                "gen": dcgan.g_optims_semi
            }

        learning_rate = base_learning_rate * np.exp(-lr_decay_rate * min(
            1.0, (train_iter * batch_size) / float(dataset_size)))

        image_batch, _ = dataset.next_batch(batch_size, class_id=None)
        labeled_image_batch, labels = supervised_batches.next()

        ### compute disc losses
        batch_z = np.random.uniform(-1, 1, [batch_size, z_dim, dcgan.num_gen])
        disc_info = session.run(
            optimizer_dict["disc_semi"] +
            dcgan.d_losses,  # + [dcgan.d_probs] + [dcgan.d_hh],
            feed_dict={
                dcgan.labeled_inputs: labeled_image_batch,
                dcgan.labels: labels,
                dcgan.inputs: image_batch,
                dcgan.z: batch_z,
                dcgan.d_semi_learning_rate: learning_rate
            })

        d_losses = disc_info[num_disc:num_disc * 2]

        #print disc_info[num_disc*2:num_disc*3][0][:, 0]

        ### compute generative losses
        batch_z = np.random.uniform(-1, 1, [batch_size, z_dim, dcgan.num_gen])
        gen_info = session.run(optimizer_dict["gen"] + dcgan.g_losses,
                               feed_dict={
                                   dcgan.z: batch_z,
                                   dcgan.inputs: image_batch,
                                   dcgan.g_learning_rate: learning_rate
                               })
        g_losses = [g_ for g_ in gen_info if g_ is not None]

        ### vanilla supervised loss
        _, s_loss = session.run([optimizer_dict["sup_d"], dcgan.s_loss],
                                feed_dict={
                                    dcgan.inputs: labeled_image_batch,
                                    dcgan.lbls: labels
                                })

        if train_iter > 0 and train_iter % args.n_save == 0:

            print "Iter %i" % train_iter
            print "Disc losses = %s" % (", ".join(
                ["%.2f" % dl for dl in d_losses]))
            print "Gen losses = %s" % (", ".join(
                ["%.2f" % gl for gl in g_losses]))

            # get test set performance on real labels only for both GAN-based classifier and standard one
            s_acc, ss_acc = get_test_accuracy(session, dcgan,
                                              test_image_batches,
                                              test_label_batches)
            print "Sup classification acc: %.2f" % (s_acc)
            print "Semi-sup classification acc: %.2f" % (ss_acc)

            print "saving results and samples"

            results = {
                "disc_losses": map(float, d_losses),
                "gen_losses": map(float, g_losses),
                "supervised_acc": float(s_acc),
                "semi_supervised_acc": float(ss_acc),
                "timestamp": time.time()
            }

            with open(
                    os.path.join(args.out_dir, 'results_%i.json' % train_iter),
                    'w') as fp:
                json.dump(results, fp)

            if args.save_samples:
                for zi in xrange(dcgan.num_gen):
                    _imgs, _ps = [], []
                    for _ in range(10):
                        z_sampler = np.random.uniform(-1,
                                                      1,
                                                      size=(batch_size, z_dim))
                        sampled_imgs = session.run(
                            dcgan.gen_samplers[zi * dcgan.num_mcmc],
                            feed_dict={dcgan.z_sampler: z_sampler})
                        _imgs.append(sampled_imgs)
                    sampled_imgs = np.concatenate(_imgs)
                    print_images(sampled_imgs,
                                 "B_DCGAN_%i_%.2f" %
                                 (zi, g_losses[zi * dcgan.num_mcmc]),
                                 train_iter,
                                 directory=args.out_dir)

                print_images(image_batch,
                             "RAW",
                             train_iter,
                             directory=args.out_dir)

            if args.save_weights:
                var_dict = {}
                for var in tf.trainable_variables():
                    var_dict[var.name] = session.run(var.name)

                np.savez_compressed(
                    os.path.join(args.out_dir, "weights_%i.npz" % train_iter),
                    **var_dict)

            print "done"
Example #3
0
def b_dcgan(dataset, args):

    z_dim = args.z_dim
    x_dim = dataset.x_dim
    batch_size = args.batch_size
    dataset_size = dataset.dataset_size

    session = get_session()
    tf.set_random_seed(args.random_seed)

    dcgan = BDCGAN(x_dim,
                   z_dim,
                   dataset_size,
                   batch_size=batch_size,
                   J=args.J,
                   J_d=args.J_d,
                   M=args.M,
                   num_layers=args.num_layers,
                   lr=args.lr,
                   optimizer=args.optimizer,
                   gf_dim=args.gf_dim,
                   df_dim=args.df_dim,
                   ml=(args.ml and args.J == 1 and args.M == 1
                       and args.J_d == 1))

    print "Starting session"
    session.run(tf.global_variables_initializer())

    print "Starting training loop"

    num_train_iter = args.train_iter

    optimizer_dict = {"disc": dcgan.d_optims_adam, "gen": dcgan.g_optims_adam}

    base_learning_rate = args.lr  # for now we use same learning rate for Ds and Gs
    lr_decay_rate = args.lr_decay
    num_disc = args.J_d

    for train_iter in range(num_train_iter):

        if train_iter == 5000:
            print "Switching to user-specified optimizer"
            optimizer_dict = {
                "disc": dcgan.d_optims_adam,
                "gen": dcgan.g_optims_adam
            }

        learning_rate = base_learning_rate * np.exp(-lr_decay_rate * min(
            1.0, (train_iter * batch_size) / float(dataset_size)))

        image_batch, _ = dataset.next_batch(batch_size, class_id=None)

        ### compute disc losses
        batch_z = np.random.uniform(-1, 1, [batch_size, z_dim, dcgan.num_gen])
        disc_info = session.run(optimizer_dict["disc"] + dcgan.d_losses,
                                feed_dict={
                                    dcgan.inputs: image_batch,
                                    dcgan.z: batch_z,
                                    dcgan.d_learning_rate: learning_rate
                                })
        d_losses = [d_ for d_ in disc_info if d_ is not None]

        ### compute generative losses
        batch_z = np.random.uniform(-1, 1, [batch_size, z_dim, dcgan.num_gen])
        gen_info = session.run(optimizer_dict["gen"] + dcgan.g_losses,
                               feed_dict={
                                   dcgan.z: batch_z,
                                   dcgan.inputs: image_batch,
                                   dcgan.g_learning_rate: learning_rate
                               })
        g_losses = [g_ for g_ in gen_info if g_ is not None]

        if train_iter > 0 and train_iter % args.n_save == 0:

            print "Iter %i" % train_iter
            print "Disc losses = %s" % (", ".join(
                ["%.2f" % dl for dl in d_losses]))
            print "Gen losses = %s" % (", ".join(
                ["%.2f" % gl for gl in g_losses]))

            print "saving results and samples"

            results = {
                "disc_losses": map(float, d_losses),
                "gen_losses": map(float, g_losses),
                "timestamp": time.time()
            }

            with open(
                    os.path.join(args.out_dir, 'results_%i.json' % train_iter),
                    'w') as fp:
                json.dump(results, fp)

            if args.save_samples:
                for zi in xrange(dcgan.num_gen):
                    _imgs, _ps = [], []
                    for _ in range(10):
                        z_sampler = np.random.uniform(-1,
                                                      1,
                                                      size=(batch_size, z_dim))
                        sampled_imgs = session.run(
                            dcgan.gen_samplers[zi * dcgan.num_mcmc],
                            feed_dict={dcgan.z_sampler: z_sampler})
                        _imgs.append(sampled_imgs)
                    sampled_imgs = np.concatenate(_imgs)
                    print_images(sampled_imgs,
                                 "B_DCGAN_%i_%.2f" %
                                 (zi, g_losses[zi * dcgan.num_mcmc]),
                                 train_iter,
                                 directory=args.out_dir)

                print_images(image_batch,
                             "RAW",
                             train_iter,
                             directory=args.out_dir)

            if args.save_weights:
                var_dict = {}
                for var in tf.trainable_variables():
                    var_dict[var.name] = session.run(var.name)

                np.savez_compressed(
                    os.path.join(args.out_dir, "weights_%i.npz" % train_iter),
                    **var_dict)

            print "done"
def b_dcgan(dataset, args):

    z_dim = args.z_dim
    x_dim = dataset.x_dim
    batch_size = args.batch_size
    dataset_size = dataset.dataset_size

    session = get_session()

    test_x = tf.placeholder(tf.float32, shape=(batch_size, 28, 28, 1))
    x = tf.placeholder(tf.float32, shape=(batch_size, 28, 28, 1))
    y = tf.placeholder(tf.float32, shape=(batch_size, 10))

    unlabeled_batch_ph = tf.placeholder(tf.float32,
                                        shape=(batch_size, 28, 28, 1))
    labeled_image_ph = tf.placeholder(tf.float32,
                                      shape=(batch_size, 28, 28, 1))
    if args.random_seed is not None:
        tf.set_random_seed(args.random_seed)
    # due to how much the TF code sucks all functions take fixed batch_size at all times
    dcgan = BDCGAN(
        x_dim,
        z_dim,
        dataset_size,
        batch_size=batch_size,
        J=args.J,
        M=args.M,
        lr=args.lr,
        optimizer=args.optimizer,
        gen_observed=args.gen_observed,
        adv_train=args.adv_train,
        num_classes=dataset.num_classes if args.semi_supervised else 1)
    if args.adv_test and args.semi_supervised:
        if args.basic_iterative:
            fgsm = BasicIterativeMethod(dcgan, sess=session)
            dcgan.adv_constructor = fgsm
            fgsm_params = {
                'eps': args.eps,
                'eps_iter': float(args.eps / 4),
                'nb_iter': 4,
                'ord': np.inf,
                'clip_min': 0.,
                'clip_max': 1.
            }
            #,'y_target': None}
        else:
            fgsm = FastGradientMethod(dcgan, sess=session)
            dcgan.adv_constructor = fgsm
            eval_params = {'batch_size': batch_size}
            fgsm_params = {'eps': args.eps, 'clip_min': 0., 'clip_max': 1.}
        adv_x = fgsm.generate(x, **fgsm_params)
        adv_test_x = fgsm.generate(test_x, **fgsm_params)
        preds = dcgan.get_probs(adv_x)
    if args.adv_train:
        unlabeled_targets = np.zeros([batch_size, dcgan.K + 1])
        unlabeled_targets[:, 0] = 1
        fgsm_targeted_params = {
            'eps': args.eps,
            'clip_min': 0.,
            'clip_max': 1.,
            'y_target': unlabeled_targets
        }

    saver = tf.train.Saver()

    print("Starting session")
    session.run(tf.global_variables_initializer())

    prev_iters = 0
    if args.load_chkpt:
        saver.restore(session, args.chkpt)
        # Assume checkpoint is of the form "model_300"
        prev_iters = int(args.chkpt.split('/')[-1].split('_')[1])
        print("Model restored from iteration:", prev_iters)

    print("Starting training loop")
    num_train_iter = args.train_iter

    if hasattr(dataset, "supervised_batches"):
        # implement own data feeder if data doesnt fit in memory
        supervised_batches = dataset.supervised_batches(args.N, batch_size)
    else:
        supervised_batches = get_supervised_batches(
            dataset, args.N, batch_size, list(range(dataset.num_classes)))

    if args.semi_supervised:
        test_image_batches, test_label_batches = get_test_batches(
            dataset, batch_size)

        optimizer_dict = {
            "semi_d": dcgan.d_optim_semi_adam,
            "sup_d": dcgan.s_optim_adam,
            "adv_d": dcgan.d_optim_adam,
            "gen": dcgan.g_optims_adam
        }
    else:
        optimizer_dict = {
            "adv_d": dcgan.d_optim_adam,
            "gen": dcgan.g_optims_adam
        }

    base_learning_rate = args.lr  # for now we use same learning rate for Ds and Gs
    lr_decay_rate = args.lr_decay

    for train_iter in range(1 + prev_iters, 1 + num_train_iter):

        if train_iter == 5000:
            print("Switching to user-specified optimizer")
            if args.semi_supervised:
                optimizer_dict = {
                    "semi_d": dcgan.d_optim_semi,
                    "sup_d": dcgan.s_optim,
                    "adv_d": dcgan.d_optim,
                    "gen": dcgan.g_optims
                }
            else:
                optimizer_dict = {
                    "adv_d": dcgan.d_optim,
                    "gen": dcgan.g_optims
                }

        learning_rate = base_learning_rate * np.exp(-lr_decay_rate * min(
            1.0, (train_iter * batch_size) / float(dataset_size)))

        batch_z = np.random.uniform(-1, 1, [batch_size, z_dim])
        image_batch, batch_label = dataset.next_batch(batch_size,
                                                      class_id=None)
        batch_targets = np.zeros([batch_size, 11])
        batch_targets[:, 0] = 1

        if args.semi_supervised:
            labeled_image_batch, labels = next(supervised_batches)
            if args.adv_train:
                adv_labeled = session.run(
                    fgsm.generate(labeled_image_ph, **fgsm_targeted_params),
                    feed_dict={labeled_image_ph: labeled_image_batch})
                adv_unlabeled = session.run(
                    fgsm.generate(unlabeled_batch_ph, **fgsm_params),
                    feed_dict={unlabeled_batch_ph: image_batch})
                _, d_loss = session.run(
                    [optimizer_dict["semi_d"], dcgan.d_loss_semi],
                    feed_dict={
                        dcgan.labeled_inputs: labeled_image_batch,
                        dcgan.labels: get_gan_labels(labels),
                        dcgan.inputs: image_batch,
                        dcgan.z: batch_z,
                        dcgan.d_semi_learning_rate: learning_rate,
                        dcgan.adv_unlab: adv_unlabeled,
                        dcgan.adv_labeled: adv_labeled
                    })
            else:
                _, d_loss = session.run(
                    [optimizer_dict["semi_d"], dcgan.d_loss_semi],
                    feed_dict={
                        dcgan.labeled_inputs: labeled_image_batch,
                        dcgan.labels: get_gan_labels(labels),
                        dcgan.inputs: image_batch,
                        dcgan.z: batch_z,
                        dcgan.d_semi_learning_rate: learning_rate
                    })

            _, s_loss = session.run([optimizer_dict["sup_d"], dcgan.s_loss],
                                    feed_dict={
                                        dcgan.inputs: labeled_image_batch,
                                        dcgan.lbls: labels
                                    })

        else:
            # regular GAN
            _, d_loss = session.run(
                [optimizer_dict["adv_d"], dcgan.d_loss],
                feed_dict={
                    dcgan.inputs: image_batch,
                    dcgan.z: batch_z,
                    dcgan.d_learning_rate: learning_rate
                })

        if args.wasserstein:
            session.run(dcgan.clip_d, feed_dict={})

        g_losses = []
        for gi in range(dcgan.num_gen):

            # compute g_sample loss
            batch_z = np.random.uniform(-1, 1, [batch_size, z_dim])
            for m in range(dcgan.num_mcmc):
                _, g_loss = session.run([
                    optimizer_dict["gen"][gi * dcgan.num_mcmc + m],
                    dcgan.generation["g_losses"][gi * dcgan.num_mcmc + m]
                ],
                                        feed_dict={
                                            dcgan.z: batch_z,
                                            dcgan.g_learning_rate:
                                            learning_rate
                                        })
                g_losses.append(g_loss)

        # if args.adv_test:
        #     probs, logits = dcgan.discriminator(adv_x,dcgan.K+1,reuse = True)

        #     labels = tf.placeholder(tf.float32,
        #                              [args.batch_size, dcgan.K+1], name='real_targets')
        #     compare_labels = tf.convert_to_tensor(np.array([np.append(0,i) for i in batch_label]))

        #     print(session.run(model_loss(compare_labels,probs), feed_dict = {x:image_batch}))
        # if args.adv_test:
        #     #preds = dcgan.get_probs(adv_x)
        #     #eval_preds = session.run(preds, feed_dict = {x:image_batch})
        #     #print(eval_preds[0])
        #     #adv_exs = session.run(adv_test_x, feed_dict = {x:test_image_batches})
        #     # adv_acc = model_eval(
        #     #     session, x, y, preds, image_batch, batch_label, args=eval_params)
        #     # #print(session.run(model_loss(compare_labels,probs), feed_dict = {x:image_batch}))
        #     # print("Adversarial loss = %2.f" % (1-adv_acc))
        #     print(get_test_accuracy(session,dcgan,adv_set,test_label_batches))

        if train_iter > 0 and train_iter % args.n_save == 0:
            print("Iter %i" % train_iter)
            # collect samples
            if args.save_samples:  # saving samples
                all_sampled_imgs = []
                for gi in range(dcgan.num_gen):
                    _imgs, _ps = [], []
                    for _ in range(10):
                        sample_z = np.random.uniform(-1,
                                                     1,
                                                     size=(batch_size, z_dim))
                        sampled_imgs, sampled_probs = session.run([
                            dcgan.generation["gen_samplers"][gi *
                                                             dcgan.num_mcmc],
                            dcgan.generation["d_probs"][gi * dcgan.num_mcmc]
                        ],
                                                                  feed_dict={
                                                                      dcgan.z:
                                                                      sample_z
                                                                  })
                        _imgs.append(sampled_imgs)
                        _ps.append(sampled_probs)

                    sampled_imgs = np.concatenate(_imgs)
                    sampled_probs = np.concatenate(_ps)
                    all_sampled_imgs.append(
                        [sampled_imgs, sampled_probs[:, 1:].sum(1)])

            print("Disc loss = %.2f, Gen loss = %s" %
                  (d_loss, ", ".join(["%.2f" % gl for gl in g_losses])))

            #if args.adv_test:
            #preds = dcgan.get_probs(adv_x)
            #eval_preds = session.run(preds, feed_dict = {x:image_batch})
            #print(eval_preds[0])
            #adv_exs = session.run(adv_test_x, feed_dict = {x:test_image_batches})
            # adv_acc = model_eval(
            #     session, x, y, preds, image_batch, batch_label, args=eval_params)
            # #print(session.run(model_loss(compare_labels,probs), feed_dict = {x:image_batch}))
            # print("Adversarial loss = %2.f" % (1-adv_acc))
            #print(get_test_accuracy(session,dcgan,adv_set,test_label_batches))

            # adv_x = fgsm.generate(x,**fgsm_params)
            # preds = dcgan.get_probs(adv_x)
            # acc = model_eval(
            #     session, x, y, preds, image_batch, batch_label, args=eval_params)
            # print("Adversarial loss = %2.f" % (1-acc))

            if args.semi_supervised:
                # get test set performance on real labels only for both GAN-based classifier and standard one

                s_acc, ss_acc, non_adv_acc, ex_prob = get_test_accuracy(
                    session, dcgan, test_image_batches, test_label_batches)
                if args.adv_test:
                    adv_set = []
                    for test_images in test_image_batches:
                        adv_set.append(
                            session.run(adv_x, feed_dict={x: test_images}))
                    adv_sup_acc, adv_ss_acc, correct_uncertainty, incorrect_uncertainty, adv_acc, adv_ex_prob = get_adv_test_accuracy(
                        session, dcgan, adv_set, test_label_batches)
                    print("Adversarial semi-sup accuracy with filter: %.2f" %
                          adv_sup_acc)
                    print("Adverarial semi-sup accuracy: %.2f" % adv_ss_acc)
                    print("Uncertainty for correct predictions: %.2f" %
                          correct_uncertainty)
                    print("Uncertainty for incorrect predictions: %.2f" %
                          incorrect_uncertainty)
                    print("non_adversarial_classification_accuracy: %.2f" %
                          non_adv_acc)
                    print("adversarial_classification_accuracy: %.2f" %
                          adv_acc)

                    if args.save_samples:
                        print("saving adversarial test images and test images")
                        i = 0

                        for x, y in zip(adv_set[-1], test_image_batches[-1]):
                            np.save(
                                args.out_dir + '/adv_test' + str(train_iter) +
                                '_' + str(i), x)
                            np.save(
                                args.out_dir + '/test' + str(train_iter) +
                                '_' + str(i), y)
                            i = i + 1
                            if i == 5:  #save 5 adversarial images
                                break

                print("Supervised acc: %.2f" % (s_acc))
                print("Semi-sup acc: %.2f" % (ss_acc))

            print("saving results and samples")

            results = {
                "disc_loss": float(d_loss),
                "gen_losses": list(map(float, g_losses))
            }
            if args.semi_supervised:
                #results["example_non_adversarial_probs"] = list(ex_prob.flatten())
                #results["example_adversarial_probs"] = list(adv_ex_prob.flatten())
                results["non_adversarial_classification_accuracy"] = float(
                    non_adv_acc)
                results["adversarial_classification_accuracy"] = float(adv_acc)
                results["adversarial_uncertainty_correct"] = float(
                    correct_uncertainty)
                results["adversarial_uncertainty_incorrect"] = float(
                    incorrect_uncertainty)
                results["supervised_acc"] = float(s_acc)
                results['adversarial_filtered_semi_supervised_acc'] = float(
                    adv_sup_acc)
                results["adversarial_unfilted_semi_supervised_acc"] = float(
                    adv_ss_acc)
                results["semi_supervised_acc"] = float(ss_acc)
                results["timestamp"] = time.time()
                results["previous_chkpt"] = args.chkpt

            with open(
                    os.path.join(args.out_dir, 'results_%i.json' % train_iter),
                    'w') as fp:
                json.dump(results, fp)

            if args.save_samples:
                for gi in range(dcgan.num_gen):
                    print_images(all_sampled_imgs[gi],
                                 "B_DCGAN_%i_%.2f" %
                                 (gi, g_losses[gi * dcgan.num_mcmc]),
                                 train_iter,
                                 directory=args.out_dir)

                print_images(image_batch,
                             "RAW",
                             train_iter,
                             directory=args.out_dir)

            if args.save_weights:
                var_dict = {}
                for var in tf.trainable_variables():
                    var_dict[var.name] = session.run(var.name)

                np.savez_compressed(
                    os.path.join(args.out_dir, "weights_%i.npz" % train_iter),
                    **var_dict)

            print("Done saving weights")

        if train_iter > 0 and train_iter % args.save_chkpt == 0:
            save_path = saver.save(
                session, os.path.join(args.out_dir, "model_%i" % train_iter))
            print("Model checkpointed in file: %s" % save_path)

    session.close()
Example #5
0
def b_dcgan(dataset, args):

    z_dim = args.z_dim
    x_dim = dataset.x_dim
    batch_size = args.batch_size
    dataset_size = dataset.dataset_size

    session = get_session()
    if args.random_seed is not None:
        tf.set_random_seed(args.random_seed)
    # due to how much the TF code sucks all functions take fixed batch_size at all times
    dcgan = BDCGAN(
        x_dim,
        z_dim,
        dataset_size,
        batch_size=batch_size,
        J=args.J,
        M=args.M,
        lr=args.lr,
        optimizer=args.optimizer,
        gen_observed=args.gen_observed,
        num_classes=dataset.num_classes if args.semi_supervised else 1)

    print("Starting session")
    session.run(tf.global_variables_initializer())

    print("Starting training loop")

    num_train_iter = args.train_iter

    if hasattr(dataset, "supervised_batches"):
        # implement own data feeder if data doesnt fit in memory
        supervised_batches = dataset.supervised_batches(args.N, batch_size)
    else:
        supervised_batches = get_supervised_batches(
            dataset, args.N, batch_size, list(range(dataset.num_classes)))

    test_image_batches, test_label_batches = get_test_batches(
        dataset, batch_size)

    optimizer_dict = {
        "semi_d": dcgan.d_optim_semi_adam,
        "sup_d": dcgan.s_optim_adam,
        "adv_d": dcgan.d_optim_adam,
        "gen": dcgan.g_optims_adam
    }

    base_learning_rate = args.lr  # for now we use same learning rate for Ds and Gs
    lr_decay_rate = args.lr_decay

    for train_iter in range(num_train_iter):

        if train_iter == 5000:
            print("Switching to user-specified optimizer")
            optimizer_dict = {
                "semi_d": dcgan.d_optim_semi,
                "sup_d": dcgan.s_optim,
                "adv_d": dcgan.d_optim,
                "gen": dcgan.g_optims
            }

        learning_rate = base_learning_rate * np.exp(-lr_decay_rate * min(
            1.0, (train_iter * batch_size) / float(dataset_size)))

        batch_z = np.random.uniform(-1, 1, [batch_size, z_dim])
        image_batch, _ = dataset.next_batch(batch_size, class_id=None)

        if args.semi_supervised:

            labeled_image_batch, labels = next(supervised_batches)

            _, d_loss = session.run(
                [optimizer_dict["semi_d"], dcgan.d_loss_semi],
                feed_dict={
                    dcgan.labeled_inputs: labeled_image_batch,
                    dcgan.labels: get_gan_labels(labels),
                    dcgan.inputs: image_batch,
                    dcgan.z: batch_z,
                    dcgan.d_semi_learning_rate: learning_rate
                })

            _, s_loss = session.run([optimizer_dict["sup_d"], dcgan.s_loss],
                                    feed_dict={
                                        dcgan.inputs: labeled_image_batch,
                                        dcgan.lbls: labels
                                    })

        else:
            # regular GAN
            _, d_loss = session.run(
                [optimizer_dict["adv_d"], dcgan.d_loss],
                feed_dict={
                    dcgan.inputs: image_batch,
                    dcgan.z: batch_z,
                    dcgan.d_learning_rate: learning_rate
                })

        if args.wasserstein:
            session.run(dcgan.clip_d, feed_dict={})

        g_losses = []
        for gi in range(dcgan.num_gen):

            # compute g_sample loss
            batch_z = np.random.uniform(-1, 1, [batch_size, z_dim])
            for m in range(dcgan.num_mcmc):
                _, g_loss = session.run([
                    optimizer_dict["gen"][gi * dcgan.num_mcmc + m],
                    dcgan.generation["g_losses"][gi * dcgan.num_mcmc + m]
                ],
                                        feed_dict={
                                            dcgan.z: batch_z,
                                            dcgan.g_learning_rate:
                                            learning_rate
                                        })
                g_losses.append(g_loss)

        if train_iter > 0 and train_iter % args.n_save == 0:

            print("Iter %i" % train_iter)
            # collect samples
            if args.save_samples:  # saving samples
                all_sampled_imgs = []
                for gi in range(dcgan.num_gen):
                    _imgs, _ps = [], []
                    for _ in range(10):
                        sample_z = np.random.uniform(-1,
                                                     1,
                                                     size=(batch_size, z_dim))
                        sampled_imgs, sampled_probs = session.run([
                            dcgan.generation["gen_samplers"][gi *
                                                             dcgan.num_mcmc],
                            dcgan.generation["d_probs"][gi * dcgan.num_mcmc]
                        ],
                                                                  feed_dict={
                                                                      dcgan.z:
                                                                      sample_z
                                                                  })
                        _imgs.append(sampled_imgs)
                        _ps.append(sampled_probs)

                    sampled_imgs = np.concatenate(_imgs)
                    sampled_probs = np.concatenate(_ps)
                    all_sampled_imgs.append(
                        [sampled_imgs, sampled_probs[:, 1:].sum(1)])

            print("Disc loss = %.2f, Gen loss = %s" %
                  (d_loss, ", ".join(["%.2f" % gl for gl in g_losses])))
            if args.semi_supervised:
                # get test set performance on real labels only for both GAN-based classifier and standard one
                s_acc, ss_acc = get_test_accuracy(session, dcgan,
                                                  test_image_batches,
                                                  test_label_batches)

                print("Sup classification acc: %.2f" % (s_acc))
                print("Semi-sup classification acc: %.2f" % (ss_acc))

            print("saving results and samples")

            results = {
                "disc_loss": float(d_loss),
                "gen_losses": list(map(float, g_losses))
            }
            if args.semi_supervised:
                results["supervised_acc"] = float(s_acc)
                results["semi_supervised_acc"] = float(ss_acc)
                results["timestamp"] = time.time()

            with open(
                    os.path.join(args.out_dir, 'results_%i.json' % train_iter),
                    'w') as fp:
                json.dump(results, fp)

            if args.save_samples:
                for gi in range(dcgan.num_gen):
                    print_images(all_sampled_imgs[gi],
                                 "B_DCGAN_%i_%.2f" %
                                 (gi, g_losses[gi * dcgan.num_mcmc]),
                                 train_iter,
                                 directory=args.out_dir)

                print_images(image_batch,
                             "RAW",
                             train_iter,
                             directory=args.out_dir)

            if args.save_weights:
                var_dict = {}
                for var in tf.trainable_variables():
                    var_dict[var.name] = session.run(var.name)

                np.savez_compressed(
                    os.path.join(args.out_dir, "weights_%i.npz" % train_iter),
                    **var_dict)

            print("done")
Example #6
0
def train_dcgan(dataset, args, dcgan, sess):

    print("Starting sess")
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    print("Starting training loop")

    num_train_iter = args.train_iter

    optimizer_dict = dcgan.opt_adam_dict

    lr_decay_rate = args.lr_decay
    num_disc = args.J_d
    saver = tf.train.Saver()
    running_losses = {}

    for m in ["g", "e", "d_real", "d_fake", "recon"]:
        running_losses["%s_losses" % m] = np.empty(num_train_iter)
    running_losses["d_accuracy"] = np.empty(num_train_iter)
    d_update_threshold = args.d_update_threshold
    if args.d_update_decay_steps == "":
        d_update_decay_steps = []
    else:
        d_update_decay_steps = list(
            map(int, args.d_update_decay_steps.split(",")))
    base_lrs = {}
    base_lrs["g"] = args.gen_lr
    base_lrs["e"] = args.enc_lr
    base_lrs["d"] = args.disc_lr
    lrs = {}
    for k, v in base_lrs.items():
        lrs[k] = v
    d_lr = d_base_lr = args.disc_lr
    for train_iter in range(num_train_iter):
        for k, lr in base_lrs.items():
            if k != "d":
                continue
            lrs[k] = lr * np.exp(-args.lr_decay * min(
                1.0, train_iter * args.batch_size / dataset.dataset_size))
        print('lrs', ', '.join("%s %.7f" % (k, lr) for k, lr in lrs.items()))
        if train_iter == 5000:
            print("Switching to user-specified optimizer")
            optimizer_dict = dcgan.opt_user_dict

        if (train_iter + 1) in d_update_decay_steps:
            d_update_threshold = max(d_update_threshold - args.d_update_decay,
                                     args.d_update_bound)
            print("d update threshold:", d_update_threshold)

        image_batch, _ = dataset.next_batch(args.batch_size, class_id=None)

        ### compute disc losses
        batch_z = np.random.uniform(
            -1, 1, [args.batch_size, args.z_dim, dcgan.num_gen])
        #np.random.normal(0, 1, [args.batch_size, args.z_dim, dcgan.num_gen])

        d_feed_dict = {
            dcgan.inputs: image_batch,
            dcgan.z: batch_z,
            dcgan.d_learning_rate: lrs["d"]
        }
        d_losses_reals, d_losses_fakes = sess.run(
            [dcgan.d_losses_reals, dcgan.d_losses_fakes],
            feed_dict=d_feed_dict)

        d_real_acc = sess.run(dcgan.d_acc_reals, feed_dict=d_feed_dict)
        d_fake_acc = sess.run(dcgan.d_acc_fakes, feed_dict=d_feed_dict)
        d_mean_acc = np.mean(np.concatenate((d_real_acc, d_fake_acc)))
        if (args.disc_skip_update
                and train_iter % 2) == 0 and d_mean_acc < d_update_threshold:
            sess.run(optimizer_dict["disc"], feed_dict=d_feed_dict)
            d_updated = True
        else:
            d_updated = False
        ### compute encoder losses
        e_feed_dict = {
            dcgan.inputs: image_batch,
            dcgan.e_learning_rate: lrs["e"]
        }

        e_losses = sess.run(dcgan.e_losses, feed_dict=e_feed_dict)
        if train_iter + 1 > args.e_optimize_iter:
            sess.run(optimizer_dict["enc"], feed_dict=e_feed_dict)

        ### compute generative losses
        batch_z = np.random.uniform(
            -1, 1, [args.batch_size, args.z_dim, dcgan.num_gen])

        g_feed_dict = {
            dcgan.z: batch_z,
            dcgan.inputs: image_batch,
            dcgan.g_learning_rate: lrs["g"]
        }
        g_losses = sess.run(dcgan.g_losses, feed_dict=g_feed_dict)
        sess.run(optimizer_dict["gen"], feed_dict=g_feed_dict)

        recon_losses = sess.run(dcgan.recon_losses,
                                feed_dict={dcgan.inputs: image_batch})

        print("Iter %i" % train_iter)
        print_losses("Disc reals losses", d_losses_reals)
        print_losses("Disc fakes losses", d_losses_fakes)
        print_losses("Enc losses", e_losses)
        print_losses("Gen losses", g_losses)
        print_losses("Recon losses", recon_losses)
        print_losses("Disc acc real", d_real_acc)
        print_losses("Disc acc fake", d_fake_acc)
        print("D mean acc", d_mean_acc)
        print("D updated", d_updated)
        running_losses["g_losses"][train_iter] = np.mean(g_losses)
        running_losses["e_losses"][train_iter] = np.mean(e_losses)
        running_losses["d_real_losses"][train_iter] = np.mean(d_losses_reals)
        running_losses["d_fake_losses"][train_iter] = np.mean(d_losses_fakes)
        running_losses["recon_losses"][train_iter] = np.mean(recon_losses)
        running_losses["d_accuracy"][train_iter] = d_mean_acc
        if train_iter + 1 == num_train_iter or \
           (train_iter > 0 and train_iter  % args.n_save == 0):
            """ print_losses("Raw Disc", raw_d_losses)
            print_losses("Raw Enc", raw_e_losses)
            print_losses("Raw Gen", raw_g_losses)
            """
            print("saving results and samples")

            results = {
                "disc_losses_reals": list(map(float, d_losses_reals)),
                "disc_losses_fakes": list(map(float, d_losses_fakes)),
                "enc_losses": list(map(float, e_losses)),
                "gen_losses": list(map(float, g_losses)),
                "timestamp": time.time()
            }
            res_path = os.path.join(args.out_dir,
                                    "results_%i.json" % train_iter)
            with open(res_path, 'w') as fp:
                json.dump(results, fp)

            if args.save_samples:
                for zi, gen_sampler in enumerate(dcgan.gen_samplers):
                    sampled_imgs = []
                    for _ in range(10):
                        z_sampler = np.random.uniform(-1,
                                                      1,
                                                      size=(args.batch_size,
                                                            args.z_dim))
                        img = sess.run(gen_sampler,
                                       feed_dict={dcgan.z_sampler: z_sampler})
                        sampled_imgs.append(img)
                    sampled_imgs = np.concatenate(sampled_imgs)
                    print_images(sampled_imgs,
                                 "B_DCGAN_g%i" % zi,
                                 train_iter,
                                 directory=args.out_dir)

                #print_images(
                #    image_batch, "RAW", train_iter, directory=args.out_dir)
                evaluate_recon(sess, dcgan, args, dataset, train_iter)
            if args.evaluate_latent:
                all_latent_encodings = evaluate_latent(sess, dcgan, args,
                                                       dataset)
                for ei, latent_encodings in enumerate(all_latent_encodings):
                    for r in range(1):
                        filename = "latent_encodings_e%d_r%d_%d.png" \
                                   % (ei, r, train_iter)
                        plot_latent_encodings(latent_encodings,
                                              savename=os.path.join(
                                                  args.out_dir, filename))

    save_path = saver.save(sess, os.path.join(args.out_dir, "model.ckpt"))
    print("Model saved to %s" % save_path)

    losses_file = os.path.join(args.out_dir, "running_losses.npz")
    np.savez(losses_file, **running_losses)
    print("Saved running losses to", losses_file)

    plot_losses(running_losses,
                savename=os.path.join(args.out_dir, "losses_plot.png"))

    results = evaluate_classification(sess, dcgan, args, dataset)
    with open(os.path.join(args.out_dir, "classification.json"), "w") as fp:
        json.dump(results, fp)
    print("done")