def test_output_string(self): with self.test_session(): gan = mock_gan() config = {'d_learn_rate': 1e-3, 'g_learn_rate': 1e-3, 'd_trainer': 'rmsprop', 'g_trainer': 'adam'} trainer = AlternatingTrainer(gan, config) c = tf.constant(1) self.assertTrue('d_loss' in trainer.output_string({'d_loss':c})) self.assertTrue('g_loss' in trainer.output_string({'g_loss':c})) self.assertEqual(len(trainer.output_variables({'a': c, 'b': c})), 2)
def test_output_string(self): with self.test_session(): gan = mock_gan() gan.create() config = {'d_learn_rate': 1e-3, 'g_learn_rate': 1e-3, 'd_trainer': 'rmsprop', 'g_trainer': 'adam'} trainer = AlternatingTrainer(gan, config) c = tf.constant(1) self.assertTrue('d_loss' in trainer.output_string({'d_loss':c})) self.assertTrue('g_loss' in trainer.output_string({'g_loss':c})) self.assertEqual(len(trainer.output_variables({'a': c, 'b': c})), 2)
def test_config(self): with self.test_session(): config = { 'd_learn_rate': 1e-3, 'g_learn_rate': 1e-3, 'd_trainer': 'rmsprop', 'g_trainer': 'adam' } gan = mock_gan() trainer = AlternatingTrainer(gan, config) self.assertEqual(trainer.config.d_learn_rate, 1e-3)
def test_validate(self): with self.assertRaises(ValidationException): gan = mock_gan() AlternatingTrainer(gan, {})