def testBatchSize(self, disc_iters, use_tpu=True):
     parameters = {
         "architecture": c.RESNET5_ARCH,
         "lambda": 1,
         "z_dim": 128,
         "disc_iters": disc_iters,
     }
     batch_size = 16
     dataset = datasets.get_dataset("cifar10")
     gan = MockModularGAN(dataset=dataset,
                          parameters=parameters,
                          model_dir=self.model_dir)
     estimator = gan.as_estimator(self.run_config,
                                  batch_size=batch_size,
                                  use_tpu=use_tpu)
     estimator.train(gan.input_fn, steps=1)
     logging.info("gen_args: %s", "\n".join(str(a) for a in gan.gen_args))
     logging.info("disc_args: %s", "\n".join(str(a) for a in gan.disc_args))
     num_shards = 2 if use_tpu else 1
     assert batch_size % num_shards == 0
     gen_bs = batch_size // num_shards
     disc_bs = gen_bs * 2  # merged discriminator calls.
     self.assertLen(gan.gen_args, disc_iters + 2)
     for args in gan.gen_args:
         self.assertAllEqual(args["z"].shape.as_list(), [gen_bs, 128])
     self.assertLen(gan.disc_args, disc_iters + 1)
     for args in gan.disc_args:
         self.assertAllEqual(args["x"].shape.as_list(),
                             [disc_bs, 32, 32, 3])
