コード例 #1
0
    def testBatchSizeExperimentalJointGenForDisc(self, disc_iters):
        parameters = {
            "architecture": c.DUMMY_ARCH,
            "lambda": 1,
            "z_dim": 128,
            "disc_iters": disc_iters,
        }
        batch_size = 16
        dataset = datasets.get_dataset("cifar10")
        gan = ModularGAN(dataset=dataset,
                         parameters=parameters,
                         experimental_joint_gen_for_disc=True,
                         model_dir=self.model_dir)
        estimator = gan.as_estimator(self.run_config,
                                     batch_size=batch_size,
                                     use_tpu=True)
        estimator.train(gan.input_fn, steps=1)

        gen_args = gan.generator.call_arg_list
        disc_args = gan.discriminator.call_arg_list
        self.assertLen(gen_args, 2)
        self.assertLen(disc_args, disc_iters + 1)

        self.assertAllEqual(gen_args[0]["z"].shape.as_list(),
                            [8 * disc_iters, 128])
        self.assertAllEqual(gen_args[1]["z"].shape.as_list(), [8, 128])
        for args in disc_args:
            self.assertAllEqual(args["x"].shape.as_list(), [16, 32, 32, 3])
コード例 #2
0
    def testDiscItersIsUsedCorrectly(self, disc_iters, use_tpu):
        parameters = {
            "architecture": c.DUMMY_ARCH,
            "disc_iters": disc_iters,
            "lambda": 1,
            "z_dim": 128,
        }
        run_config = tf.contrib.tpu.RunConfig(
            model_dir=self.model_dir,
            save_checkpoints_steps=1,
            tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
        dataset = datasets.get_dataset("cifar10")
        gan = ModularGAN(dataset=dataset,
                         parameters=parameters,
                         model_dir=self.model_dir)
        estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=use_tpu)
        estimator.train(gan.input_fn, steps=3)

        disc_step_values = []
        gen_step_values = []

        for step in range(4):
            basename = os.path.join(self.model_dir,
                                    "model.ckpt-{}".format(step))
            self.assertTrue(tf.gfile.Exists(basename + ".index"))
            ckpt = tf.train.load_checkpoint(basename)

            disc_step_values.append(ckpt.get_tensor("global_step_disc"))
            gen_step_values.append(ckpt.get_tensor("global_step"))

        expected_disc_steps = np.arange(4) * disc_iters
        self.assertAllEqual(disc_step_values, expected_disc_steps)
        self.assertAllEqual(gen_step_values, [0, 1, 2, 3])
コード例 #3
0
    def testBatchSizeSplitDiscCalls(self, disc_iters):
        parameters = {
            "architecture": c.DUMMY_ARCH,
            "lambda": 1,
            "z_dim": 128,
            "disc_iters": disc_iters,
        }
        batch_size = 16
        dataset = datasets.get_dataset("cifar10")
        gan = ModularGAN(dataset=dataset,
                         parameters=parameters,
                         deprecated_split_disc_calls=True,
                         model_dir=self.model_dir)
        estimator = gan.as_estimator(self.run_config,
                                     batch_size=batch_size,
                                     use_tpu=True)
        estimator.train(gan.input_fn, steps=1)

        gen_args = gan.generator.call_arg_list
        disc_args = gan.discriminator.call_arg_list
        self.assertLen(gen_args, disc_iters + 1)  # D steps, G step.
        # Each D and G step calls discriminator twice: for real and fake images.
        self.assertLen(disc_args, 2 * (disc_iters + 1))

        for args in gen_args:
            self.assertAllEqual(args["z"].shape.as_list(), [8, 128])
        for args in disc_args:
            self.assertAllEqual(args["x"].shape.as_list(), [8, 32, 32, 3])
コード例 #4
0
 def testSingleTrainingStepDiscItersWithEma(self, disc_iters):
     parameters = {
         "architecture": c.DUMMY_ARCH,
         "lambda": 1,
         "z_dim": 128,
         "dics_iters": disc_iters,
     }
     gin.bind_parameter("ModularGAN.g_use_ema", True)
     dataset = datasets.get_dataset("cifar10")
     gan = ModularGAN(dataset=dataset,
                      parameters=parameters,
                      model_dir=self.model_dir)
     estimator = gan.as_estimator(self.run_config,
                                  batch_size=2,
                                  use_tpu=False)
     estimator.train(gan.input_fn, steps=1)
     # Check for moving average variables in checkpoint.
     checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
     ema_vars = sorted([
         v[0] for v in tf.train.list_variables(checkpoint_path)
         if v[0].endswith("ExponentialMovingAverage")
     ])
     tf.logging.info("ema_vars=%s", ema_vars)
     expected_ema_vars = sorted([
         "generator/fc_noise/kernel/ExponentialMovingAverage",
         "generator/fc_noise/bias/ExponentialMovingAverage",
     ])
     self.assertAllEqual(ema_vars, expected_ema_vars)
コード例 #5
0
 def testSingleTrainingStepWithJointGenForDisc(self):
     parameters = {
         "architecture": c.RESNET5_BIGGAN_ARCH,
         "lambda": 1,
         "z_dim": 120,
         "disc_iters": 2,
     }
     dataset = datasets.get_dataset("cifar10")
     gan = ModularGAN(dataset=dataset,
                      parameters=parameters,
                      model_dir=self.model_dir,
                      experimental_joint_gen_for_disc=True,
                      conditional=True)
     estimator = gan.as_estimator(self.run_config,
                                  batch_size=2,
                                  use_tpu=False)
     estimator.train(gan.input_fn, steps=1)
