def testBatchSize(self, disc_iters, use_tpu=True): parameters = { "architecture": c.RESNET5_ARCH, "lambda": 1, "z_dim": 128, "disc_iters": disc_iters, } batch_size = 16 dataset = datasets.get_dataset("cifar10") gan = MockModularGAN(dataset=dataset, parameters=parameters, model_dir=self.model_dir) estimator = gan.as_estimator(self.run_config, batch_size=batch_size, use_tpu=use_tpu) estimator.train(gan.input_fn, steps=1) logging.info("gen_args: %s", "\n".join(str(a) for a in gan.gen_args)) logging.info("disc_args: %s", "\n".join(str(a) for a in gan.disc_args)) num_shards = 2 if use_tpu else 1 assert batch_size % num_shards == 0 gen_bs = batch_size // num_shards disc_bs = gen_bs * 2 # merged discriminator calls. self.assertLen(gan.gen_args, disc_iters + 2) for args in gan.gen_args: self.assertAllEqual(args["z"].shape.as_list(), [gen_bs, 128]) self.assertLen(gan.disc_args, disc_iters + 1) for args in gan.disc_args: self.assertAllEqual(args["x"].shape.as_list(), [disc_bs, 32, 32, 3])
def testSingleTrainingStepArchitectures(self, use_predictor, project_y=True, self_supervision="none"): parameters = { "architecture": c.RESNET_BIGGAN_ARCH, "lambda": 1, "z_dim": 120, } with gin.unlock_config(): gin.bind_parameter("ModularGAN.conditional", True) gin.bind_parameter("loss.fn", loss_lib.hinge) gin.bind_parameter("S3GAN.use_predictor", use_predictor) gin.bind_parameter("S3GAN.project_y", project_y) gin.bind_parameter("S3GAN.self_supervision", self_supervision) # Fake ImageNet dataset by overriding the properties. dataset = datasets.get_dataset("imagenet_128") model_dir = self._get_empty_model_dir() run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) gan = S3GAN(dataset=dataset, parameters=parameters, model_dir=model_dir, g_optimizer_fn=tf.train.AdamOptimizer, g_lr=0.0002, rotated_batch_fraction=2) estimator = gan.as_estimator(run_config, batch_size=8, use_tpu=False) estimator.train(gan.input_fn, steps=1)
def testBatchSizeSplitDiscCalls(self, disc_iters): parameters = { "architecture": c.DUMMY_ARCH, "lambda": 1, "z_dim": 128, "disc_iters": disc_iters, } batch_size = 16 dataset = datasets.get_dataset("cifar10") gan = ModularGAN(dataset=dataset, parameters=parameters, deprecated_split_disc_calls=True, model_dir=self.model_dir) estimator = gan.as_estimator(self.run_config, batch_size=batch_size, use_tpu=True) estimator.train(gan.input_fn, steps=1) gen_args = gan.generator.call_arg_list disc_args = gan.discriminator.call_arg_list self.assertLen(gen_args, disc_iters + 1) # D steps, G step. # Each D and G step calls discriminator twice: for real and fake images. self.assertLen(disc_args, 2 * (disc_iters + 1)) for args in gen_args: self.assertAllEqual(args["z"].shape.as_list(), [8, 128]) for args in disc_args: self.assertAllEqual(args["x"].shape.as_list(), [8, 32, 32, 3])
def testDiscItersIsUsedCorrectly(self, disc_iters, use_tpu): parameters = { "architecture": c.DUMMY_ARCH, "disc_iters": disc_iters, "lambda": 1, "z_dim": 128, } run_config = tf.contrib.tpu.RunConfig( model_dir=self.model_dir, save_checkpoints_steps=1, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) dataset = datasets.get_dataset("cifar10") gan = ModularGAN(dataset=dataset, parameters=parameters, model_dir=self.model_dir) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=use_tpu) estimator.train(gan.input_fn, steps=3) disc_step_values = [] gen_step_values = [] for step in range(4): basename = os.path.join(self.model_dir, "model.ckpt-{}".format(step)) self.assertTrue(tf.gfile.Exists(basename + ".index")) ckpt = tf.train.load_checkpoint(basename) disc_step_values.append(ckpt.get_tensor("global_step_disc")) gen_step_values.append(ckpt.get_tensor("global_step")) expected_disc_steps = np.arange(4) * disc_iters self.assertAllEqual(disc_step_values, expected_disc_steps) self.assertAllEqual(gen_step_values, [0, 1, 2, 3])
def testBatchSizeExperimentalJointGenForDisc(self, disc_iters): parameters = { "architecture": c.DUMMY_ARCH, "lambda": 1, "z_dim": 128, "disc_iters": disc_iters, } batch_size = 16 dataset = datasets.get_dataset("cifar10") gan = ModularGAN(dataset=dataset, parameters=parameters, experimental_joint_gen_for_disc=True, model_dir=self.model_dir) estimator = gan.as_estimator(self.run_config, batch_size=batch_size, use_tpu=True) estimator.train(gan.input_fn, steps=1) gen_args = gan.generator.call_arg_list disc_args = gan.discriminator.call_arg_list self.assertLen(gen_args, 2) self.assertLen(disc_args, disc_iters + 1) self.assertAllEqual(gen_args[0]["z"].shape.as_list(), [8 * disc_iters, 128]) self.assertAllEqual(gen_args[1]["z"].shape.as_list(), [8, 128]) for args in disc_args: self.assertAllEqual(args["x"].shape.as_list(), [16, 32, 32, 3])
def testSingleTrainingStepDiscItersWithEma(self, disc_iters): parameters = { "architecture": c.DUMMY_ARCH, "lambda": 1, "z_dim": 128, "dics_iters": disc_iters, } gin.bind_parameter("ModularGAN.g_use_ema", True) dataset = datasets.get_dataset("cifar10") gan = ModularGAN(dataset=dataset, parameters=parameters, model_dir=self.model_dir) estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1) # Check for moving average variables in checkpoint. checkpoint_path = tf.train.latest_checkpoint(self.model_dir) ema_vars = sorted([ v[0] for v in tf.train.list_variables(checkpoint_path) if v[0].endswith("ExponentialMovingAverage") ]) tf.logging.info("ema_vars=%s", ema_vars) expected_ema_vars = sorted([ "generator/fc_noise/kernel/ExponentialMovingAverage", "generator/fc_noise/bias/ExponentialMovingAverage", ]) self.assertAllEqual(ema_vars, expected_ema_vars)
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn): parameters = { "architecture": architecture, "lambda": 1, "z_dim": 128, } with gin.unlock_config(): gin.bind_parameter("penalty.fn", penalty_fn) gin.bind_parameter("loss.fn", loss_fn) gin.bind_parameter("G.batch_norm_fn", evonorm_s0) model_dir = self._get_empty_model_dir() run_config = tf.contrib.tpu.RunConfig( model_dir=model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) dataset = datasets.get_dataset("cifar10") gan = CLGAN( dataset=dataset, parameters=parameters, model_dir=model_dir, g_optimizer_fn=tf.train.AdamOptimizer, g_lr=0.0002, ) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
def _run_train_input_fn(self, dataset_name, preprocess_fn): dataset = datasets.get_dataset(dataset_name) with tf.Graph().as_default(): dataset = dataset.input_fn(params={"batch_size": 1}, preprocess_fn=preprocess_fn) iterator = dataset.make_initializable_iterator() with self.session() as sess: sess.run(iterator.initializer) next_batch = iterator.get_next() return [sess.run(next_batch) for _ in range(5)]
def get_element_and_verify_shape(self, dataset_name, expected_shape): dataset = datasets.get_dataset(dataset_name) dataset = dataset.eval_input_fn() image, label = dataset.make_one_shot_iterator().get_next() # Check if shape is known at compile time, required for TPUs. self.assertAllEqual(image.shape.as_list(), expected_shape) self.assertEqual(image.dtype, tf.float32) self.assertIn(label.dtype, _TPU_SUPPORTED_TYPES) with self.cached_session() as session: image = session.run(image) self.assertEqual(image.shape, expected_shape) self.assertGreaterEqual(image.min(), 0.0) self.assertLessEqual(image.max(), 1.0)
def testUnlabledDatasetRaisesError(self): parameters = { "architecture": c.RESNET_CIFAR, "lambda": 1, "z_dim": 120, } with gin.unlock_config(): gin.bind_parameter("loss.fn", loss_lib.hinge) # Use dataset without labels. dataset = datasets.get_dataset("celeb_a") with self.assertRaises(ValueError): gan = ModularGAN(dataset=dataset, parameters=parameters, conditional=True, model_dir=self.model_dir) del gan
def testSingleTrainingStepWithJointGenForDisc(self): parameters = { "architecture": c.RESNET5_BIGGAN_ARCH, "lambda": 1, "z_dim": 120, "disc_iters": 2, } dataset = datasets.get_dataset("cifar10") gan = ModularGAN(dataset=dataset, parameters=parameters, model_dir=self.model_dir, experimental_joint_gen_for_disc=True, conditional=True) estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
def test_end2end_checkpoint(self, architecture): """Takes real GAN (trained for 1 step) and evaluate it.""" if architecture in {c.RESNET_STL_ARCH, c.RESNET30_ARCH}: # RESNET_STL_ARCH and RESNET107_ARCH do not support CIFAR image shape. return gin.bind_parameter("dataset.name", "cifar10") dataset = datasets.get_dataset("cifar10") options = { "architecture": architecture, "z_dim": 120, "disc_iters": 1, "lambda": 1, } model_dir = os.path.join(tf.test.get_temp_dir(), self.id()) tf.logging.info("model_dir: %s" % model_dir) run_config = tf.contrib.tpu.RunConfig(model_dir=model_dir) gan = ModularGAN(dataset=dataset, parameters=options, conditional="biggan" in architecture, model_dir=model_dir) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) estimator.train(input_fn=gan.input_fn, steps=1) export_path = os.path.join(model_dir, "tfhub") checkpoint_path = os.path.join(model_dir, "model.ckpt-1") module_spec = gan.as_module_spec() module_spec.export(export_path, checkpoint_path=checkpoint_path) eval_tasks = [ fid_score.FIDScoreTask(), fractal_dimension.FractalDimensionTask(), inception_score.InceptionScoreTask(), ms_ssim_score.MultiscaleSSIMTask() ] result_dict = eval_gan_lib.evaluate_tfhub_module(export_path, eval_tasks, use_tpu=False, num_averaging_runs=1) tf.logging.info("result_dict: %s", result_dict) for score in [ "fid_score", "fractal_dimension", "inception_score", "ms_ssim" ]: for stats in ["mean", "std", "list"]: required_key = "%s_%s" % (score, stats) self.assertIn(required_key, result_dict, "Missing: %s." % required_key)
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn): parameters = { "architecture": architecture, "lambda": 1, "z_dim": 128, } with gin.unlock_config(): gin.bind_parameter("penalty.fn", penalty_fn) gin.bind_parameter("loss.fn", loss_fn) dataset = datasets.get_dataset("cifar10") gan = ModularGAN(dataset=dataset, parameters=parameters, model_dir=self.model_dir, conditional="biggan" in architecture) estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn, labeled_dataset): parameters = { "architecture": architecture, "lambda": 1, "z_dim": 120, } with gin.unlock_config(): gin.bind_parameter("penalty.fn", penalty_fn) gin.bind_parameter("loss.fn", loss_fn) run_config = tf.contrib.tpu.RunConfig( model_dir=self.model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) dataset = datasets.get_dataset("cifar10") gan = ModularGAN(dataset=dataset, parameters=parameters, conditional=True, model_dir=self.model_dir) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn): parameters = { "architecture": architecture, "discriminator_normalization": c.SPECTRAL_NORM, "lambda": 1, "z_dim": 128, } with gin.unlock_config(): gin.bind_parameter("penalty.fn", penalty_fn) gin.bind_parameter("loss.fn", loss_fn) run_config = tf.contrib.tpu.RunConfig( model_dir=self.model_dir, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) dataset = datasets.get_dataset("cifar10") gan = SSGAN(dataset=dataset, parameters=parameters, model_dir=self.model_dir, g_optimizer_fn=tf.train.AdamOptimizer, g_lr=0.0002, rotated_batch_size=4) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1)
def evaluate_tfhub_module(module_spec, eval_tasks, use_tpu, num_averaging_runs, step): """Evaluate model at given checkpoint_path. Args: module_spec: string, path to a TF hub module. eval_tasks: List of objects that inherit from EvalTask. use_tpu: Whether to use TPUs. num_averaging_runs: Determines how many times each metric is computed. step: Name of the step being evaluated Returns: Dict[Text, float] with all the computed results. Raises: NanFoundError: If generator output has any NaNs. """ # Make sure that the same latent variables are used for each evaluation. np.random.seed(42) dataset = datasets.get_dataset() num_test_examples = dataset.eval_test_samples batch_size = FLAGS.eval_batch_size num_batches = int(np.ceil(num_test_examples / batch_size)) # Load and update the generator. result_dict = {} fake_dsets = [] with tf.Graph().as_default(): tf.set_random_seed(42) with tf.Session() as sess: if use_tpu: sess.run(tf.contrib.tpu.initialize_system()) def sample_from_generator(): """Create graph for sampling images.""" generator = hub.Module( module_spec, name="gen_module", tags={"gen", "bs{}".format(batch_size)}) logging.info("Generator inputs: %s", generator.get_input_info_dict()) z_dim = generator.get_input_info_dict()["z"].get_shape()[1].value z = z_generator(shape=[batch_size, z_dim]) if "labels" in generator.get_input_info_dict(): # Conditional GAN. assert dataset.num_classes if FLAGS.force_label is None: labels = tf.random.uniform( [batch_size], maxval=dataset.num_classes, dtype=tf.int32) else: labels = tf.constant(FLAGS.force_label, shape=[batch_size], dtype=tf.int32) inputs = dict(z=z, labels=labels) else: # Unconditional GAN. assert "labels" not in generator.get_input_info_dict() inputs = dict(z=z) return generator(inputs=inputs, as_dict=True)["generated"] if use_tpu: generated = tf.contrib.tpu.rewrite(sample_from_generator) else: generated = sample_from_generator() tf.global_variables_initializer().run() save_model_accu_path = os.path.join(module_spec, "model-with-accu.ckpt") if not tf.io.gfile.exists(save_model_accu_path): if _update_bn_accumulators(sess, generated, num_accu_examples=204800): saver = tf.train.Saver() checkpoint_path = saver.save( sess, save_path=save_model_accu_path) logging.info("Exported generator with accumulated batch stats to " "%s.", checkpoint_path) if not eval_tasks: logging.error("Task list is empty, returning.") return for i in range(num_averaging_runs): logging.info("Generating fake data set %d/%d.", i+1, num_averaging_runs) fake_dset = eval_utils.EvalDataSample( eval_utils.sample_fake_dataset(sess, generated, num_batches)) fake_dsets.append(fake_dset) # Hacking this in here for speed for now save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step) logging.info("Computing inception features for generated data %d/%d.", i+1, num_averaging_runs) activations, logits = eval_utils.inception_transform_np( fake_dset.images, batch_size) fake_dset.set_inception_features( activations=activations, logits=logits) fake_dset.set_num_examples(num_test_examples) if i != 0: # Free up some memory by releasing additional fake data samples. # For ImageNet128 50k images are ~9 GiB. This will blow up metrics # (such as fractal dimension) if num_averaging_runs > 1. fake_dset.discard_images() real_dset = eval_utils.EvalDataSample( eval_utils.get_real_images( dataset=dataset, num_examples=num_test_examples)) logging.info("Getting Inception features for real images.") real_dset.activations, _ = eval_utils.inception_transform_np( real_dset.images, batch_size) real_dset.set_num_examples(num_test_examples) # Run all the tasks and update the result dictionary with the task statistics. result_dict = {} for task in eval_tasks: task_results_dicts = [ task.run_after_session(fake_dset, real_dset) for fake_dset in fake_dsets ] # Average the score for each key. result_statistics = {} for key in task_results_dicts[0].keys(): scores_for_key = np.array([d[key] for d in task_results_dicts]) mean, std = np.mean(scores_for_key), np.std(scores_for_key) scores_as_string = "_".join([str(x) for x in scores_for_key]) result_statistics[key + "_mean"] = mean result_statistics[key + "_std"] = std result_statistics[key + "_list"] = scores_as_string logging.info("Computed results for task %s: %s", task, result_statistics) result_dict.update(result_statistics) return result_dict
def testDiscItersIsUsedCorrectly(self, disc_iters, use_tpu): return if disc_iters > 1 and use_tpu: return parameters = { "architecture": c.RESNET_CIFAR, "disc_iters": disc_iters, "lambda": 1, "z_dim": 128, } if not use_tpu: parameters["unroll_disc_iters"] = False run_config = tf.contrib.tpu.RunConfig( model_dir=self.model_dir, save_checkpoints_steps=1, tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) dataset = datasets.get_dataset("cifar10") gan = ModularGAN(dataset=dataset, parameters=parameters, model_dir=self.model_dir) estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=use_tpu) estimator.train(gan.input_fn, steps=3) # Read checkpoints for each training step. If the weight in the generator # changed we trained the generator during that step. previous_values = {} generator_trained = [] for step in range(0, 4): basename = os.path.join(self.model_dir, "model.ckpt-{}".format(step)) self.assertTrue(tf.gfile.Exists(basename + ".index")) ckpt = tf.train.load_checkpoint(basename) if step == 0: for name in ckpt.get_variable_to_shape_map(): previous_values[name] = ckpt.get_tensor(name) continue d_trained = False g_trained = False for name in ckpt.get_variable_to_shape_map(): t = ckpt.get_tensor(name) diff = np.abs(previous_values[name] - t).max() previous_values[name] = t if "discriminator" in name and diff > 1e-10: d_trained = True elif "generator" in name and diff > 1e-10: if name.endswith("moving_mean") or name.endswith( "moving_variance"): # Note: Even when we don't train the generator the batch norm # values still get updated. continue tf.logging.info("step %d: %s changed up to %f", step, name, diff) g_trained = True self.assertTrue(d_trained) # Discriminator is trained every step. generator_trained.append(g_trained) self.assertEqual(generator_trained, GENERATOR_TRAINED_IN_STEPS[disc_iters - 1])
def generate_tfhub_module(module_spec, use_tpu, step): """Generate from model at given checkpoint_path. Args: module_spec: string, path to a TF hub module. use_tpu: Whether to use TPUs. step: Name of the step being evaluated Returns: Nothing Raises: NanFoundError: If generator output has any NaNs. """ # Make sure that the same latent variables are used for each evaluation. np.random.seed(42) dataset = datasets.get_dataset() num_test_examples = FLAGS.num_samples batch_size = FLAGS.eval_batch_size num_batches = int(np.ceil(num_test_examples / batch_size)) # Load and update the generator. result_dict = {} fake_dsets = [] with tf.Graph().as_default(): tf.set_random_seed(42) with tf.Session() as sess: if use_tpu: sess.run(tf.contrib.tpu.initialize_system()) def create_generator(force_label=None): def sample_from_generator(): """Create graph for sampling images.""" generator = hub.Module( module_spec, name="gen_module", tags={"gen", "bs{}".format(batch_size)}) logging.info("Generator inputs: %s", generator.get_input_info_dict()) z_dim = generator.get_input_info_dict()["z"].get_shape()[1].value z = z_generator(shape=[batch_size, z_dim]) if "labels" in generator.get_input_info_dict(): # Conditional GAN. assert dataset.num_classes if force_label is None: labels = tf.random.uniform( [batch_size], maxval=dataset.num_classes, dtype=tf.int32) else: labels = tf.constant(force_label, shape=[batch_size], dtype=tf.int32) inputs = dict(z=z, labels=labels) else: # Unconditional GAN. assert "labels" not in generator.get_input_info_dict() inputs = dict(z=z) return generator(inputs=inputs, as_dict=True)["generated"] return sample_from_generator if use_tpu: generated = tf.contrib.tpu.rewrite(create_generator()) else: generated = create_generator()() tf.global_variables_initializer().run() save_model_accu_path = os.path.join(module_spec, "model-with-accu.ckpt") saver = tf.train.Saver() if _update_bn_accumulators(sess, generated, num_accu_examples=FLAGS.num_accu_examples): checkpoint_path = saver.save(sess, save_path=save_model_accu_path) logging.info("Exported generator with accumulated batch stats to " "%s.", checkpoint_path) logging.info("Generating fake data set") fake_dset = eval_utils.EvalDataSample( eval_utils.sample_fake_dataset(sess, generated, num_batches)) save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step) if FLAGS.force_label is not None: if use_tpu: generated = tf.contrib.tpu.rewrite(create_generator(FLAGS.force_label)) else: generated = create_generator(FLAGS.force_label)() tf.global_variables_initializer().run() _update_bn_accumulators(sess, generated, num_accu_examples=FLAGS.num_accu_examples) logging.info("Generating fake data set with forced label") fake_dset = eval_utils.EvalDataSample( eval_utils.sample_fake_dataset(sess, generated, num_batches)) save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step, force_label=FLAGS.force_label)
def run_with_schedule(schedule, run_config, task_manager, options, use_tpu, num_eval_averaging_runs=1, eval_every_steps=-1): """Run the schedule with the given options. Available schedules: - train: Train up to options["training_steps"], continuing from existing checkpoints if available. - eval_after_train: First train up to options["training_steps"] then evaluate all checkpoints. - continuous_eval: Waiting for new checkpoints and evaluate them as they become available. This is meant to run in parallel with a job running the training schedule but can also run after it. Args: schedule: Schedule to run. One of: train, continuous_eval, train_and_eval. run_config: `tf.contrib.tpu.RunConfig` to use. task_manager: `TaskManager` for this run. options: Python dictionary will run parameters. use_tpu: Boolean whether to use TPU. num_eval_averaging_runs: Determines how many times each metric is computed. eval_every_steps: Integer determining which checkpoints to evaluate. """ logging.info("Running schedule '%s' with options: %s", schedule, options) if run_config.tf_random_seed: logging.info("Setting NumPy random seed to %s.", run_config.tf_random_seed) np.random.seed(run_config.tf_random_seed) result_dir = os.path.join(run_config.model_dir, "result") utils.check_folder(result_dir) dataset = datasets.get_dataset() gan = options["gan_class"](dataset=dataset, parameters=options, model_dir=run_config.model_dir) if schedule not in {"train", "eval_after_train", "continuous_eval"}: raise ValueError("Schedule {} not supported.".format(schedule)) if schedule in {"train", "eval_after_train"}: train_hooks = [ gin.tf.GinConfigSaverHook(run_config.model_dir), hooks.ReportProgressHook(task_manager, max_steps=options["training_steps"]), ] if run_config.save_checkpoints_steps: # This replaces the default checkpoint saver hook in the estimator. logging.info("Using AsyncCheckpointSaverHook.") train_hooks.append( hooks.AsyncCheckpointSaverHook( checkpoint_dir=run_config.model_dir, save_steps=run_config.save_checkpoints_steps)) # (b/122782388): Remove hotfix. run_config = run_config.replace(save_checkpoints_steps=1000000) estimator = gan.as_estimator( run_config, batch_size=options["batch_size"], use_tpu=use_tpu) estimator.train( input_fn=gan.input_fn, max_steps=options["training_steps"], hooks=train_hooks) task_manager.mark_training_done() if schedule == "continuous_eval": # Continuous eval with up to 24 hours between checkpoints. checkpoints = task_manager.unevaluated_checkpoints( timeout=24 * 3600, eval_every_steps=eval_every_steps) if schedule == "eval_after_train": checkpoints = task_manager.unevaluated_checkpoints( eval_every_steps=eval_every_steps) if schedule in {"continuous_eval", "eval_after_train"}: _run_eval( gan.as_module_spec(), checkpoints=checkpoints, task_manager=task_manager, run_config=run_config, use_tpu=use_tpu, num_averaging_runs=num_eval_averaging_runs)
def train_gilbo(gan, sess, outdir, checkpoint_path, dataset, options): """Build and train GILBO model. Args: gan: GAN object. sess: tf.Session. outdir: Output directory. A pickle file will be written there. checkpoint_path: Path where gan"s checkpoints are written. Only used to ensure that GILBO files are written to a unique subdirectory of outdir. dataset: Name of dataset used to train the GAN. options: Options dictionary. Returns: mean_eval_info: Mean GILBO computed over a large number of images generated by the trained GAN mean_train_consistency: Mean consistency of the trained GILBO model with data from the training set. mean_eval_consistency: Same consistency measure for the trained model with data from the validation set. mean_self_consistency: Same consistency measure for the trained model with data generated by the trained model itself. See the GILBO paper for an explanation of these metrics. Raises: ValueError: If the GAN has uninitialized variables. """ uninitialized = sess.run(tf.report_uninitialized_variables()) if uninitialized: raise ValueError("Model has uninitialized variables!\n%r" % uninitialized) outdir = os.path.join(outdir, checkpoint_path.replace("/", "_")) tf.gfile.MakeDirs(outdir) with tf.variable_scope("gilbo"): ones = tf.ones((gan.batch_size, gan.z_dim)) # Get a distribution for the prior. z_dist = ds.Independent(ds.Uniform(-ones, ones), 1) z_sample = z_dist.sample() epsneg = np.finfo("float32").epsneg # Clip samples from the GAN uniform prior because the Beta distribution # doesn"t include the top endpoint and has issues with the bottom endpoint. ganz_clip = tf.clip_by_value(gan.z, -(1 - epsneg), 1 - epsneg) # Get generated images from the model. fake_images = gan.fake_images # Build the regressor distribution that encodes images back to predicted # samples from the prior. with tf.variable_scope("regressor"): z_pred_dist = _build_regressor(fake_images, gan.z_dim) # Capture the parameters of the distributions for later analysis. dist_p1 = z_pred_dist.distribution.distribution.concentration0 dist_p2 = z_pred_dist.distribution.distribution.concentration1 # info and avg_info compute the GILBO. info = z_pred_dist.log_prob(ganz_clip) - z_dist.log_prob(ganz_clip) avg_info = tf.reduce_mean(info) # Set up training of the GILBO model. lr = options.get("gilbo_learning_rate", 4e-4) learning_rate = tf.get_variable("learning_rate", initializer=lr, trainable=False) gilbo_step = tf.get_variable("gilbo_step", dtype=tf.int32, initializer=0, trainable=False) opt = tf.train.AdamOptimizer(learning_rate) regressor_vars = tf.contrib.framework.get_variables("gilbo/regressor") train_op = opt.minimize(-info, var_list=regressor_vars) # Initialize the variables we just created. uninitialized = plist(tf.report_uninitialized_variables().eval()) uninitialized_vars = uninitialized.apply( tf.contrib.framework.get_variables_by_name)._[0] tf.variables_initializer(uninitialized_vars).run() saver = tf.train.Saver(uninitialized_vars, max_to_keep=1) try: checkpoint_path = tf.train.latest_checkpoint(outdir) saver.restore(sess, checkpoint_path) except ValueError: # Failing to restore just indicates that we don"t have a valid checkpoint, # so we will just start training a fresh GILBO model. pass _train_gilbo(sess, gan, saver, learning_rate, gilbo_step, z_sample, avg_info, z_pred_dist, train_op, outdir, options) mean_eval_info = _eval_gilbo(sess, gan, z_sample, avg_info, dist_p1, dist_p2, fake_images, outdir, options) # Collect encoded distributions on the training and eval set in order to do # kl-nearest-neighbors on generated samples and measure consistency. dataset = datasets.get_dataset(dataset) x_train = dataset.load_dataset(split_name="train", num_threads=1) x_train = x_train.batch(gan.batch_size, drop_remainder=True) x_train = x_train.make_one_shot_iterator().get_next()[0] x_train = tf.reshape(x_train, fake_images.shape) x_eval = dataset.load_dataset(split_name="test", num_threads=1) x_eval = x_eval.batch(gan.batch_size, drop_remainder=True) x_eval = x_eval.make_one_shot_iterator().get_next()[0] x_eval = tf.reshape(x_eval, fake_images.shape) mean_train_consistency = _run_gilbo_consistency(x_train, "train", extract_input_images=0, save_consistency_images=20, num_batches=5, **locals()) mean_eval_consistency = _run_gilbo_consistency(x_eval, "eval", extract_input_images=0, save_consistency_images=20, num_batches=5, **locals()) mean_self_consistency = _run_gilbo_consistency(fake_images, "self", extract_input_images=20, save_consistency_images=20, num_batches=5, **locals()) return (mean_eval_info, mean_train_consistency, mean_eval_consistency, mean_self_consistency)
def compute_accuracy_loss(sess, gan, test_images, max_train_examples=50000, num_repeat=5): """Compute discriminator's accuracy and loss on a given dataset. Args: sess: Tf.Session object. gan: Any AbstractGAN instance. test_images: numpy array with test images. max_train_examples: How many "train" examples to get from the dataset. In each round, some of them will be randomly selected to evaluate train set accuracy. num_repeat: How many times to repreat the computation. The mean of all the results is reported. Returns: Dict[Text, float] with all the computed scores. Raises: ValueError: If the number of test_images is greater than the number of training images returned by the dataset. """ logging.info("Evaluating training and test accuracy...") train_images = eval_utils.get_real_images( dataset=datasets.get_dataset(), num_examples=max_train_examples, split="train", failure_on_insufficient_examples=False) if train_images.shape[0] < test_images.shape[0]: raise ValueError("num_train %d must be larger than num_test %d." % (train_images.shape[0], test_images.shape[0])) num_batches = int(np.floor(test_images.shape[0] / gan.batch_size)) if num_batches * gan.batch_size < test_images.shape[0]: logging.error("Ignoring the last batch with %d samples / %d epoch size.", test_images.shape[0] - num_batches * gan.batch_size, gan.batch_size) ret = { "train_accuracy": [], "test_accuracy": [], "fake_accuracy": [], "train_d_loss": [], "test_d_loss": [] } for _ in range(num_repeat): idx = np.random.choice(train_images.shape[0], test_images.shape[0]) bs = gan.batch_size train_subset = [train_images[i] for i in idx] train_predictions, test_predictions, fake_predictions = [], [], [] train_d_losses, test_d_losses = [], [] for i in range(num_batches): z_sample = gan.z_generator(gan.batch_size, gan.z_dim) start_idx = i * bs end_idx = start_idx + bs test_batch = test_images[start_idx : end_idx] train_batch = train_subset[start_idx : end_idx] test_prediction, test_d_loss, fake_images = sess.run( [gan.discriminator_output, gan.d_loss, gan.fake_images], feed_dict={ gan.inputs: test_batch, gan.z: z_sample }) train_prediction, train_d_loss = sess.run( [gan.discriminator_output, gan.d_loss], feed_dict={ gan.inputs: train_batch, gan.z: z_sample }) fake_prediction = sess.run( gan.discriminator_output, feed_dict={gan.inputs: fake_images})[0] train_predictions.append(train_prediction[0]) test_predictions.append(test_prediction[0]) fake_predictions.append(fake_prediction) train_d_losses.append(train_d_loss) test_d_losses.append(test_d_loss) train_predictions = [x >= 0.5 for x in train_predictions] test_predictions = [x >= 0.5 for x in test_predictions] fake_predictions = [x < 0.5 for x in fake_predictions] ret["train_accuracy"].append(np.array(train_predictions).mean()) ret["test_accuracy"].append(np.array(test_predictions).mean()) ret["fake_accuracy"].append(np.array(fake_predictions).mean()) ret["train_d_loss"].append(np.mean(train_d_losses)) ret["test_d_loss"].append(np.mean(test_d_losses)) for key in ret: ret[key] = np.mean(ret[key]) return ret