コード例 #1
0
    def _test_complete_flow(self,
                            train_input_fn,
                            eval_input_fn,
                            predict_input_fn,
                            prediction_size,
                            lr_decay=False):
        def make_opt():
            gstep = training_util.get_or_create_global_step()
            lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
            return training.GradientDescentOptimizer(lr)

        gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
        dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
        est = estimator.GANEstimator(
            generator_fn=generator_fn,
            discriminator_fn=discriminator_fn,
            generator_loss_fn=losses.wasserstein_generator_loss,
            discriminator_loss_fn=losses.wasserstein_discriminator_loss,
            generator_optimizer=gopt,
            discriminator_optimizer=dopt,
            model_dir=self._model_dir)

        # TRAIN
        num_steps = 10
        est.train(train_input_fn, steps=num_steps)

        # EVALUTE
        scores = est.evaluate(eval_input_fn)
        self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
        self.assertIn('loss', six.iterkeys(scores))

        # PREDICT
        predictions = np.array([x for x in est.predict(predict_input_fn)])

        self.assertAllEqual(prediction_size, predictions.shape)
コード例 #2
0
    def _test_warm_start(self, warm_start_from=None):
        """Tests whether WarmStartSettings work as intended."""
        def generator_with_new_variable(noise_dict, mode):
            variable_scope.get_variable(name=self.new_variable_name,
                                        initializer=self.new_variable_value,
                                        trainable=True)
            return generator_fn(noise_dict, mode)

        def train_input_fn():
            data = np.zeros([3, 4])
            return {'x': data}, data

        est = estimator.GANEstimator(
            generator_fn=generator_fn,
            discriminator_fn=discriminator_fn,
            generator_loss_fn=losses.wasserstein_generator_loss,
            discriminator_loss_fn=losses.wasserstein_discriminator_loss,
            generator_optimizer=training.GradientDescentOptimizer(1.0),
            discriminator_optimizer=training.GradientDescentOptimizer(1.0),
            model_dir=self._model_dir)

        est.train(train_input_fn, steps=1)

        est_warm = estimator.GANEstimator(
            generator_fn=generator_with_new_variable,
            discriminator_fn=discriminator_fn,
            generator_loss_fn=losses.wasserstein_generator_loss,
            discriminator_loss_fn=losses.wasserstein_discriminator_loss,
            generator_optimizer=training.GradientDescentOptimizer(1.0),
            discriminator_optimizer=training.GradientDescentOptimizer(1.0),
            model_dir=None if warm_start_from else self._model_dir,
            warm_start_from=warm_start_from)

        est_warm.train(train_input_fn, steps=1)

        return est_warm