def evaluate_classifier(train_dir, cfg): log.info("Loading classifier network from the following config: %s", str(cfg)) model_params = { 'resnet_size': cfg.resnet_size, 'data_format': 'channels_first', 'batch_size': cfg.evaluation.batch_size, 'num_classes': cfg.dataset.num_classes, } images, labels, iter_fn = matcher.load_dataset(cfg.evaluation.split, cfg.training.batch_size, cfg.dataset.name, cfg.dataset.image_size, augmentation=False, shuffle=False, normalize=True, onehot=True) with tf.variable_scope("resnet", reuse=None): model = cfg.network_fn(images, labels, tf.estimator.ModeKeys.EVAL, model_params) init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log) local_init_op = tf.local_variables_initializer() tf.get_default_graph().finalize() log.info("Creating the session...") with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: sess.run(init_assign_op, feed_dict=init_feed_dict) iter_fn(sess) sess.run(local_init_op) # reset counters in metrics num_test_batches = cfg.evaluation.test_set_size // cfg.evaluation.batch_size for i in range(num_test_batches): eval_acc = sess.run(model.eval_metric_ops)['accuracy'][1] log.info("Final evaluation accuracy on split %s: %.4f", cfg.evaluation.split, eval_acc) return eval_acc
def inception_score(train_dir): split = "gan_100_%s" % args.run_name if args.inception_file != "": split = args.inception_file images, _, iter_fn = matcher.load_dataset(split, 100, args.dataset, args.image_size, augmentation=False, shuffle=True, classes=None, normalize=False) inception = Inception(images) init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log) tf.get_default_graph().finalize() real_activations = matcher.load_inception_activations(args.dataset, args.image_size) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: sess.run(init_assign_op, feed_dict=init_feed_dict) iter_fn(sess) scores, fids = inception.compute_inception_score_and_fid(real_activations, sess, splits=args.inception_splits) is_mean, is_std = scores fid_mean, fid_std = fids log.info("Final Inception score over 50K images: %f +- %f" % (is_mean, is_std)) log.info("Final Frechet Inception distance over 50K images of %s: %f +- %f" % (split, fid_mean, fid_std))
def train_classifier(train_dir, cfg): log.info("Training classifier network from the following config: %s", str(cfg)) assert cfg.evaluation.test_set_size % cfg.evaluation.batch_size == 0 num_test_batches = cfg.evaluation.test_set_size // cfg.evaluation.batch_size model_params = { 'resnet_size': cfg.resnet_size, 'data_format': 'channels_first', 'batch_size': cfg.training.batch_size, 'num_classes': cfg.dataset.num_classes, } log.info("Creating the graph...") images, labels, iter_fn = matcher.load_dataset(cfg.training.split, cfg.training.batch_size, cfg.dataset.name, cfg.dataset.image_size, augmentation=True, shuffle=True, normalize=True, onehot=True) val_images, val_labels, val_iter_fn = matcher.load_dataset( cfg.evaluation.split, cfg.evaluation.batch_size, cfg.dataset.name, cfg.dataset.image_size, augmentation=False, shuffle=True, normalize=True, onehot=True, dequantize=False) with tf.variable_scope("resnet", reuse=None): model = cfg.network_fn(images, labels, tf.estimator.ModeKeys.TRAIN, model_params) with tf.variable_scope("resnet", reuse=True): val_model = cfg.network_fn(val_images, val_labels, tf.estimator.ModeKeys.EVAL, model_params) local_init_op = tf.local_variables_initializer() clean_init_op = tf.group(tf.global_variables_initializer(), local_init_op) global_step = tf.train.get_or_create_global_step() init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log) saver = tf.train.Saver(max_to_keep=1) summary_op = tf.summary.merge_all() tf.get_default_graph().finalize() log.info("Creating the session...") with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: summary_writer = tf.summary.FileWriter(train_dir) sess.run(clean_init_op) sess.run(init_assign_op, feed_dict=init_feed_dict) iter_fn(sess) val_iter_fn(sess) starting_step = sess.run(global_step) starting_time = time.time() log.info("Starting training from step %i..." % starting_step) for step in range(starting_step, cfg.training.max_iterations + 1): start_time = time.time() try: _, train_loss = sess.run([model.train_op, model.loss]) except (tf.errors.OutOfRangeError, tf.errors.CancelledError): break except KeyboardInterrupt: log.info("Killed by ^C") break if step % cfg.training.print_step == 0: duration = float(time.time() - start_time) examples_per_sec = cfg.training.batch_size / duration avg_speed = (time.time() - starting_time) / (step - starting_step + 1) time_to_finish = datetime.timedelta( seconds=(avg_speed * (cfg.training.max_iterations - step))) end_date = datetime.datetime.now() + time_to_finish format_str = ( 'step %d, %.3f (%.1f examples/sec; %.3f sec/batch)') log.info(format_str % (step, train_loss, examples_per_sec, duration)) log.info( "%i iterations left expected to finish after %s, thus at %s (avg speed: %.3f sec/batch)" % (cfg.training.max_iterations - step, str(time_to_finish), end_date.strftime("%Y-%m-%d %H:%M:%S"), avg_speed)) if step % cfg.training.summary_step == 0: summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) if step % cfg.training.ckpt_step == 0 and step > 0: summary_writer.flush() log.debug("Saving checkpoint...") checkpoint_path = os.path.join(train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step, write_meta_graph=False) if step % cfg.training.eval_step == 0: sess.run(local_init_op) # reset counters in metrics for _ in range(num_test_batches): eval_acc = sess.run( val_model.eval_metric_ops)['accuracy'][1] log.info("Intermediate evaluation accuracy: %.4f" % eval_acc) summary = tf.Summary() summary.value.add(tag='accuracy/test', simple_value=eval_acc) summary_writer.add_summary(summary, step) summary_writer.close() sess.run(local_init_op) # reset counters in metrics for i in range(num_test_batches): eval_metrics = sess.run(val_model.eval_metric_ops) eval_acc = eval_metrics['accuracy'][1] eval_acc_top5 = eval_metrics['top5_accuracy'][1] log.info("Final evaluation accuracy on split %s: %.4f", cfg.evaluation.split, eval_acc) if cfg.dataset.name == 'imagenet': log.info("Final evaluation top-5 accuracy on split %s: %.4f", cfg.evaluation.split, eval_acc_top5) return eval_acc
def train(train_dir): # XXX subsampling support is dropped silently assert abs(args.subsampling - 1) < 0.01 target_classes = list(range(args.num_classes)) # XXX classes support is dropped, UPDATE: retrofitted for class splits classes = None split_training_mode = args.total_class_splits > 0 num_classes = args.num_classes if split_training_mode: assert args.num_classes % args.total_class_splits == 0 split_sz = args.num_classes // args.total_class_splits classes_a = args.active_split_num * split_sz classes_b = classes_a + split_sz classes = list(range(classes_a, classes_b)) num_classes = split_sz log.info("Class split training mode is activated: " "this run chooses %i split out of %i in total, " "thus classes=%s", args.active_split_num, args.total_class_splits, str(classes)) images, labels, iter_fn = matcher.load_dataset(args.train_split, args.batch_size, args.dataset, args.image_size, augmentation=False, shuffle=True, classes=classes, normalize=True) if split_training_mode: labels -= classes_a noise = tf.random_normal([args.batch_size, args.noise_dims]) discriminator_train_steps = args.num_discriminator_steps generator_train_steps = 1 def conditional_generator(x): return archs[args.arch].generator(x, True, num_classes=num_classes)[0] def conditional_discriminator(x, conditioning): gan_logits, class_logits, _ = archs[args.arch].discriminator(x, True, gen_input=conditioning, num_classes=num_classes) return gan_logits, class_logits def unconditional_discriminator(x, conditioning): gan_logits, _, _ = archs[args.arch].discriminator(x, True, gen_input=conditioning, num_classes=num_classes) return gan_logits one_hot_labels = tf.one_hot(labels, num_classes) if args.unconditional: gan_model = tfgan.gan_model( generator_fn=conditional_generator, discriminator_fn=unconditional_discriminator, real_data=images, generator_inputs=noise) elif args.projection: gan_model = tfgan.gan_model( generator_fn=conditional_generator, discriminator_fn=unconditional_discriminator, real_data=images, generator_inputs=(noise, one_hot_labels)) else: gan_model = tfgan.acgan_model( generator_fn=conditional_generator, discriminator_fn=conditional_discriminator, real_data=images, generator_inputs=(noise, one_hot_labels), one_hot_labels=one_hot_labels) gp = None if abs(args.gradient_penalty) < 0.01 else args.gradient_penalty acgan_gw = None if (args.unconditional or abs(args.acgan_gw) < 0.001) else args.acgan_gw acgan_dw = None if (args.unconditional or abs(args.acgan_dw) < 0.001) else args.acgan_dw if args.gan_loss == 'hinge': model_gen_loss = gan_losses.hinge_generator_loss model_dis_loss = gan_losses.hinge_discriminator_loss elif args.gan_loss == 'wasserstein': model_gen_loss = tfgan.losses.wasserstein_generator_loss model_dis_loss = tfgan.losses.wasserstein_discriminator_loss elif args.gan_loss == 'classical': model_gen_loss = tfgan.losses.modified_generator_loss model_dis_loss = tfgan.losses.modified_discriminator_loss else: raise ValueError("Unsupported GAN loss") gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=model_gen_loss, discriminator_loss_fn=model_dis_loss, aux_cond_generator_weight=acgan_gw, aux_cond_discriminator_weight=acgan_dw, gradient_penalty_weight=gp, ) global_step = tf.train.get_or_create_global_step() train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=get_optimizer("generator"), discriminator_optimizer=get_optimizer("discriminator")) if args.inception_step > 0: real_activations = matcher.load_inception_activations(args.dataset, args.image_size) inception = Inception(init_generator(args.eval_batch_size, reuse=True, denormalize=True)[0]) tfgan.eval.add_gan_model_image_summaries(gan_model, grid_size=int(np.sqrt(args.batch_size))) init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log) summary_op = tf.summary.merge_all() clean_init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=100, keep_checkpoint_every_n_hours=1) tf.get_default_graph().finalize() with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: summary_writer = tf.summary.FileWriter(train_dir) sess.run(clean_init_op) sess.run(init_assign_op, feed_dict=init_feed_dict) iter_fn(sess) starting_step = sess.run(global_step) starting_time = time.time() log.info("Starting training from step %i..." % starting_step) for step in range(starting_step, args.max_iterations+1): start_time = time.time() try: gen_loss = 0 for _ in range(generator_train_steps): cur_gen_loss = sess.run(train_ops.generator_train_op) gen_loss += cur_gen_loss dis_loss = 0 for _ in range(discriminator_train_steps): cur_dis_loss = sess.run(train_ops.discriminator_train_op) dis_loss += cur_dis_loss sess.run(train_ops.global_step_inc_op) except (tf.errors.OutOfRangeError, tf.errors.CancelledError): break except KeyboardInterrupt: log.info("Killed by ^C") break if step % args.print_step == 0: duration = float(time.time() - start_time) examples_per_sec = args.batch_size / duration log.info("step %i: gen loss = %f, dis loss = %f (%.1f examples/sec; %.3f sec/batch)" % (step, gen_loss, dis_loss, examples_per_sec, duration)) avg_speed = (time.time() - starting_time)/(step - starting_step + 1) time_to_finish = avg_speed * (args.max_iterations - step) end_date = datetime.datetime.now() + datetime.timedelta(seconds=time_to_finish) log.info("%i iterations left expected to finish at %s (avg speed: %.3f sec/batch)" % (args.max_iterations - step, end_date.strftime("%Y-%m-%d %H:%M:%S"), avg_speed)) if step % args.summary_step == 0: summary = tf.Summary() summary.ParseFromString(sess.run(summary_op)) if args.inception_step != 0 and step % args.inception_step == 0 and step > 0: scores, fids = inception.compute_inception_score_and_fid(real_activations, sess) is_mean, is_std = scores fid_mean, fid_std = fids summary.value.add(tag='Inception_50K_mean', simple_value=is_mean) summary.value.add(tag='Inception_50K_std', simple_value=is_std) summary.value.add(tag='FID_50K_mean', simple_value=fid_mean) summary.value.add(tag='FID_50K_std', simple_value=fid_std) log.info("Inception score over 50K images: %f +- %f" % (is_mean, is_std)) log.info("Frechet Inception distance over 50K images: %f +- %f" % (fid_mean, fid_std)) summary_writer.add_summary(summary, step) if step % args.ckpt_step == 0 and step >= 0: summary_writer.flush() log.debug("Saving checkpoint...") checkpoint_path = os.path.join(train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step, write_meta_graph=False) summary_writer.close()
from paths import DATASETS import argparse from data import matcher parser = argparse.ArgumentParser( description='Generate and cache Inception activations') parser.add_argument("--dataset", default='cifar10', choices=['imagenet', 'cifar10', 'cifar100']) parser.add_argument("--image_size", default=32, type=int) args = parser.parse_args() inception = Inception(None) images, _, iter_fn = matcher.load_dataset('train', 500, normalize=False, dataset=args.dataset, size=args.image_size) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: iter_fn(sess) def _sample_images(x): return sess.run(images) print("Loading images...") dset = split_apply_concat(np.zeros(50000), _sample_images, 100) print("Images are loaded: ", dset.shape) print("Computing inception activations...") activations = inception._compute_inception_activations(dset, sess,