예제 #1
0
    def test_supervisor_run_gan_model_train_ops_multiple_steps(self):
        step = training_util.create_global_step()
        train_ops = namedtuples.GANTrainOps(
            generator_train_op=constant_op.constant(3.0),
            discriminator_train_op=constant_op.constant(2.0),
            global_step_inc_op=step.assign_add(1))
        train_steps = namedtuples.GANTrainSteps(generator_train_steps=3,
                                                discriminator_train_steps=4)

        final_loss = slim_learning.train(
            train_op=train_ops,
            logdir='',
            global_step=step,
            number_of_steps=1,
            train_step_fn=train.get_sequential_train_steps(train_steps))
        self.assertTrue(np.isscalar(final_loss))
        self.assertEqual(17.0, final_loss)
예제 #2
0
  def test_supervisor_run_gan_model_train_ops_multiple_steps(self):
    step = training_util.create_global_step()
    train_ops = namedtuples.GANTrainOps(
        generator_train_op=constant_op.constant(3.0),
        discriminator_train_op=constant_op.constant(2.0),
        global_step_inc_op=step.assign_add(1))
    train_steps = namedtuples.GANTrainSteps(
        generator_train_steps=3, discriminator_train_steps=4)

    final_loss = slim_learning.train(
        train_op=train_ops,
        logdir='',
        global_step=step,
        number_of_steps=1,
        train_step_fn=train.get_sequential_train_steps(train_steps))
    self.assertTrue(np.isscalar(final_loss))
    self.assertEqual(17.0, final_loss)