Example #2
0
 def testSingleTrainingStepArchitectures(self,
                                         use_predictor,
                                         project_y=True,
                                         self_supervision="none"):
     parameters = {
         "architecture": c.RESNET_BIGGAN_ARCH,
         "lambda": 1,
         "z_dim": 120,
     }
     with gin.unlock_config():
         gin.bind_parameter("ModularGAN.conditional", True)
         gin.bind_parameter("loss.fn", loss_lib.hinge)
         gin.bind_parameter("S3GAN.use_predictor", use_predictor)
         gin.bind_parameter("S3GAN.project_y", project_y)
         gin.bind_parameter("S3GAN.self_supervision", self_supervision)
     # Fake ImageNet dataset by overriding the properties.
     dataset = datasets.get_dataset("imagenet_128")
     model_dir = self._get_empty_model_dir()
     run_config = tf.contrib.tpu.RunConfig(
         model_dir=model_dir,
         tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
     gan = S3GAN(dataset=dataset,
                 parameters=parameters,
                 model_dir=model_dir,
                 g_optimizer_fn=tf.train.AdamOptimizer,
                 g_lr=0.0002,
                 rotated_batch_fraction=2)
     estimator = gan.as_estimator(run_config, batch_size=8, use_tpu=False)
     estimator.train(gan.input_fn, steps=1)
Example #3
0
    def testBatchSizeSplitDiscCalls(self, disc_iters):
        parameters = {
            "architecture": c.DUMMY_ARCH,
            "lambda": 1,
            "z_dim": 128,
            "disc_iters": disc_iters,
        }
        batch_size = 16
        dataset = datasets.get_dataset("cifar10")
        gan = ModularGAN(dataset=dataset,
                         parameters=parameters,
                         deprecated_split_disc_calls=True,
                         model_dir=self.model_dir)
        estimator = gan.as_estimator(self.run_config,
                                     batch_size=batch_size,
                                     use_tpu=True)
        estimator.train(gan.input_fn, steps=1)

        gen_args = gan.generator.call_arg_list
        disc_args = gan.discriminator.call_arg_list
        self.assertLen(gen_args, disc_iters + 1)  # D steps, G step.
        # Each D and G step calls discriminator twice: for real and fake images.
        self.assertLen(disc_args, 2 * (disc_iters + 1))

        for args in gen_args:
            self.assertAllEqual(args["z"].shape.as_list(), [8, 128])
        for args in disc_args:
            self.assertAllEqual(args["x"].shape.as_list(), [8, 32, 32, 3])
Example #4
0
    def testDiscItersIsUsedCorrectly(self, disc_iters, use_tpu):
        parameters = {
            "architecture": c.DUMMY_ARCH,
            "disc_iters": disc_iters,
            "lambda": 1,
            "z_dim": 128,
        }
        run_config = tf.contrib.tpu.RunConfig(
            model_dir=self.model_dir,
            save_checkpoints_steps=1,
            tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
        dataset = datasets.get_dataset("cifar10")
        gan = ModularGAN(dataset=dataset,
                         parameters=parameters,
                         model_dir=self.model_dir)
        estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=use_tpu)
        estimator.train(gan.input_fn, steps=3)

        disc_step_values = []
        gen_step_values = []

        for step in range(4):
            basename = os.path.join(self.model_dir,
                                    "model.ckpt-{}".format(step))
            self.assertTrue(tf.gfile.Exists(basename + ".index"))
            ckpt = tf.train.load_checkpoint(basename)

            disc_step_values.append(ckpt.get_tensor("global_step_disc"))
            gen_step_values.append(ckpt.get_tensor("global_step"))

        expected_disc_steps = np.arange(4) * disc_iters
        self.assertAllEqual(disc_step_values, expected_disc_steps)
        self.assertAllEqual(gen_step_values, [0, 1, 2, 3])
Example #5
0
    def testBatchSizeExperimentalJointGenForDisc(self, disc_iters):
        parameters = {
            "architecture": c.DUMMY_ARCH,
            "lambda": 1,
            "z_dim": 128,
            "disc_iters": disc_iters,
        }
        batch_size = 16
        dataset = datasets.get_dataset("cifar10")
        gan = ModularGAN(dataset=dataset,
                         parameters=parameters,
                         experimental_joint_gen_for_disc=True,
                         model_dir=self.model_dir)
        estimator = gan.as_estimator(self.run_config,
                                     batch_size=batch_size,
                                     use_tpu=True)
        estimator.train(gan.input_fn, steps=1)

        gen_args = gan.generator.call_arg_list
        disc_args = gan.discriminator.call_arg_list
        self.assertLen(gen_args, 2)
        self.assertLen(disc_args, disc_iters + 1)

        self.assertAllEqual(gen_args[0]["z"].shape.as_list(),
                            [8 * disc_iters, 128])
        self.assertAllEqual(gen_args[1]["z"].shape.as_list(), [8, 128])
        for args in disc_args:
            self.assertAllEqual(args["x"].shape.as_list(), [16, 32, 32, 3])
Example #6
0
 def testSingleTrainingStepDiscItersWithEma(self, disc_iters):
     parameters = {
         "architecture": c.DUMMY_ARCH,
         "lambda": 1,
         "z_dim": 128,
         "dics_iters": disc_iters,
     }
     gin.bind_parameter("ModularGAN.g_use_ema", True)
     dataset = datasets.get_dataset("cifar10")
     gan = ModularGAN(dataset=dataset,
                      parameters=parameters,
                      model_dir=self.model_dir)
     estimator = gan.as_estimator(self.run_config,
                                  batch_size=2,
                                  use_tpu=False)
     estimator.train(gan.input_fn, steps=1)
     # Check for moving average variables in checkpoint.
     checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
     ema_vars = sorted([
         v[0] for v in tf.train.list_variables(checkpoint_path)
         if v[0].endswith("ExponentialMovingAverage")
     ])
     tf.logging.info("ema_vars=%s", ema_vars)
     expected_ema_vars = sorted([
         "generator/fc_noise/kernel/ExponentialMovingAverage",
         "generator/fc_noise/bias/ExponentialMovingAverage",
     ])
     self.assertAllEqual(ema_vars, expected_ema_vars)
Example #7
0
  def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
    parameters = {
        "architecture": architecture,
        "lambda": 1,
        "z_dim": 128,
    }
    with gin.unlock_config():
      gin.bind_parameter("penalty.fn", penalty_fn)
      gin.bind_parameter("loss.fn", loss_fn)
      gin.bind_parameter("G.batch_norm_fn", evonorm_s0)

    model_dir = self._get_empty_model_dir()
    run_config = tf.contrib.tpu.RunConfig(
        model_dir=model_dir,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
    dataset = datasets.get_dataset("cifar10")

    gan = CLGAN(
        dataset=dataset,
        parameters=parameters,
        model_dir=model_dir,
        g_optimizer_fn=tf.train.AdamOptimizer,
        g_lr=0.0002,
    )
    estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
    estimator.train(gan.input_fn, steps=1)
Example #8
0
 def _run_train_input_fn(self, dataset_name, preprocess_fn):
     dataset = datasets.get_dataset(dataset_name)
     with tf.Graph().as_default():
         dataset = dataset.input_fn(params={"batch_size": 1},
                                    preprocess_fn=preprocess_fn)
         iterator = dataset.make_initializable_iterator()
         with self.session() as sess:
             sess.run(iterator.initializer)
             next_batch = iterator.get_next()
             return [sess.run(next_batch) for _ in range(5)]
Example #9
0
 def get_element_and_verify_shape(self, dataset_name, expected_shape):
     dataset = datasets.get_dataset(dataset_name)
     dataset = dataset.eval_input_fn()
     image, label = dataset.make_one_shot_iterator().get_next()
     # Check if shape is known at compile time, required for TPUs.
     self.assertAllEqual(image.shape.as_list(), expected_shape)
     self.assertEqual(image.dtype, tf.float32)
     self.assertIn(label.dtype, _TPU_SUPPORTED_TYPES)
     with self.cached_session() as session:
         image = session.run(image)
         self.assertEqual(image.shape, expected_shape)
         self.assertGreaterEqual(image.min(), 0.0)
         self.assertLessEqual(image.max(), 1.0)
Example #10
0
 def testUnlabledDatasetRaisesError(self):
     parameters = {
         "architecture": c.RESNET_CIFAR,
         "lambda": 1,
         "z_dim": 120,
     }
     with gin.unlock_config():
         gin.bind_parameter("loss.fn", loss_lib.hinge)
     # Use dataset without labels.
     dataset = datasets.get_dataset("celeb_a")
     with self.assertRaises(ValueError):
         gan = ModularGAN(dataset=dataset,
                          parameters=parameters,
                          conditional=True,
                          model_dir=self.model_dir)
         del gan
Example #11
0
 def testSingleTrainingStepWithJointGenForDisc(self):
     parameters = {
         "architecture": c.RESNET5_BIGGAN_ARCH,
         "lambda": 1,
         "z_dim": 120,
         "disc_iters": 2,
     }
     dataset = datasets.get_dataset("cifar10")
     gan = ModularGAN(dataset=dataset,
                      parameters=parameters,
                      model_dir=self.model_dir,
                      experimental_joint_gen_for_disc=True,
                      conditional=True)
     estimator = gan.as_estimator(self.run_config,
                                  batch_size=2,
                                  use_tpu=False)
     estimator.train(gan.input_fn, steps=1)
Example #12
0
    def test_end2end_checkpoint(self, architecture):
        """Takes real GAN (trained for 1 step) and evaluate it."""
        if architecture in {c.RESNET_STL_ARCH, c.RESNET30_ARCH}:
            # RESNET_STL_ARCH and RESNET107_ARCH do not support CIFAR image shape.
            return
        gin.bind_parameter("dataset.name", "cifar10")
        dataset = datasets.get_dataset("cifar10")
        options = {
            "architecture": architecture,
            "z_dim": 120,
            "disc_iters": 1,
            "lambda": 1,
        }
        model_dir = os.path.join(tf.test.get_temp_dir(), self.id())
        tf.logging.info("model_dir: %s" % model_dir)
        run_config = tf.contrib.tpu.RunConfig(model_dir=model_dir)
        gan = ModularGAN(dataset=dataset,
                         parameters=options,
                         conditional="biggan" in architecture,
                         model_dir=model_dir)
        estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
        estimator.train(input_fn=gan.input_fn, steps=1)
        export_path = os.path.join(model_dir, "tfhub")
        checkpoint_path = os.path.join(model_dir, "model.ckpt-1")
        module_spec = gan.as_module_spec()
        module_spec.export(export_path, checkpoint_path=checkpoint_path)

        eval_tasks = [
            fid_score.FIDScoreTask(),
            fractal_dimension.FractalDimensionTask(),
            inception_score.InceptionScoreTask(),
            ms_ssim_score.MultiscaleSSIMTask()
        ]
        result_dict = eval_gan_lib.evaluate_tfhub_module(export_path,
                                                         eval_tasks,
                                                         use_tpu=False,
                                                         num_averaging_runs=1)
        tf.logging.info("result_dict: %s", result_dict)
        for score in [
                "fid_score", "fractal_dimension", "inception_score", "ms_ssim"
        ]:
            for stats in ["mean", "std", "list"]:
                required_key = "%s_%s" % (score, stats)
                self.assertIn(required_key, result_dict,
                              "Missing: %s." % required_key)
Example #13
0
 def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
     parameters = {
         "architecture": architecture,
         "lambda": 1,
         "z_dim": 128,
     }
     with gin.unlock_config():
         gin.bind_parameter("penalty.fn", penalty_fn)
         gin.bind_parameter("loss.fn", loss_fn)
     dataset = datasets.get_dataset("cifar10")
     gan = ModularGAN(dataset=dataset,
                      parameters=parameters,
                      model_dir=self.model_dir,
                      conditional="biggan" in architecture)
     estimator = gan.as_estimator(self.run_config,
                                  batch_size=2,
                                  use_tpu=False)
     estimator.train(gan.input_fn, steps=1)
Example #14
0
 def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn,
                            labeled_dataset):
     parameters = {
         "architecture": architecture,
         "lambda": 1,
         "z_dim": 120,
     }
     with gin.unlock_config():
         gin.bind_parameter("penalty.fn", penalty_fn)
         gin.bind_parameter("loss.fn", loss_fn)
     run_config = tf.contrib.tpu.RunConfig(
         model_dir=self.model_dir,
         tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
     dataset = datasets.get_dataset("cifar10")
     gan = ModularGAN(dataset=dataset,
                      parameters=parameters,
                      conditional=True,
                      model_dir=self.model_dir)
     estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
     estimator.train(gan.input_fn, steps=1)
Example #15
0
 def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
     parameters = {
         "architecture": architecture,
         "discriminator_normalization": c.SPECTRAL_NORM,
         "lambda": 1,
         "z_dim": 128,
     }
     with gin.unlock_config():
         gin.bind_parameter("penalty.fn", penalty_fn)
         gin.bind_parameter("loss.fn", loss_fn)
     run_config = tf.contrib.tpu.RunConfig(
         model_dir=self.model_dir,
         tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
     dataset = datasets.get_dataset("cifar10")
     gan = SSGAN(dataset=dataset,
                 parameters=parameters,
                 model_dir=self.model_dir,
                 g_optimizer_fn=tf.train.AdamOptimizer,
                 g_lr=0.0002,
                 rotated_batch_size=4)
     estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
     estimator.train(gan.input_fn, steps=1)
Example #16
0
def evaluate_tfhub_module(module_spec, eval_tasks, use_tpu,
                          num_averaging_runs, step):
  """Evaluate model at given checkpoint_path.

  Args:
    module_spec: string, path to a TF hub module.
    eval_tasks: List of objects that inherit from EvalTask.
    use_tpu: Whether to use TPUs.
    num_averaging_runs: Determines how many times each metric is computed.
    step: Name of the step being evaluated

  Returns:
    Dict[Text, float] with all the computed results.

  Raises:
    NanFoundError: If generator output has any NaNs.
  """
  # Make sure that the same latent variables are used for each evaluation.
  np.random.seed(42)
  dataset = datasets.get_dataset()
  num_test_examples = dataset.eval_test_samples

  batch_size = FLAGS.eval_batch_size
  num_batches = int(np.ceil(num_test_examples / batch_size))

  # Load and update the generator.
  result_dict = {}
  fake_dsets = []
  with tf.Graph().as_default():
    tf.set_random_seed(42)
    with tf.Session() as sess:
      if use_tpu:
        sess.run(tf.contrib.tpu.initialize_system())
      def sample_from_generator():
        """Create graph for sampling images."""
        generator = hub.Module(
            module_spec,
            name="gen_module",
            tags={"gen", "bs{}".format(batch_size)})
        logging.info("Generator inputs: %s", generator.get_input_info_dict())
        z_dim = generator.get_input_info_dict()["z"].get_shape()[1].value
        z = z_generator(shape=[batch_size, z_dim])
        if "labels" in generator.get_input_info_dict():
          # Conditional GAN.
          assert dataset.num_classes

          if FLAGS.force_label is None:
            labels = tf.random.uniform(
                [batch_size], maxval=dataset.num_classes, dtype=tf.int32)
          else:
            labels = tf.constant(FLAGS.force_label, shape=[batch_size], dtype=tf.int32)

          inputs = dict(z=z, labels=labels)
        else:
          # Unconditional GAN.
          assert "labels" not in generator.get_input_info_dict()
          inputs = dict(z=z)
        return generator(inputs=inputs, as_dict=True)["generated"]
      
      if use_tpu:
        generated = tf.contrib.tpu.rewrite(sample_from_generator)
      else:
        generated = sample_from_generator()

      tf.global_variables_initializer().run()

      save_model_accu_path = os.path.join(module_spec, "model-with-accu.ckpt")

      if not tf.io.gfile.exists(save_model_accu_path):
        if _update_bn_accumulators(sess, generated, num_accu_examples=204800):
          saver = tf.train.Saver()
          checkpoint_path = saver.save(
              sess,
              save_path=save_model_accu_path)
          logging.info("Exported generator with accumulated batch stats to "
                       "%s.", checkpoint_path)
      if not eval_tasks:
        logging.error("Task list is empty, returning.")
        return

      for i in range(num_averaging_runs):
        logging.info("Generating fake data set %d/%d.", i+1, num_averaging_runs)
        fake_dset = eval_utils.EvalDataSample(
            eval_utils.sample_fake_dataset(sess, generated, num_batches))
        fake_dsets.append(fake_dset)

        # Hacking this in here for speed for now
        save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step)

        logging.info("Computing inception features for generated data %d/%d.",
                     i+1, num_averaging_runs)
        activations, logits = eval_utils.inception_transform_np(
            fake_dset.images, batch_size)
        fake_dset.set_inception_features(
            activations=activations, logits=logits)
        fake_dset.set_num_examples(num_test_examples)
        if i != 0:
          # Free up some memory by releasing additional fake data samples.
          # For ImageNet128 50k images are ~9 GiB. This will blow up metrics
          # (such as fractal dimension) if num_averaging_runs > 1.
          fake_dset.discard_images()

  real_dset = eval_utils.EvalDataSample(
      eval_utils.get_real_images(
          dataset=dataset, num_examples=num_test_examples))
  logging.info("Getting Inception features for real images.")
  real_dset.activations, _ = eval_utils.inception_transform_np(
      real_dset.images, batch_size)
  real_dset.set_num_examples(num_test_examples)

  # Run all the tasks and update the result dictionary with the task statistics.
  result_dict = {}
  for task in eval_tasks:
    task_results_dicts = [
        task.run_after_session(fake_dset, real_dset)
        for fake_dset in fake_dsets
    ]
    # Average the score for each key.
    result_statistics = {}
    for key in task_results_dicts[0].keys():
      scores_for_key = np.array([d[key] for d in task_results_dicts])
      mean, std = np.mean(scores_for_key), np.std(scores_for_key)
      scores_as_string = "_".join([str(x) for x in scores_for_key])
      result_statistics[key + "_mean"] = mean
      result_statistics[key + "_std"] = std
      result_statistics[key + "_list"] = scores_as_string
    logging.info("Computed results for task %s: %s", task, result_statistics)

    result_dict.update(result_statistics)
  return result_dict
Example #17
0
    def testDiscItersIsUsedCorrectly(self, disc_iters, use_tpu):

        return
        if disc_iters > 1 and use_tpu:

            return
        parameters = {
            "architecture": c.RESNET_CIFAR,
            "disc_iters": disc_iters,
            "lambda": 1,
            "z_dim": 128,
        }
        if not use_tpu:
            parameters["unroll_disc_iters"] = False
        run_config = tf.contrib.tpu.RunConfig(
            model_dir=self.model_dir,
            save_checkpoints_steps=1,
            tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
        dataset = datasets.get_dataset("cifar10")
        gan = ModularGAN(dataset=dataset,
                         parameters=parameters,
                         model_dir=self.model_dir)
        estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=use_tpu)
        estimator.train(gan.input_fn, steps=3)

        # Read checkpoints for each training step. If the weight in the generator
        # changed we trained the generator during that step.
        previous_values = {}
        generator_trained = []
        for step in range(0, 4):
            basename = os.path.join(self.model_dir,
                                    "model.ckpt-{}".format(step))
            self.assertTrue(tf.gfile.Exists(basename + ".index"))
            ckpt = tf.train.load_checkpoint(basename)

            if step == 0:
                for name in ckpt.get_variable_to_shape_map():
                    previous_values[name] = ckpt.get_tensor(name)
                continue

            d_trained = False
            g_trained = False
            for name in ckpt.get_variable_to_shape_map():
                t = ckpt.get_tensor(name)
                diff = np.abs(previous_values[name] - t).max()
                previous_values[name] = t
                if "discriminator" in name and diff > 1e-10:
                    d_trained = True
                elif "generator" in name and diff > 1e-10:
                    if name.endswith("moving_mean") or name.endswith(
                            "moving_variance"):
                        # Note: Even when we don't train the generator the batch norm
                        # values still get updated.
                        continue
                    tf.logging.info("step %d: %s changed up to %f", step, name,
                                    diff)
                    g_trained = True
            self.assertTrue(d_trained)  # Discriminator is trained every step.
            generator_trained.append(g_trained)

        self.assertEqual(generator_trained,
                         GENERATOR_TRAINED_IN_STEPS[disc_iters - 1])
Example #18
0
def generate_tfhub_module(module_spec, use_tpu, step):
  """Generate from model at given checkpoint_path.

  Args:
    module_spec: string, path to a TF hub module.
    use_tpu: Whether to use TPUs.
    step: Name of the step being evaluated

  Returns:
    Nothing

  Raises:
    NanFoundError: If generator output has any NaNs.
  """
  # Make sure that the same latent variables are used for each evaluation.
  np.random.seed(42)
  dataset = datasets.get_dataset()
  num_test_examples = FLAGS.num_samples

  batch_size = FLAGS.eval_batch_size
  num_batches = int(np.ceil(num_test_examples / batch_size))

  # Load and update the generator.
  result_dict = {}
  fake_dsets = []
  with tf.Graph().as_default():
    tf.set_random_seed(42)
    with tf.Session() as sess:
      if use_tpu:
        sess.run(tf.contrib.tpu.initialize_system())

      def create_generator(force_label=None):
        def sample_from_generator():
          """Create graph for sampling images."""
          generator = hub.Module(
              module_spec,
              name="gen_module",
              tags={"gen", "bs{}".format(batch_size)})
          logging.info("Generator inputs: %s", generator.get_input_info_dict())
          z_dim = generator.get_input_info_dict()["z"].get_shape()[1].value
          z = z_generator(shape=[batch_size, z_dim])
          if "labels" in generator.get_input_info_dict():
            # Conditional GAN.
            assert dataset.num_classes

            if force_label is None:
              labels = tf.random.uniform(
                  [batch_size], maxval=dataset.num_classes, dtype=tf.int32)
            else:
              labels = tf.constant(force_label, shape=[batch_size], dtype=tf.int32)

            inputs = dict(z=z, labels=labels)
          else:
            # Unconditional GAN.
            assert "labels" not in generator.get_input_info_dict()
            inputs = dict(z=z)
          return generator(inputs=inputs, as_dict=True)["generated"]
        return sample_from_generator

      if use_tpu:
        generated = tf.contrib.tpu.rewrite(create_generator())
      else:
        generated = create_generator()()

      tf.global_variables_initializer().run()

      save_model_accu_path = os.path.join(module_spec, "model-with-accu.ckpt")
      saver = tf.train.Saver()
    
      if _update_bn_accumulators(sess, generated, num_accu_examples=FLAGS.num_accu_examples):
        checkpoint_path = saver.save(sess, save_path=save_model_accu_path)
        logging.info("Exported generator with accumulated batch stats to "
                     "%s.", checkpoint_path)

      logging.info("Generating fake data set")
      fake_dset = eval_utils.EvalDataSample(
          eval_utils.sample_fake_dataset(sess, generated, num_batches))

      save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step)

      if FLAGS.force_label is not None:
        if use_tpu:
          generated = tf.contrib.tpu.rewrite(create_generator(FLAGS.force_label))
        else:
          generated = create_generator(FLAGS.force_label)()

        tf.global_variables_initializer().run()
        
        _update_bn_accumulators(sess, generated, num_accu_examples=FLAGS.num_accu_examples)

        logging.info("Generating fake data set with forced label")
        fake_dset = eval_utils.EvalDataSample(
            eval_utils.sample_fake_dataset(sess, generated, num_batches))

        save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step, force_label=FLAGS.force_label)
