Ejemplo n.º 1
0
    def testL2Regularization(self, architecture):
        parameters = self.parameters.copy()
        parameters.update(params.GetDatasetParameters("celeba"))
        parameters.update({
            "architecture": architecture,
            "penalty_type": consts.L2_PENALTY,
            "discriminator_normalization": consts.NO_NORMALIZATION,
        })
        dataset_content = test_utils.load_fake_dataset(parameters).repeat()

        config = tf.ConfigProto(allow_soft_placement=True)
        tf.reset_default_graph()
        with tf.Session(config=config):
            kwargs = dict(
                runtime_info=FakeRuntimeInfo(),
                dataset_content=dataset_content,
                parameters=parameters)
            gan = GAN_PENALTY(**kwargs)
            gan.build_model()
Ejemplo n.º 2
0
    def testDiscItersIsUsedCorrectly(self, disc_iters):
        parameters = self.parameters.copy()
        parameters.update(params.GetDatasetParameters("cifar10"))
        parameters.update({
            "batch_size": 2,
            "training_steps": 10,
            "save_checkpoint_steps": 1,
            "disc_iters": disc_iters,
            "architecture": consts.RESNET_CIFAR,
            "penalty_type": consts.NO_PENALTY,
            "discriminator_normalization": consts.NO_NORMALIZATION,
        })
        dataset_content = test_utils.load_fake_dataset(parameters).repeat()

        task_workdir = os.path.join(FLAGS.test_tmpdir, str(disc_iters))
        with tf.Graph().as_default(), tf.Session() as sess:
            gan = GAN_PENALTY(
                runtime_info=FakeRuntimeInfo(task_workdir),
                dataset_content=dataset_content,
                parameters=parameters)
            gan.build_model()
            gan.train(sess)

        # Read checkpoints for each training step. If the weight in the generator
        # changed we trained the generator during that step.
        previous_values = {}
        generator_trained = []
        for step in range(0, parameters["training_steps"] + 1):
            basename = os.path.join(
                task_workdir, "GAN_PENALTY.model-{}".format(step))
            self.assertTrue(tf.gfile.Exists(basename + ".index"))
            ckpt = tf.train.load_checkpoint(basename)

            if step == 0:
                for name in ckpt.get_variable_to_shape_map():
                    previous_values[name] = ckpt.get_tensor(name)
                continue

            d_trained = False
            g_trained = False
            for name in ckpt.get_variable_to_shape_map():
                t = ckpt.get_tensor(name)
                diff = np.abs(previous_values[name] - t).max()
                previous_values[name] = t
                if "discriminator" in name and diff > 1e-10:
                    d_trained = True
                elif "generator" in name and diff > 1e-10:
                    if name.endswith('moving_mean') or name.endswith('moving_variance'):
                        # Note: Even when we don't train the generator the batch norm
                        # values still get updated.
                        continue
                    tf.logging.info(
                        "step %d: %s changed up to %f", step, name, diff)
                    g_trained = True
            self.assertTrue(d_trained)  # Discriminator is trained every step.
            generator_trained.append(g_trained)

        self.assertEqual(generator_trained,
                         GENERATOR_TRAINED_IN_STEPS[disc_iters - 1])
    def testDiscItersIsUsedCorrectly(self, disc_iters, use_tpu):
        if disc_iters > 1 and use_tpu:

            return
        parameters = self.parameters.copy()
        parameters.update(params.GetDatasetParameters("cifar10"))
        parameters.update({
            "use_tpu": use_tpu,
            "discriminator_normalization": consts.NO_NORMALIZATION,
            "architecture": consts.RESNET_CIFAR,
            "penalty_type": consts.NO_PENALTY,
            "disc_iters": disc_iters,
            "training_steps": 5,
        })
        dataset_content = test_utils.load_fake_dataset(parameters).repeat()

        model_dir = os.path.join(FLAGS.test_tmpdir, str(disc_iters))

        config = tf.contrib.tpu.RunConfig(
            model_dir=model_dir,
            save_checkpoints_steps=1,
            keep_checkpoint_max=1000,
            tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
        gan = GAN_PENALTY(runtime_info=FakeRuntimeInfo(model_dir),
                          dataset_content=dataset_content,
                          parameters=parameters)
        gan.train_with_estimator(config)

        # Read checkpoints for each training step. If the weight in the generator
        # changed we trained the generator during that step.
        previous_values = {}
        generator_trained = []
        for step in range(0, parameters["training_steps"] + 1):
            basename = os.path.join(model_dir, "model.ckpt-{}".format(step))
            self.assertTrue(tf.gfile.Exists(basename + ".index"))
            ckpt = tf.train.load_checkpoint(basename)

            if step == 0:
                for name in ckpt.get_variable_to_shape_map():
                    previous_values[name] = ckpt.get_tensor(name)
                continue

            d_trained = False
            g_trained = False
            for name in ckpt.get_variable_to_shape_map():
                t = ckpt.get_tensor(name)
                diff = np.abs(previous_values[name] - t).max()
                previous_values[name] = t
                if "discriminator" in name and diff > 1e-10:
                    d_trained = True
                elif "generator" in name and diff > 1e-10:
                    if name.endswith("moving_mean") or name.endswith(
                            "moving_variance"):
                        # Note: Even when we don't train the generator the batch norm
                        # values still get updated.
                        continue
                    tf.logging.info("step %d: %s changed up to %f", step, name,
                                    diff)
                    g_trained = True
            self.assertTrue(d_trained)  # Discriminator is trained every step.
            generator_trained.append(g_trained)

        self.assertEqual(generator_trained,
                         GENERATOR_TRAINED_IN_STEPS[disc_iters - 1])