Пример #1
0
def main(config):
    num_step = config.num_step
    data_loader = MNISTLoader(config.data_dir)

    names = []
    fobjs = []
    try:
        for _ in xrange(num_step):
            fd, name = tempfile.mkstemp(suffix=".npy")
            fobj = os.fdopen(fd, "wb+")
            names.append(name)
            fobjs.append(fobj)
            image_arr = data_loader.next_batch(config.batch_size)[0]
            np.save(fobj, image_arr, allow_pickle=False)
            fobj.close()

        mean_score, std_score = get_mnist_score(images_iter(names),
                                                config.model_path,
                                                batch_size=100,
                                                split=10)

        print("mean = %.4f, std = %.4f." % (mean_score, std_score))

        if config.save_path is not None:
            with open(config.save_path, "wb") as f:
                cPickle.dump(
                    dict(batch_size=config.batch_size,
                         scores=dict(mean=mean_score, std=std_score)), f)
    finally:
        for name in names:
            os.unlink(name)
        for fobj in fobjs:
            fobj.close()
Пример #2
0
    for params in GRID:
        for key, value in params.items():
            setattr(config, key, value)
        name = NAME_STYLE % params
        if save_dir is not None:
            config.save_dir = os.path.join(save_dir, name + "_models")
            os.makedirs(config.save_dir, exist_ok=True)
        if log_dir is not None:
            config.log_path = os.path.join(log_dir, name + ".log")

        print("config: %r" % config)
        print("resetting environment...")
        tf.reset_default_graph()

        train_data_loader = MNISTLoader(config.data_dir,
                                        include_test=False,
                                        first=config.public_num,
                                        seed=config.public_seed)
        eval_data_loader = MNISTLoader(config.data_dir,
                                       include_test=True,
                                       include_train=False)

        run_task(config,
                 train_data_loader,
                 eval_data_loader,
                 generator_forward,
                 code_classifier_forward,
                 image_classifier_forward,
                 image_classifier_optimizer=tf.train.AdamOptimizer(),
                 code_classifier_optimizer=tf.train.AdamOptimizer(),
                 model_path=config.model_path)
Пример #3
0
    expanded_labels = expanded_labels[indices]
    print(expanded_images.shape)
    print(expanded_labels.shape)

    if config.sample_ratio is not None:
        kwargs = {}
        gan_data_loader = MNISTLoader_aug(expanded_images, expanded_labels, 
                                  first=int(party_data_size *100 * (1 - config.sample_ratio)),
                                  seed=config.sample_seed
                                )
        sample_data_loader = MNISTLoader_aug(expanded_images, expanded_labels, 
                                  last=int(party_data_size *100 * config.sample_ratio),
                                  seed=config.sample_seed
                                )
    else:
        gan_data_loader = MNISTLoader(config.data_dir, include_train=not config.exclude_train,
                                  include_test=not config.exclude_test)

    if config.enable_accounting:
        accountant = GaussianMomentsAccountant(gan_data_loader.n, config.moment)
        if config.log_path:
            open(config.log_path, "w").close()
    else:
        accountant = None

    if config.adaptive_rate:
        lr = tf.placeholder(tf.float32, shape=())
    else:
        lr = config.learning_rate

    gen_optimizer = tf.train.AdamOptimizer(config.gen_learning_rate, beta1=0.5, beta2=0.9)
    disc_optimizer = tf.train.AdamOptimizer(lr, beta1=0.5, beta2=0.9)
Пример #4
0
    if config.save_path is not None:
        fobj = open(config.save_path, "w")
    else:
        fobj = None

    for params in chain(GRID1, GRID2, GRID3):
        for key, value in params.items():
            setattr(config, key, value)
        name = NAME_STYLE % params
        print("config: %r" % config)
        print("resetting environment...")
        tf.reset_default_graph()

        eval_data_loader = MNISTLoader(config.data_dir,
                                       include_test=True,
                                       include_train=False)

        mean_accuracy = run_task_eval(config,
                                      eval_data_loader,
                                      image_classifier_forward,
                                      model_dir=os.path.join(
                                          model_dir, name + "_models"))
        if fobj is not None:
            fobj.write("%s: %.4f\n" % (name, mean_accuracy))
        else:
            print("%s: %.4f\n" % (name, mean_accuracy))

    if fobj is not None:
        fobj.close()