def run_with_schedule(schedule, run_config, task_manager, options, use_tpu,
                      num_eval_averaging_runs=1, eval_every_steps=-1):
  """Run the schedule with the given options.

  Available schedules:
  - train: Train up to options["training_steps"], continuing from existing
      checkpoints if available.
  - eval_after_train: First train up to options["training_steps"] then
      evaluate all checkpoints.
  - continuous_eval: Waiting for new checkpoints and evaluate them as they
      become available. This is meant to run in parallel with a job running
      the training schedule but can also run after it.

  Args:
    schedule: Schedule to run. One of: train, continuous_eval, train_and_eval.
    run_config: `tf.contrib.tpu.RunConfig` to use.
    task_manager: `TaskManager` for this run.
    options: Python dictionary will run parameters.
    use_tpu: Boolean whether to use TPU.
    num_eval_averaging_runs: Determines how many times each metric is computed.
    eval_every_steps: Integer determining which checkpoints to evaluate.
  """
  logging.info("Running schedule '%s' with options: %s", schedule, options)
  if run_config.tf_random_seed:
    logging.info("Setting NumPy random seed to %s.", run_config.tf_random_seed)
    np.random.seed(run_config.tf_random_seed)

  result_dir = os.path.join(run_config.model_dir, "result")
  utils.check_folder(result_dir)

  dataset = datasets.get_dataset()
  gan = options["gan_class"](dataset=dataset,
                             parameters=options,
                             model_dir=run_config.model_dir)

  if schedule not in {"train", "eval_after_train", "continuous_eval"}:
    raise ValueError("Schedule {} not supported.".format(schedule))
  if schedule in {"train", "eval_after_train"}:
    train_hooks = [
        gin.tf.GinConfigSaverHook(run_config.model_dir),
        hooks.ReportProgressHook(task_manager,
                                 max_steps=options["training_steps"]),
    ]
    if run_config.save_checkpoints_steps:
      # This replaces the default checkpoint saver hook in the estimator.
      logging.info("Using AsyncCheckpointSaverHook.")
      train_hooks.append(
          hooks.AsyncCheckpointSaverHook(
              checkpoint_dir=run_config.model_dir,
              save_steps=run_config.save_checkpoints_steps))
      # (b/122782388): Remove hotfix.
      run_config = run_config.replace(save_checkpoints_steps=1000000)
    estimator = gan.as_estimator(
        run_config, batch_size=options["batch_size"], use_tpu=use_tpu)
    estimator.train(
        input_fn=gan.input_fn,
        max_steps=options["training_steps"],
        hooks=train_hooks)
    task_manager.mark_training_done()

  if schedule == "continuous_eval":
    # Continuous eval with up to 24 hours between checkpoints.
    checkpoints = task_manager.unevaluated_checkpoints(
        timeout=24 * 3600, eval_every_steps=eval_every_steps)
  if schedule == "eval_after_train":
    checkpoints = task_manager.unevaluated_checkpoints(
        eval_every_steps=eval_every_steps)
  if schedule in {"continuous_eval", "eval_after_train"}:
    _run_eval(
        gan.as_module_spec(),
        checkpoints=checkpoints,
        task_manager=task_manager,
        run_config=run_config,
        use_tpu=use_tpu,
        num_averaging_runs=num_eval_averaging_runs)
