def main(config): num_step = config.num_step data_loader = MNISTLoader(config.data_dir) names = [] fobjs = [] try: for _ in xrange(num_step): fd, name = tempfile.mkstemp(suffix=".npy") fobj = os.fdopen(fd, "wb+") names.append(name) fobjs.append(fobj) image_arr = data_loader.next_batch(config.batch_size)[0] np.save(fobj, image_arr, allow_pickle=False) fobj.close() mean_score, std_score = get_mnist_score(images_iter(names), config.model_path, batch_size=100, split=10) print("mean = %.4f, std = %.4f." % (mean_score, std_score)) if config.save_path is not None: with open(config.save_path, "wb") as f: cPickle.dump( dict(batch_size=config.batch_size, scores=dict(mean=mean_score, std=std_score)), f) finally: for name in names: os.unlink(name) for fobj in fobjs: fobj.close()
for params in GRID: for key, value in params.items(): setattr(config, key, value) name = NAME_STYLE % params if save_dir is not None: config.save_dir = os.path.join(save_dir, name + "_models") os.makedirs(config.save_dir, exist_ok=True) if log_dir is not None: config.log_path = os.path.join(log_dir, name + ".log") print("config: %r" % config) print("resetting environment...") tf.reset_default_graph() train_data_loader = MNISTLoader(config.data_dir, include_test=False, first=config.public_num, seed=config.public_seed) eval_data_loader = MNISTLoader(config.data_dir, include_test=True, include_train=False) run_task(config, train_data_loader, eval_data_loader, generator_forward, code_classifier_forward, image_classifier_forward, image_classifier_optimizer=tf.train.AdamOptimizer(), code_classifier_optimizer=tf.train.AdamOptimizer(), model_path=config.model_path)
expanded_labels = expanded_labels[indices] print(expanded_images.shape) print(expanded_labels.shape) if config.sample_ratio is not None: kwargs = {} gan_data_loader = MNISTLoader_aug(expanded_images, expanded_labels, first=int(party_data_size *100 * (1 - config.sample_ratio)), seed=config.sample_seed ) sample_data_loader = MNISTLoader_aug(expanded_images, expanded_labels, last=int(party_data_size *100 * config.sample_ratio), seed=config.sample_seed ) else: gan_data_loader = MNISTLoader(config.data_dir, include_train=not config.exclude_train, include_test=not config.exclude_test) if config.enable_accounting: accountant = GaussianMomentsAccountant(gan_data_loader.n, config.moment) if config.log_path: open(config.log_path, "w").close() else: accountant = None if config.adaptive_rate: lr = tf.placeholder(tf.float32, shape=()) else: lr = config.learning_rate gen_optimizer = tf.train.AdamOptimizer(config.gen_learning_rate, beta1=0.5, beta2=0.9) disc_optimizer = tf.train.AdamOptimizer(lr, beta1=0.5, beta2=0.9)
if config.save_path is not None: fobj = open(config.save_path, "w") else: fobj = None for params in chain(GRID1, GRID2, GRID3): for key, value in params.items(): setattr(config, key, value) name = NAME_STYLE % params print("config: %r" % config) print("resetting environment...") tf.reset_default_graph() eval_data_loader = MNISTLoader(config.data_dir, include_test=True, include_train=False) mean_accuracy = run_task_eval(config, eval_data_loader, image_classifier_forward, model_dir=os.path.join( model_dir, name + "_models")) if fobj is not None: fobj.write("%s: %.4f\n" % (name, mean_accuracy)) else: print("%s: %.4f\n" % (name, mean_accuracy)) if fobj is not None: fobj.close()
parser.add_argument("--exclude-train", dest="exclude_train", action="store_true") parser.add_argument("--exclude-test", dest="exclude_test", action="store_true") config = parser.parse_args() config.dataset = "mnist" np.random.seed() if config.enable_accounting: config.sigma = np.sqrt(2.0 * np.log(1.25 / config.delta)) / config.epsilon print("Now with new sigma: %.4f" % config.sigma) if config.sample_ratio is not None: kwargs = {} gan_data_loader = MNISTLoader(config.data_dir, include_train=not config.exclude_train, include_test=not config.exclude_test, first=int(50000 * (1 - config.sample_ratio)), seed=config.sample_seed ) sample_data_loader = MNISTLoader(config.data_dir, include_train=not config.exclude_train, include_test=not config.exclude_test, last=int(50000 * config.sample_ratio), seed=config.sample_seed ) else: gan_data_loader = MNISTLoader(config.data_dir, include_train=not config.exclude_train, include_test=not config.exclude_test) if config.enable_accounting: accountant = GaussianMomentsAccountant(gan_data_loader.n, config.moment) if config.log_path: open(config.log_path, "w").close()
parser.add_argument("--data-dir", default="./data/mnist_data", dest="data_dir") parser.add_argument("--learning-rate", default=4e-4, type=float, dest="learning_rate") parser.add_argument("--gen-learning-rate", default=4e-4, type=float, dest="gen_learning_rate") config = parser.parse_args() np.random.seed() data_loader = MNISTLoader(config.data_dir) gen_optimizer = tf.train.AdamOptimizer(config.gen_learning_rate, beta1=0.5, beta2=0.9) disc_optimizer = tf.train.AdamOptimizer(config.learning_rate, beta1=0.5, beta2=0.9) train(config, data_loader, mnist.generator_forward, mnist.discriminator_forward, gen_optimizer=gen_optimizer, disc_optimizer=disc_optimizer)
def train(config): data_loader = MNISTLoader(config.data_dir) real_labels, fake_labels, real_inputs, fake_inputs = build_graph(config) global_step = tf.Variable(0, False) gen_train_ops, disc_train_ops, gen_loss, disc_loss = create_train_ops( config, global_step, real_labels, fake_labels, real_inputs, fake_inputs) saver = tf.train.Saver(max_to_keep=20) sess = tf.Session() sess.run(tf.global_variables_initializer()) num_steps = data_loader.num_steps(config.batch_size) if config.save_dir: os.makedirs(config.save_dir, exist_ok=True) if config.image_dir: os.makedirs(config.image_dir, exist_ok=True) total_step = 0 for epoch in xrange(config.epoch): bar = trange(num_steps, leave=False) for _ in bar: disc_loss_value, gen_loss_value = 0.0, 0.0 tflearn.is_training(True, sess) if total_step == 0: sess.run([], feed_dict={global_step: 1}) else: gen_loss_value, _ = sess.run( [gen_loss, gen_train_ops], feed_dict={fake_labels: sample_labels(config.batch_size)}) for i in xrange(5): bx, by = data_loader.next_batch(config.batch_size) disc_loss_value, _ = sess.run([disc_loss, disc_train_ops], feed_dict={ real_labels: by, fake_labels: by, real_inputs: bx }) bar.set_description("epoch %d, gen loss %.4f, disc loss %.4f" % (epoch, gen_loss_value, disc_loss_value)) tflearn.is_training(False, sess) if total_step % 20 == 0 and config.image_dir: sampled_labels = regular_labels() generated = sess.run(fake_inputs, feed_dict={fake_labels: sampled_labels}) generate_images( generated, data_loader.mode(), os.path.join(config.image_dir, "gen_step_%d.jpg" % total_step)) generate_images( data_loader.next_batch(config.batch_size)[0], data_loader.mode(), os.path.join(config.image_dir, "real_step_%d.jpg" % total_step)) total_step += 1 bar.close() if config.save_dir is not None: saver.save(sess, os.path.join(config.save_dir, "model"), global_step=global_step, write_meta_graph=False) sess.close()