Пример #5
0
    parser.add_argument("--exclude-train", dest="exclude_train", action="store_true")
    parser.add_argument("--exclude-test", dest="exclude_test", action="store_true")

    config = parser.parse_args()
    config.dataset = "mnist"

    np.random.seed()
    if config.enable_accounting:
        config.sigma = np.sqrt(2.0 * np.log(1.25 / config.delta)) / config.epsilon
        print("Now with new sigma: %.4f" % config.sigma)

    if config.sample_ratio is not None:
        kwargs = {}
        gan_data_loader = MNISTLoader(config.data_dir, include_train=not config.exclude_train,
                                  include_test=not config.exclude_test,
                                  first=int(50000 * (1 - config.sample_ratio)),
                                  seed=config.sample_seed
                                )
        sample_data_loader = MNISTLoader(config.data_dir, include_train=not config.exclude_train,
                                  include_test=not config.exclude_test,
                                  last=int(50000 * config.sample_ratio),
                                  seed=config.sample_seed
                                )
    else:
        gan_data_loader = MNISTLoader(config.data_dir, include_train=not config.exclude_train,
                                  include_test=not config.exclude_test)

    if config.enable_accounting:
        accountant = GaussianMomentsAccountant(gan_data_loader.n, config.moment)
        if config.log_path:
            open(config.log_path, "w").close()
Пример #6
0
    parser.add_argument("--data-dir",
                        default="./data/mnist_data",
                        dest="data_dir")
    parser.add_argument("--learning-rate",
                        default=4e-4,
                        type=float,
                        dest="learning_rate")
    parser.add_argument("--gen-learning-rate",
                        default=4e-4,
                        type=float,
                        dest="gen_learning_rate")

    config = parser.parse_args()

    np.random.seed()

    data_loader = MNISTLoader(config.data_dir)

    gen_optimizer = tf.train.AdamOptimizer(config.gen_learning_rate,
                                           beta1=0.5,
                                           beta2=0.9)
    disc_optimizer = tf.train.AdamOptimizer(config.learning_rate,
                                            beta1=0.5,
                                            beta2=0.9)

    train(config,
          data_loader,
          mnist.generator_forward,
          mnist.discriminator_forward,
          gen_optimizer=gen_optimizer,
          disc_optimizer=disc_optimizer)
Пример #7
0
def train(config):
    data_loader = MNISTLoader(config.data_dir)
    real_labels, fake_labels, real_inputs, fake_inputs = build_graph(config)
    global_step = tf.Variable(0, False)
    gen_train_ops, disc_train_ops, gen_loss, disc_loss = create_train_ops(
        config, global_step, real_labels, fake_labels, real_inputs,
        fake_inputs)
    saver = tf.train.Saver(max_to_keep=20)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    num_steps = data_loader.num_steps(config.batch_size)

    if config.save_dir:
        os.makedirs(config.save_dir, exist_ok=True)
    if config.image_dir:
        os.makedirs(config.image_dir, exist_ok=True)

    total_step = 0
    for epoch in xrange(config.epoch):
        bar = trange(num_steps, leave=False)
        for _ in bar:
            disc_loss_value, gen_loss_value = 0.0, 0.0
            tflearn.is_training(True, sess)
            if total_step == 0:
                sess.run([], feed_dict={global_step: 1})
            else:
                gen_loss_value, _ = sess.run(
                    [gen_loss, gen_train_ops],
                    feed_dict={fake_labels: sample_labels(config.batch_size)})
            for i in xrange(5):
                bx, by = data_loader.next_batch(config.batch_size)
                disc_loss_value, _ = sess.run([disc_loss, disc_train_ops],
                                              feed_dict={
                                                  real_labels: by,
                                                  fake_labels: by,
                                                  real_inputs: bx
                                              })
            bar.set_description("epoch %d, gen loss %.4f, disc loss %.4f" %
                                (epoch, gen_loss_value, disc_loss_value))
            tflearn.is_training(False, sess)
            if total_step % 20 == 0 and config.image_dir:
                sampled_labels = regular_labels()
                generated = sess.run(fake_inputs,
                                     feed_dict={fake_labels: sampled_labels})
                generate_images(
                    generated, data_loader.mode(),
                    os.path.join(config.image_dir,
                                 "gen_step_%d.jpg" % total_step))
                generate_images(
                    data_loader.next_batch(config.batch_size)[0],
                    data_loader.mode(),
                    os.path.join(config.image_dir,
                                 "real_step_%d.jpg" % total_step))

            total_step += 1
        bar.close()
        if config.save_dir is not None:
            saver.save(sess,
                       os.path.join(config.save_dir, "model"),
                       global_step=global_step,
                       write_meta_graph=False)
    sess.close()