Example #20
0
def train_gilbo(gan, sess, outdir, checkpoint_path, dataset, options):
    """Build and train GILBO model.

  Args:
    gan: GAN object.
    sess: tf.Session.
    outdir: Output directory. A pickle file will be written there.
    checkpoint_path: Path where gan"s checkpoints are written. Only used to
                     ensure that GILBO files are written to a unique
                     subdirectory of outdir.
    dataset: Name of dataset used to train the GAN.
    options: Options dictionary.

  Returns:
    mean_eval_info: Mean GILBO computed over a large number of images generated
                    by the trained GAN
    mean_train_consistency: Mean consistency of the trained GILBO model with
                            data from the training set.
    mean_eval_consistency: Same consistency measure for the trained model with
                           data from the validation set.
    mean_self_consistency: Same consistency measure for the trained model with
                           data generated by the trained model itself.
    See the GILBO paper for an explanation of these metrics.

  Raises:
    ValueError: If the GAN has uninitialized variables.
  """
    uninitialized = sess.run(tf.report_uninitialized_variables())
    if uninitialized:
        raise ValueError("Model has uninitialized variables!\n%r" %
                         uninitialized)

    outdir = os.path.join(outdir, checkpoint_path.replace("/", "_"))

    tf.gfile.MakeDirs(outdir)
    with tf.variable_scope("gilbo"):
        ones = tf.ones((gan.batch_size, gan.z_dim))
        # Get a distribution for the prior.
        z_dist = ds.Independent(ds.Uniform(-ones, ones), 1)
        z_sample = z_dist.sample()
        epsneg = np.finfo("float32").epsneg
        # Clip samples from the GAN uniform prior because the Beta distribution
        # doesn"t include the top endpoint and has issues with the bottom endpoint.
        ganz_clip = tf.clip_by_value(gan.z, -(1 - epsneg), 1 - epsneg)

        # Get generated images from the model.
        fake_images = gan.fake_images

        # Build the regressor distribution that encodes images back to predicted
        # samples from the prior.
        with tf.variable_scope("regressor"):
            z_pred_dist = _build_regressor(fake_images, gan.z_dim)
        # Capture the parameters of the distributions for later analysis.
        dist_p1 = z_pred_dist.distribution.distribution.concentration0
        dist_p2 = z_pred_dist.distribution.distribution.concentration1

        # info and avg_info compute the GILBO.
        info = z_pred_dist.log_prob(ganz_clip) - z_dist.log_prob(ganz_clip)
        avg_info = tf.reduce_mean(info)

        # Set up training of the GILBO model.
        lr = options.get("gilbo_learning_rate", 4e-4)
        learning_rate = tf.get_variable("learning_rate",
                                        initializer=lr,
                                        trainable=False)
        gilbo_step = tf.get_variable("gilbo_step",
                                     dtype=tf.int32,
                                     initializer=0,
                                     trainable=False)
        opt = tf.train.AdamOptimizer(learning_rate)

        regressor_vars = tf.contrib.framework.get_variables("gilbo/regressor")
        train_op = opt.minimize(-info, var_list=regressor_vars)

    # Initialize the variables we just created.
    uninitialized = plist(tf.report_uninitialized_variables().eval())
    uninitialized_vars = uninitialized.apply(
        tf.contrib.framework.get_variables_by_name)._[0]
    tf.variables_initializer(uninitialized_vars).run()

    saver = tf.train.Saver(uninitialized_vars, max_to_keep=1)
    try:
        checkpoint_path = tf.train.latest_checkpoint(outdir)
        saver.restore(sess, checkpoint_path)
    except ValueError:
        # Failing to restore just indicates that we don"t have a valid checkpoint,
        # so we will just start training a fresh GILBO model.
        pass
    _train_gilbo(sess, gan, saver, learning_rate, gilbo_step, z_sample,
                 avg_info, z_pred_dist, train_op, outdir, options)

    mean_eval_info = _eval_gilbo(sess, gan, z_sample, avg_info, dist_p1,
                                 dist_p2, fake_images, outdir, options)
    # Collect encoded distributions on the training and eval set in order to do
    # kl-nearest-neighbors on generated samples and measure consistency.
    dataset = datasets.get_dataset(dataset)
    x_train = dataset.load_dataset(split_name="train", num_threads=1)
    x_train = x_train.batch(gan.batch_size, drop_remainder=True)
    x_train = x_train.make_one_shot_iterator().get_next()[0]
    x_train = tf.reshape(x_train, fake_images.shape)

    x_eval = dataset.load_dataset(split_name="test", num_threads=1)
    x_eval = x_eval.batch(gan.batch_size, drop_remainder=True)
    x_eval = x_eval.make_one_shot_iterator().get_next()[0]
    x_eval = tf.reshape(x_eval, fake_images.shape)

    mean_train_consistency = _run_gilbo_consistency(x_train,
                                                    "train",
                                                    extract_input_images=0,
                                                    save_consistency_images=20,
                                                    num_batches=5,
                                                    **locals())
    mean_eval_consistency = _run_gilbo_consistency(x_eval,
                                                   "eval",
                                                   extract_input_images=0,
                                                   save_consistency_images=20,
                                                   num_batches=5,
                                                   **locals())
    mean_self_consistency = _run_gilbo_consistency(fake_images,
                                                   "self",
                                                   extract_input_images=20,
                                                   save_consistency_images=20,
                                                   num_batches=5,
                                                   **locals())
    return (mean_eval_info, mean_train_consistency, mean_eval_consistency,
            mean_self_consistency)
