def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn, prediction_size, lr_decay=False): def make_opt(): gstep = training_util.get_or_create_global_step() lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) return training.GradientDescentOptimizer(lr) gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) est = estimator.GANEstimator( generator_fn=generator_fn, discriminator_fn=discriminator_fn, generator_loss_fn=losses.wasserstein_generator_loss, discriminator_loss_fn=losses.wasserstein_discriminator_loss, generator_optimizer=gopt, discriminator_optimizer=dopt, model_dir=self._model_dir) # TRAIN num_steps = 10 est.train(train_input_fn, steps=num_steps) # EVALUTE scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) # PREDICT predictions = np.array([x for x in est.predict(predict_input_fn)]) self.assertAllEqual(prediction_size, predictions.shape)
def _test_warm_start(self, warm_start_from=None): """Tests whether WarmStartSettings work as intended.""" def generator_with_new_variable(noise_dict, mode): variable_scope.get_variable(name=self.new_variable_name, initializer=self.new_variable_value, trainable=True) return generator_fn(noise_dict, mode) def train_input_fn(): data = np.zeros([3, 4]) return {'x': data}, data est = estimator.GANEstimator( generator_fn=generator_fn, discriminator_fn=discriminator_fn, generator_loss_fn=losses.wasserstein_generator_loss, discriminator_loss_fn=losses.wasserstein_discriminator_loss, generator_optimizer=training.GradientDescentOptimizer(1.0), discriminator_optimizer=training.GradientDescentOptimizer(1.0), model_dir=self._model_dir) est.train(train_input_fn, steps=1) est_warm = estimator.GANEstimator( generator_fn=generator_with_new_variable, discriminator_fn=discriminator_fn, generator_loss_fn=losses.wasserstein_generator_loss, discriminator_loss_fn=losses.wasserstein_discriminator_loss, generator_optimizer=training.GradientDescentOptimizer(1.0), discriminator_optimizer=training.GradientDescentOptimizer(1.0), model_dir=None if warm_start_from else self._model_dir, warm_start_from=warm_start_from) est_warm.train(train_input_fn, steps=1) return est_warm