コード例 #6
0
    def test_end2end_checkpoint(self, architecture):
        """Takes real GAN (trained for 1 step) and evaluate it."""
        if architecture in {c.RESNET_STL_ARCH, c.RESNET30_ARCH}:
            # RESNET_STL_ARCH and RESNET107_ARCH do not support CIFAR image shape.
            return
        gin.bind_parameter("dataset.name", "cifar10")
        dataset = datasets.get_dataset("cifar10")
        options = {
            "architecture": architecture,
            "z_dim": 120,
            "disc_iters": 1,
            "lambda": 1,
        }
        model_dir = os.path.join(tf.test.get_temp_dir(), self.id())
        tf.logging.info("model_dir: %s" % model_dir)
        run_config = tf.contrib.tpu.RunConfig(model_dir=model_dir)
        gan = ModularGAN(dataset=dataset,
                         parameters=options,
                         conditional="biggan" in architecture,
                         model_dir=model_dir)
        estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
        estimator.train(input_fn=gan.input_fn, steps=1)
        export_path = os.path.join(model_dir, "tfhub")
        checkpoint_path = os.path.join(model_dir, "model.ckpt-1")
        module_spec = gan.as_module_spec()
        module_spec.export(export_path, checkpoint_path=checkpoint_path)

        eval_tasks = [
            fid_score.FIDScoreTask(),
            fractal_dimension.FractalDimensionTask(),
            inception_score.InceptionScoreTask(),
            ms_ssim_score.MultiscaleSSIMTask()
        ]
        result_dict = eval_gan_lib.evaluate_tfhub_module(export_path,
                                                         eval_tasks,
                                                         use_tpu=False,
                                                         num_averaging_runs=1)
        tf.logging.info("result_dict: %s", result_dict)
        for score in [
                "fid_score", "fractal_dimension", "inception_score", "ms_ssim"
        ]:
            for stats in ["mean", "std", "list"]:
                required_key = "%s_%s" % (score, stats)
                self.assertIn(required_key, result_dict,
                              "Missing: %s." % required_key)
コード例 #7
0
 def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
     parameters = {
         "architecture": architecture,
         "lambda": 1,
         "z_dim": 128,
     }
     with gin.unlock_config():
         gin.bind_parameter("penalty.fn", penalty_fn)
         gin.bind_parameter("loss.fn", loss_fn)
     dataset = datasets.get_dataset("cifar10")
     gan = ModularGAN(dataset=dataset,
                      parameters=parameters,
                      model_dir=self.model_dir,
                      conditional="biggan" in architecture)
     estimator = gan.as_estimator(self.run_config,
                                  batch_size=2,
                                  use_tpu=False)
     estimator.train(gan.input_fn, steps=1)
コード例 #8
0
 def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn,
                            labeled_dataset):
     parameters = {
         "architecture": architecture,
         "lambda": 1,
         "z_dim": 120,
     }
     with gin.unlock_config():
         gin.bind_parameter("penalty.fn", penalty_fn)
         gin.bind_parameter("loss.fn", loss_fn)
     run_config = tf.contrib.tpu.RunConfig(
         model_dir=self.model_dir,
         tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
     dataset = datasets.get_dataset("cifar10")
     gan = ModularGAN(dataset=dataset,
                      parameters=parameters,
                      conditional=True,
                      model_dir=self.model_dir)
     estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
     estimator.train(gan.input_fn, steps=1)
コード例 #9
0
    def testDiscItersIsUsedCorrectly(self, disc_iters, use_tpu):

        return
        if disc_iters > 1 and use_tpu:

            return
        parameters = {
            "architecture": c.RESNET_CIFAR,
            "disc_iters": disc_iters,
            "lambda": 1,
            "z_dim": 128,
        }
        if not use_tpu:
            parameters["unroll_disc_iters"] = False
        run_config = tf.contrib.tpu.RunConfig(
            model_dir=self.model_dir,
            save_checkpoints_steps=1,
            tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
        dataset = datasets.get_dataset("cifar10")
        gan = ModularGAN(dataset=dataset,
                         parameters=parameters,
                         model_dir=self.model_dir)
        estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=use_tpu)
        estimator.train(gan.input_fn, steps=3)

        # Read checkpoints for each training step. If the weight in the generator
        # changed we trained the generator during that step.
        previous_values = {}
        generator_trained = []
        for step in range(0, 4):
            basename = os.path.join(self.model_dir,
                                    "model.ckpt-{}".format(step))
            self.assertTrue(tf.gfile.Exists(basename + ".index"))
            ckpt = tf.train.load_checkpoint(basename)

            if step == 0:
                for name in ckpt.get_variable_to_shape_map():
                    previous_values[name] = ckpt.get_tensor(name)
                continue

            d_trained = False
            g_trained = False
            for name in ckpt.get_variable_to_shape_map():
                t = ckpt.get_tensor(name)
                diff = np.abs(previous_values[name] - t).max()
                previous_values[name] = t
                if "discriminator" in name and diff > 1e-10:
                    d_trained = True
                elif "generator" in name and diff > 1e-10:
                    if name.endswith("moving_mean") or name.endswith(
                            "moving_variance"):
                        # Note: Even when we don't train the generator the batch norm
                        # values still get updated.
                        continue
                    tf.logging.info("step %d: %s changed up to %f", step, name,
                                    diff)
                    g_trained = True
            self.assertTrue(d_trained)  # Discriminator is trained every step.
            generator_trained.append(g_trained)

        self.assertEqual(generator_trained,
                         GENERATOR_TRAINED_IN_STEPS[disc_iters - 1])