def compute_accuracy_loss(sess,
                          gan,
                          test_images,
                          max_train_examples=50000,
                          num_repeat=5):
  """Compute discriminator's accuracy and loss on a given dataset.

  Args:
    sess: Tf.Session object.
    gan: Any AbstractGAN instance.
    test_images: numpy array with test images.
    max_train_examples: How many "train" examples to get from the dataset.
                        In each round, some of them will be randomly selected
                        to evaluate train set accuracy.
    num_repeat: How many times to repreat the computation.
                The mean of all the results is reported.
  Returns:
    Dict[Text, float] with all the computed scores.

  Raises:
    ValueError: If the number of test_images is greater than the number of
                training images returned by the dataset.
  """
  logging.info("Evaluating training and test accuracy...")
  train_images = eval_utils.get_real_images(
      dataset=datasets.get_dataset(),
      num_examples=max_train_examples,
      split="train",
      failure_on_insufficient_examples=False)
  if train_images.shape[0] < test_images.shape[0]:
    raise ValueError("num_train %d must be larger than num_test %d." %
                     (train_images.shape[0], test_images.shape[0]))

  num_batches = int(np.floor(test_images.shape[0] / gan.batch_size))
  if num_batches * gan.batch_size < test_images.shape[0]:
    logging.error("Ignoring the last batch with %d samples / %d epoch size.",
                  test_images.shape[0] - num_batches * gan.batch_size,
                  gan.batch_size)

  ret = {
      "train_accuracy": [],
      "test_accuracy": [],
      "fake_accuracy": [],
      "train_d_loss": [],
      "test_d_loss": []
  }

  for _ in range(num_repeat):
    idx = np.random.choice(train_images.shape[0], test_images.shape[0])
    bs = gan.batch_size
    train_subset = [train_images[i] for i in idx]
    train_predictions, test_predictions, fake_predictions = [], [], []
    train_d_losses, test_d_losses = [], []

    for i in range(num_batches):
      z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
      start_idx = i * bs
      end_idx = start_idx + bs
      test_batch = test_images[start_idx : end_idx]
      train_batch = train_subset[start_idx : end_idx]

      test_prediction, test_d_loss, fake_images = sess.run(
          [gan.discriminator_output, gan.d_loss, gan.fake_images],
          feed_dict={
              gan.inputs: test_batch, gan.z: z_sample
          })
      train_prediction, train_d_loss = sess.run(
          [gan.discriminator_output, gan.d_loss],
          feed_dict={
              gan.inputs: train_batch,
              gan.z: z_sample
          })
      fake_prediction = sess.run(
          gan.discriminator_output,
          feed_dict={gan.inputs: fake_images})[0]

      train_predictions.append(train_prediction[0])
      test_predictions.append(test_prediction[0])
      fake_predictions.append(fake_prediction)
      train_d_losses.append(train_d_loss)
      test_d_losses.append(test_d_loss)

    train_predictions = [x >= 0.5 for x in train_predictions]
    test_predictions = [x >= 0.5 for x in test_predictions]
    fake_predictions = [x < 0.5 for x in fake_predictions]

    ret["train_accuracy"].append(np.array(train_predictions).mean())
    ret["test_accuracy"].append(np.array(test_predictions).mean())
    ret["fake_accuracy"].append(np.array(fake_predictions).mean())
    ret["train_d_loss"].append(np.mean(train_d_losses))
    ret["test_d_loss"].append(np.mean(test_d_losses))

  for key in ret:
    ret[key] = np.mean(ret[key])

  return ret