Example #1
0
    def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self):
        # First, train only the weights of the model.
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            total_loss = self.ModelLoss()
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)
            weights, biases = variables_lib.get_variables()

            train_op = training.create_train_op(total_loss, optimizer)
            train_weights = training.create_train_op(
                total_loss, optimizer, variables_to_train=[weights])
            train_biases = training.create_train_op(
                total_loss, optimizer, variables_to_train=[biases])

            with self.test_session() as session:
                # Initialize the variables.
                session.run(variables_lib2.global_variables_initializer())

                # Get the initial weights and biases values.
                weights_values, biases_values = session.run([weights, biases])
                self.assertGreater(np.linalg.norm(weights_values), 0)
                self.assertAlmostEqual(np.linalg.norm(biases_values), 0)

                # Update weights and biases.
                loss = session.run(train_op)
                self.assertGreater(loss, .5)
                new_weights, new_biases = session.run([weights, biases])

                # Check that the weights and biases have been updated.
                self.assertGreater(
                    np.linalg.norm(weights_values - new_weights), 0)
                self.assertGreater(np.linalg.norm(biases_values - new_biases),
                                   0)

                weights_values, biases_values = new_weights, new_biases

                # Update only weights.
                loss = session.run(train_weights)
                self.assertGreater(loss, .5)
                new_weights, new_biases = session.run([weights, biases])

                # Check that the weights have been updated, but biases have not.
                self.assertGreater(
                    np.linalg.norm(weights_values - new_weights), 0)
                self.assertAlmostEqual(
                    np.linalg.norm(biases_values - new_biases), 0)
                weights_values = new_weights

                # Update only biases.
                loss = session.run(train_biases)
                self.assertGreater(loss, .5)
                new_weights, new_biases = session.run([weights, biases])

                # Check that the biases have been updated, but weights have not.
                self.assertAlmostEqual(
                    np.linalg.norm(weights_values - new_weights), 0)
                self.assertGreater(np.linalg.norm(biases_values - new_biases),
                                   0)
Example #2
0
  def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self):
    # First, train only the weights of the model.
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      weights, biases = variables_lib.get_variables()

      train_op = training.create_train_op(total_loss, optimizer)
      train_weights = training.create_train_op(
          total_loss, optimizer, variables_to_train=[weights])
      train_biases = training.create_train_op(
          total_loss, optimizer, variables_to_train=[biases])

      with session_lib.Session() as sess:
        # Initialize the variables.
        sess.run(variables_lib2.global_variables_initializer())

        # Get the intial weights and biases values.
        weights_values, biases_values = sess.run([weights, biases])
        self.assertGreater(np.linalg.norm(weights_values), 0)
        self.assertAlmostEqual(np.linalg.norm(biases_values), 0)

        # Update weights and biases.
        loss = sess.run(train_op)
        self.assertGreater(loss, .5)
        new_weights, new_biases = sess.run([weights, biases])

        # Check that the weights and biases have been updated.
        self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
        self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)

        weights_values, biases_values = new_weights, new_biases

        # Update only weights.
        loss = sess.run(train_weights)
        self.assertGreater(loss, .5)
        new_weights, new_biases = sess.run([weights, biases])

        # Check that the weights have been updated, but biases have not.
        self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
        self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0)
        weights_values = new_weights

        # Update only biases.
        loss = sess.run(train_biases)
        self.assertGreater(loss, .5)
        new_weights, new_biases = sess.run([weights, biases])

        # Check that the biases have been updated, but weights have not.
        self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0)
        self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
Example #3
0
  def testEmptyUpdateOps(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss_ops.log_loss(tf_predictions, tf_labels)
      total_loss = loss_ops.get_total_loss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(total_loss, optimizer, update_ops=[])

      moving_mean = variables_lib.get_variables_by_name('moving_mean')[0]
      moving_variance = variables_lib.get_variables_by_name('moving_variance')[
          0]

      with session_lib.Session() as sess:
        # Initialize all variables
        sess.run(variables_lib2.global_variables_initializer())
        mean, variance = sess.run([moving_mean, moving_variance])
        # After initialization moving_mean == 0 and moving_variance == 1.
        self.assertAllClose(mean, [0] * 4)
        self.assertAllClose(variance, [1] * 4)

        for _ in range(10):
          sess.run([train_op])
        mean = moving_mean.eval()
        variance = moving_variance.eval()

        # Since we skip update_ops the moving_vars are not updated.
        self.assertAllClose(mean, [0] * 4)
        self.assertAllClose(variance, [1] * 4)
Example #4
0
    def _train_model(self, checkpoint_dir, num_steps):
        """Trains a simple classification model.

    Note that the data has been configured such that after around 300 steps,
    the model has memorized the dataset (e.g. we can expect %100 accuracy).

    Args:
      checkpoint_dir: The directory where the checkpoint is written to.
      num_steps: The number of steps to train for.
    """
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            tf_predictions = logistic_classifier(tf_inputs)
            loss = loss_ops.log_loss(tf_predictions, tf_labels)

            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)
            train_op = training.create_train_op(loss, optimizer)

            loss = training.train(
                train_op,
                checkpoint_dir,
                hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)])
Example #5
0
  def testNoneGlobalStep(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss_ops.log_loss(tf_predictions, tf_labels)
      total_loss = loss_ops.get_total_loss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(
          total_loss, optimizer, global_step=None)

      global_step = variables_lib.get_or_create_global_step()

      with session_lib.Session() as sess:
        # Initialize all variables
        sess.run(variables_lib2.global_variables_initializer())

        for _ in range(10):
          sess.run([train_op])
        global_step = global_step.eval()
        # Since train_op don't use global_step it shouldn't change.
        self.assertAllClose(global_step, 0)
Example #6
0
    def testGlobalStepIsIncrementedByDefault(self):
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            tf_predictions = batchnorm_classifier(tf_inputs)
            loss = losses.log_loss(tf_labels, tf_predictions)
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)
            train_op = training.create_train_op(loss, optimizer)

            global_step = variables_lib.get_or_create_global_step()

            with self.test_session() as session:
                # Initialize all variables
                session.run(variables_lib2.global_variables_initializer())

                for _ in range(10):
                    session.run(train_op)

                # After 10 updates global_step should be 10.
                self.assertAllClose(global_step.eval(), 10)
Example #7
0
  def testResumeTrainAchievesRoughlyTheSameLoss(self):
    number_of_steps = [300, 1, 5]
    logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss')

    for i in range(len(number_of_steps)):
      with ops.Graph().as_default():
        random_seed.set_random_seed(i)
        tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
        tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

        tf_predictions = logistic_classifier(tf_inputs)
        loss_ops.log_loss(tf_predictions, tf_labels)
        total_loss = loss_ops.get_total_loss()

        optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

        train_op = training.create_train_op(total_loss, optimizer)

        saver = saver_lib.Saver()

        loss = training.train(
            train_op,
            logdir,
            hooks=[
                basic_session_run_hooks.StopAtStepHook(
                    num_steps=number_of_steps[i]),
                basic_session_run_hooks.CheckpointSaverHook(
                    logdir, save_steps=50, saver=saver),
            ])
        self.assertIsNotNone(loss)
        self.assertLess(loss, .015)
 def dis_train_op():
     with ops.name_scope('discriminator_train'):
         return training.create_train_op(
             total_loss=gan_loss.discriminator_loss,
             optimizer=discriminator_optimizer,
             variables_to_train=gan_model.discriminator_variables,
             update_ops=dis_update_ops)
 def gen_train_op():
     with ops.name_scope('generator_train'):
         return training.create_train_op(
             total_loss=gan_loss.generator_loss,
             optimizer=generator_optimizer,
             variables_to_train=gan_model.generator_variables,
             update_ops=gen_update_ops)
Example #10
0
def create_train_op_v2(total_loss,
                       optimizer,
                       global_step=_USE_GLOBAL_STEP,
                       update_ops_before_loss=None,
                       update_ops_after_loss=None,
                       variables_to_train=None,
                       transform_grads_fn=None,
                       summarize_gradients=False,
                       gate_gradients=tf.train.Optimizer.GATE_OP,
                       aggregation_method=None,
                       colocate_gradients_with_ops=False,
                       check_numerics=True):
    """
    support update_ops after train_op

    """
    train_op = create_train_op(total_loss=total_loss,
                               optimizer=optimizer,
                               global_step=global_step,
                               update_ops=update_ops_before_loss,
                               variables_to_train=variables_to_train,
                               transform_grads_fn=transform_grads_fn,
                               summarize_gradients=summarize_gradients,
                               gate_gradients=gate_gradients,
                               aggregation_method=aggregation_method,
                               colocate_gradients_with_ops=colocate_gradients_with_ops,
                               check_numerics=check_numerics)
    if update_ops_after_loss is not None:
        final_train_op = with_dependencies([train_op], update_ops_after_loss)
    else:
        final_train_op = train_op
    return final_train_op
 def dis_train_op():
   with ops.name_scope('discriminator_train'):
     return training.create_train_op(
         total_loss=gan_loss.discriminator_loss,
         optimizer=discriminator_optimizer,
         variables_to_train=gan_model.discriminator_variables,
         update_ops=dis_update_ops)
 def gen_train_op():
   with ops.name_scope('generator_train'):
     return training.create_train_op(
         total_loss=gan_loss.generator_loss,
         optimizer=generator_optimizer,
         variables_to_train=gan_model.generator_variables,
         update_ops=gen_update_ops)
Example #13
0
  def _train_model(self, checkpoint_dir, num_steps):
    """Trains a simple classification model.

    Note that the data has been configured such that after around 300 steps,
    the model has memorized the dataset (e.g. we can expect %100 accuracy).

    Args:
      checkpoint_dir: The directory where the checkpoint is written to.
      num_steps: The number of steps to train for.
    """
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = logistic_classifier(tf_inputs)
      loss = loss_ops.log_loss(tf_predictions, tf_labels)

      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      train_op = training.create_train_op(loss, optimizer)

      loss = training.train(
          train_op,
          checkpoint_dir,
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)])
    def testCanAchieveZeroLoss(self):
        logdir = os.path.join(self.get_temp_dir(), 'can_achieve_zero_loss')

        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            tf_predictions = logistic_classifier(tf_inputs)
            loss_ops.log_loss(tf_predictions, tf_labels)
            total_loss = loss_ops.get_total_loss()

            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)

            train_op = training.create_train_op(total_loss, optimizer)

            loss = training.train(
                train_op,
                logdir,
                hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)])
            self.assertIsNotNone(loss)
            self.assertLess(loss, .015)
Example #15
0
    def testTrainWithLocalVariable(self):
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            local_multiplier = variables_lib.local_variable(1.0)

            tf_predictions = logistic_classifier(tf_inputs) * local_multiplier
            losses.log_loss(tf_labels, tf_predictions)
            total_loss = losses.get_total_loss()
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)
            train_op = training.create_train_op(total_loss, optimizer)

            loss = training.train(
                train_op,
                None,
                hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)],
                save_summaries_steps=None,
                save_checkpoint_secs=None)
            self.assertIsNotNone(loss)
            self.assertLess(loss, .015)
Example #16
0
    def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0):
        tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
        tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

        tf_predictions = logistic_classifier(tf_inputs)
        losses.log_loss(tf_labels, tf_predictions)
        total_loss = losses.get_total_loss()

        optimizer = gradient_descent.GradientDescentOptimizer(
            learning_rate=learning_rate)

        def transform_grads_fn(grads):
            if gradient_multiplier != 1.0:
                variables = variables_lib2.trainable_variables()
                gradient_multipliers = {
                    var: gradient_multiplier
                    for var in variables
                }

                with ops.name_scope('multiply_grads'):
                    return training.multiply_gradients(grads,
                                                       gradient_multipliers)
            else:
                return grads

        return training.create_train_op(total_loss,
                                        optimizer,
                                        transform_grads_fn=transform_grads_fn)
Example #17
0
def gan_train_ops(
    self,
    model,
    loss,
    generator_optimizer,
    discriminator_optimizer,
    check_for_unused_update_ops=True,
):
    # Create global step increment op.
    global_step = training_util.get_or_create_global_step()
    global_step_inc = global_step.assign_add(1)

    update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
    all_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, model.EG_scope.name))
    update_ops = list(all_ops & update_ops)
    with ops.name_scope('EG_train'):
        gen_train_op = training.create_train_op(
            total_loss=self.loss_EG,
            optimizer=self.EG_optimizer,
            variables_to_train=self.E_variables + self.G_variables,
            global_step=self.EG_global_step,
            update_ops=update_ops)

    update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
    all_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, model.Dz_scope.name))
    update_ops = list(all_ops & update_ops)
    with ops.name_scope('Dz_train'):
        gen_train_op = training.create_train_op(
            total_loss=self.loss_Dz,
            optimizer=self.D_z_optimizer,
            variables_to_train=self.E_variables + self.G_variables,
            global_step=self.EG_global_step,
            update_ops=update_ops)

    update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
    all_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, model.Di_scope.name))
    update_ops = list(all_ops & update_ops)
    with ops.name_scope('Di_train'):
        gen_train_op = training.create_train_op(
            total_loss=self.loss_Di,
            optimizer=self.D_img_optimizer,
            variables_to_train=self.E_variables + self.G_variables,
            global_step=self.EG_global_step,
            update_ops=update_ops)

    return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)
Example #18
0
  def testTrainOpInCollection(self):
    with ops.Graph().as_default():
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss = losses.log_loss(tf_labels, tf_predictions)
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      train_op = training.create_train_op(loss, optimizer)

      # Make sure the training op was recorded in the proper collection
      self.assertTrue(train_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
Example #19
0
  def testTrainOpInCollection(self):
    with ops.Graph().as_default():
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss = losses.log_loss(tf_labels, tf_predictions)
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      train_op = training.create_train_op(loss, optimizer)

      # Make sure the training op was recorded in the proper collection
      self.assertTrue(train_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
Example #20
0
  def testCanAchieveZeroLoss(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = logistic_classifier(tf_inputs)
      losses.log_loss(tf_labels, tf_predictions)
      total_loss = losses.get_total_loss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      train_op = training.create_train_op(total_loss, optimizer)

      loss = training.train(
          train_op,
          None,
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)],
          save_summaries_steps=None,
          save_checkpoint_secs=None)
      self.assertIsNotNone(loss)
      self.assertLess(loss, .015)
Example #21
0
def make_train_op(scope,
                  optimizer,
                  global_step,
                  total_loss,
                  clip_gradient_norm=0,
                  clip_gradient_value=0,
                  summarize_gradients=False):
    transform_grads_fn = make_transform_grads_fn(
        clip_gradient_norm=clip_gradient_norm,
        clip_gradient_value=clip_gradient_value)
    variables = tf.trainable_variables(scope=scope)
    updates = tf.get_collection(key=tf.GraphKeys.UPDATE_OPS, scope=scope)
    train_op = create_train_op(total_loss=total_loss,
                               optimizer=optimizer,
                               update_ops=updates,
                               variables_to_train=variables,
                               transform_grads_fn=transform_grads_fn,
                               summarize_gradients=summarize_gradients,
                               global_step=global_step)
    return train_op
Example #22
0
  def testTrainWithNoInitAssignCanAchieveZeroLoss(self):
    g = ops.Graph()
    with g.as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss_ops.log_loss(tf_predictions, tf_labels)
      total_loss = loss_ops.get_total_loss()

      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(total_loss, optimizer)

      loss = training.train(
          train_op,
          self._logdir,
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)])
      self.assertLess(loss, .1)
Example #23
0
  def testCanAchieveZeroLoss(self):
    logdir = os.path.join(self.get_temp_dir(), 'can_achieve_zero_loss')

    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = logistic_classifier(tf_inputs)
      loss_ops.log_loss(tf_predictions, tf_labels)
      total_loss = loss_ops.get_total_loss()

      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(total_loss, optimizer)

      loss = training.train(
          train_op,
          logdir,
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)])
      self.assertIsNotNone(loss)
      self.assertLess(loss, .015)
Example #24
0
  def testGlobalStepIsIncrementedByDefault(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss = losses.log_loss(tf_labels, tf_predictions)
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      train_op = training.create_train_op(loss, optimizer)

      global_step = variables_lib.get_or_create_global_step()

      with self.test_session() as session:
        # Initialize all variables
        session.run(variables_lib2.global_variables_initializer())

        for _ in range(10):
          session.run(train_op)

        # After 10 updates global_step should be 10.
        self.assertAllClose(global_step.eval(), 10)
Example #25
0
    def testUseUpdateOps(self):
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            expected_mean = np.mean(self._inputs, axis=(0))
            expected_var = np.var(self._inputs, axis=(0))

            tf_predictions = batchnorm_classifier(tf_inputs)
            loss = losses.log_loss(tf_labels, tf_predictions)
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)

            train_op = training.create_train_op(loss, optimizer)

            moving_mean = variables_lib.get_variables_by_name('moving_mean')[0]
            moving_variance = variables_lib.get_variables_by_name(
                'moving_variance')[0]

            with self.test_session() as session:
                # Initialize all variables
                session.run(variables_lib2.global_variables_initializer())
                mean, variance = session.run([moving_mean, moving_variance])
                # After initialization moving_mean == 0 and moving_variance == 1.
                self.assertAllClose(mean, [0] * 4)
                self.assertAllClose(variance, [1] * 4)

                for _ in range(10):
                    session.run(train_op)

                mean = moving_mean.eval()
                variance = moving_variance.eval()
                # After 10 updates with decay 0.1 moving_mean == expected_mean and
                # moving_variance == expected_var.
                self.assertAllClose(mean, expected_mean)
                self.assertAllClose(variance, expected_var)
Example #26
0
  def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0):
    tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

    tf_predictions = logistic_classifier(tf_inputs)
    loss_ops.log_loss(tf_predictions, tf_labels)
    total_loss = loss_ops.get_total_loss()

    optimizer = gradient_descent.GradientDescentOptimizer(
        learning_rate=learning_rate)

    def transform_grads_fn(grads):
      if gradient_multiplier != 1.0:
        variables = variables_lib2.trainable_variables()
        gradient_multipliers = {var: gradient_multiplier for var in variables}

        with ops.name_scope('multiply_grads'):
          return training.multiply_gradients(grads, gradient_multipliers)
      else:
        return grads

    return training.create_train_op(
        total_loss, optimizer, transform_grads_fn=transform_grads_fn)
Example #27
0
  def testUseUpdateOps(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      expected_mean = np.mean(self._inputs, axis=(0))
      expected_var = np.var(self._inputs, axis=(0))

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss = losses.log_loss(tf_labels, tf_predictions)
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(loss, optimizer)

      moving_mean = variables_lib.get_variables_by_name('moving_mean')[0]
      moving_variance = variables_lib.get_variables_by_name('moving_variance')[
          0]

      with self.test_session() as session:
        # Initialize all variables
        session.run(variables_lib2.global_variables_initializer())
        mean, variance = session.run([moving_mean, moving_variance])
        # After initialization moving_mean == 0 and moving_variance == 1.
        self.assertAllClose(mean, [0] * 4)
        self.assertAllClose(variance, [1] * 4)

        for _ in range(10):
          session.run(train_op)

        mean = moving_mean.eval()
        variance = moving_variance.eval()
        # After 10 updates with decay 0.1 moving_mean == expected_mean and
        # moving_variance == expected_var.
        self.assertAllClose(mean, expected_mean)
        self.assertAllClose(variance, expected_var)
Example #28
0
def create_train_op(total_loss,
                    optimizer,
                    global_step=_USE_GLOBAL_STEP,
                    update_ops=None,
                    variables_to_train=None,
                    clip_gradient_norm=0,
                    summarize_gradients=False,
                    gate_gradients=tf_optimizer.Optimizer.GATE_OP,
                    aggregation_method=None,
                    colocate_gradients_with_ops=False,
                    gradient_multipliers=None,
                    check_numerics=True):
    """Creates an `Operation` that evaluates the gradients and returns the loss.

    Args:
      total_loss: A `Tensor` representing the total loss.
      optimizer: A tf.Optimizer to use for computing the gradients.
      global_step: A `Tensor` representing the global step variable. If left as
        `_USE_GLOBAL_STEP`, then slim.variables.global_step() is used.
      update_ops: An optional list of updates to execute. If `update_ops` is
        `None`, then the update ops are set to the contents of the
        `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
        it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
        a warning will be displayed.
      variables_to_train: an optional list of variables to train. If None, it will
        default to all tf.trainable_variables().
      clip_gradient_norm: If greater than 0 then the gradients would be clipped
        by it.
      summarize_gradients: Whether or not add summaries for each gradient.
      gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
      aggregation_method: Specifies the method used to combine gradient terms.
        Valid values are defined in the class `AggregationMethod`.
      colocate_gradients_with_ops: Whether or not to try colocating the gradients
        with the ops that generated them.
      gradient_multipliers: A dictionary of either `Variables` or `Variable` op
        names to the coefficient by which the associated gradient should be
        scaled.
      check_numerics: Whether or not we apply check_numerics.

    Returns:
      A `Tensor` that when evaluated, computes the gradients and returns the total
        loss value.
    """

    def transform_grads_fn(grads):
        if gradient_multipliers:
            with ops.name_scope('multiply_grads'):
                grads = multiply_gradients(grads, gradient_multipliers)

        # Clip gradients.
        if clip_gradient_norm > 0:
            with ops.name_scope('clip_grads'):
                grads = clip_gradient_norms(grads, clip_gradient_norm)
        return grads

    return training.create_train_op(
        total_loss=total_loss,
        optimizer=optimizer,
        global_step=global_step,
        update_ops=update_ops,
        variables_to_train=variables_to_train,
        transform_grads_fn=transform_grads_fn,
        summarize_gradients=summarize_gradients,
        gate_gradients=gate_gradients,
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops,
        check_numerics=check_numerics)
Example #29
0
def gan_train_ops(
        model,  # GANModel
        loss,  # GANLoss
        generator_optimizer,
        discriminator_optimizer,
        # Optional check flags.
        check_for_unused_update_ops=True,
        # Optional args to pass directly to the `create_train_op`.
        **kwargs):
    """Returns GAN train ops.

  The highest-level call in TFGAN. It is composed of functions that can also
  be called, should a user require more control over some part of the GAN
  training process.

  Args:
    model: A GANModel.
    loss: A GANLoss.
    generator_optimizer: The optimizer for generator updates.
    discriminator_optimizer: The optimizer for the discriminator updates.
    check_for_unused_update_ops: If `True`, throws an exception if there are
      update ops outside of the generator or discriminator scopes.
    **kwargs: Keyword args to pass directly to
      `training.create_train_op` for both the generator and
      discriminator train op.

  Returns:
    A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
    be used to train a generator/discriminator pair.
  """
    # Create global step increment op.
    global_step = training_util.get_or_create_global_step()
    global_step_inc = global_step.assign_add(1)

    # Get generator and discriminator update ops. We split them so that update
    # ops aren't accidentally run multiple times. For now, throw an error if
    # there are update ops that aren't associated with either the generator or
    # the discriminator. Might modify the `kwargs` dictionary.
    gen_update_ops, dis_update_ops = _get_update_ops(
        kwargs, model.generator_scope.name, model.discriminator_scope.name,
        check_for_unused_update_ops)

    generator_global_step = None
    if isinstance(generator_optimizer,
                  sync_replicas_optimizer.SyncReplicasOptimizer):
        # TODO(joelshor): Figure out a way to get this work without including the
        # dummy global step in the checkpoint.
        # WARNING: Making this variable a local variable causes sync replicas to
        # hang forever.
        generator_global_step = variable_scope.get_variable(
            'dummy_global_step_generator',
            shape=[],
            dtype=global_step.dtype.base_dtype,
            initializer=init_ops.zeros_initializer(),
            trainable=False,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        gen_update_ops += [generator_global_step.assign(global_step)]
    with ops.name_scope('generator_train'):
        gen_train_op = training.create_train_op(
            total_loss=loss.generator_loss,
            optimizer=generator_optimizer,
            variables_to_train=model.generator_variables,
            global_step=generator_global_step,
            update_ops=gen_update_ops,
            **kwargs)

    discriminator_global_step = None
    if isinstance(discriminator_optimizer,
                  sync_replicas_optimizer.SyncReplicasOptimizer):
        # See comment above `generator_global_step`.
        discriminator_global_step = variable_scope.get_variable(
            'dummy_global_step_discriminator',
            shape=[],
            dtype=global_step.dtype.base_dtype,
            initializer=init_ops.zeros_initializer(),
            trainable=False,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        dis_update_ops += [discriminator_global_step.assign(global_step)]
    with ops.name_scope('discriminator_train'):
        disc_train_op = training.create_train_op(
            total_loss=loss.discriminator_loss,
            optimizer=discriminator_optimizer,
            variables_to_train=model.discriminator_variables,
            global_step=discriminator_global_step,
            update_ops=dis_update_ops,
            **kwargs)

    return namedtuples.GANTrainOps(gen_train_op, disc_train_op,
                                   global_step_inc)
def create_train_op(total_loss,
                    optimizer,
                    global_step=_USE_GLOBAL_STEP,
                    update_ops=None,
                    variables_to_train=None,
                    clip_gradient_norm=0,
                    summarize_gradients=False,
                    gate_gradients=tf_optimizer.Optimizer.GATE_OP,
                    aggregation_method=None,
                    colocate_gradients_with_ops=False,
                    gradient_multipliers=None):
    """Creates an `Operation` that evaluates the gradients and returns the loss.

  Args:
    total_loss: A `Tensor` representing the total loss.
    optimizer: A tf.Optimizer to use for computing the gradients.
    global_step: A `Tensor` representing the global step variable. If left as
      `_USE_GLOBAL_STEP`, then slim.variables.global_step() is used.
    update_ops: An optional list of updates to execute. If `update_ops` is
      `None`, then the update ops are set to the contents of the
      `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
      it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
      a warning will be displayed.
    variables_to_train: an optional list of variables to train. If None, it will
      default to all tf.trainable_variables().
    clip_gradient_norm: If greater than 0 then the gradients would be clipped
      by it.
    summarize_gradients: Whether or not add summaries for each gradient.
    gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
    aggregation_method: Specifies the method used to combine gradient terms.
      Valid values are defined in the class `AggregationMethod`.
    colocate_gradients_with_ops: Whether or not to try colocating the gradients
      with the ops that generated them.
    gradient_multipliers: A dictionary of either `Variables` or `Variable` op
      names to the coefficient by which the associated gradient should be
      scaled.
  Returns:
    A `Tensor` that when evaluated, computes the gradients and returns the total
      loss value.
  """
    def transform_grads_fn(grads):
        if gradient_multipliers:
            with ops.name_scope('multiply_grads'):
                grads = multiply_gradients(grads, gradient_multipliers)

        # Clip gradients.
        if clip_gradient_norm > 0:
            with ops.name_scope('clip_grads'):
                grads = clip_gradient_norms(grads, clip_gradient_norm)
        return grads

    return training.create_train_op(
        total_loss=total_loss,
        optimizer=optimizer,
        global_step=global_step,
        update_ops=update_ops,
        variables_to_train=variables_to_train,
        transform_grads_fn=transform_grads_fn,
        summarize_gradients=summarize_gradients,
        gate_gradients=gate_gradients,
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops)
Example #31
0
def gan_train_ops(
        models,
        generator_scope,
        discriminator_scopes,
        losses,
        generator_optimizer,
        discriminator_optimizers,
        check_for_unused_update_ops=True,
        # Optional args to pass directly to the `create_train_op`.
        **kwargs):
    """Returns GAN train ops.
  The highest-level call in TFGAN. It is composed of functions that can also
  be called, should a user require more control over some part of the GAN
  training process.
  Args:
    model: A GANModel.
    loss: A GANLoss.
    generator_optimizer: The optimizer for generator updates.
    discriminator_optimizer: The optimizer for the discriminator updates.
    check_for_unused_update_ops: If `True`, throws an exception if there are
      update ops outside of the generator or discriminator scopes.
    **kwargs: Keyword args to pass directly to
      `training.create_train_op` for both the generator and
      discriminator train op.
  Returns:
    A GANTrainOps tuple of
    (generator_train_op, discriminator_train_op, global_step_inc_op) that can
    be used to train a generator/discriminator pair.
  """
    # Create global step increment op.
    global_step = training_util.get_or_create_global_step()
    global_step_inc = global_step.assign_add(1)

    # Get generator and discriminator update ops. We split them so that update
    # ops aren't accidentally run multiple times. For now, throw an error if
    # there are update ops that aren't associated with either the generator or
    # the discriminator. Might modify the `kwargs` dictionary.
    gen_update_ops, dis_update_ops_list = _get_update_ops(
        kwargs, generator_scope, discriminator_scopes,
        check_for_unused_update_ops)

    with tf.name_scope(generator_scope):
        gen_train_op = training.create_train_op(
            total_loss=losses.gen_nongan_losses,
            optimizer=generator_optimizer,
            variables_to_train=models[0].
            generator_variables,  # the generator variables are always the same
            global_step=None,
            update_ops=gen_update_ops,
            summarize_gradients=True,
            **kwargs)

        gen_summaries = ops.get_collection(ops.GraphKeys.SUMMARIES,
                                           generator_scope)

    disc_train_ops = {}
    disc_summaries_dict = {}
    for i in np.arange(len(models)):
        with tf.name_scope(discriminator_scopes[i]):
            disc_train_op = training.create_train_op(
                total_loss=losses.discriminator_losses[i],
                optimizer=discriminator_optimizers[i],
                variables_to_train=models[i].discriminator_variables,
                global_step=None,
                update_ops=dis_update_ops_list[i],
                summarize_gradients=True,
                **kwargs)
            # annotate the train_ops with meaningful names?
            disc_train_ops[i] = disc_train_op

            disc_summaries_dict[discriminator_scopes[i]] = ops.get_collection(
                ops.GraphKeys.SUMMARIES, discriminator_scopes[i])

    return GANTrainOps(gen_train_op, disc_train_ops,
                       global_step_inc), gen_summaries, disc_summaries_dict
Example #32
0
  def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
    logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/')
    if gfile.Exists(logdir):  # For running on jenkins.
      gfile.DeleteRecursively(logdir)

    # First, train only the weights of the model.
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      weights = variables_lib.get_variables_by_name('weights')

      train_op = training.create_train_op(
          total_loss, optimizer, variables_to_train=weights)

      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=200),
          ])
      self.assertGreater(loss, .015)
      self.assertLess(loss, .05)

    # Next, train the biases of the model.
    with ops.Graph().as_default():
      random_seed.set_random_seed(1)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      biases = variables_lib.get_variables_by_name('biases')

      train_op = training.create_train_op(
          total_loss, optimizer, variables_to_train=biases)

      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=300),
          ])
      self.assertGreater(loss, .015)
      self.assertLess(loss, .05)

    # Finally, train both weights and bias to get lower loss.
    with ops.Graph().as_default():
      random_seed.set_random_seed(2)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(total_loss, optimizer)
      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=400),
          ])
      self.assertIsNotNone(loss)
      self.assertLess(loss, .015)
Example #33
0
def gan_train_ops(
        model,
        loss,
        generator_optimizer,
        discriminator_optimizer,
        check_for_unused_update_ops=True,
        # Optional args to pass directly to the `create_train_op`.
        **kwargs):
    """Returns GAN train ops.

  The highest-level call in TFGAN. It is composed of functions that can also
  be called, should a user require more control over some part of the GAN
  training process.

  Args:
    model: A GANModel.
    loss: A GANLoss.
    generator_optimizer: The optimizer for generator updates.
    discriminator_optimizer: The optimizer for the discriminator updates.
    check_for_unused_update_ops: If `True`, throws an exception if there are
      update ops outside of the generator or discriminator scopes.
    **kwargs: Keyword args to pass directly to
      `training.create_train_op` for both the generator and
      discriminator train op.

  Returns:
    A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
    be used to train a generator/discriminator pair.
  """
    if isinstance(model, namedtuples.CycleGANModel):
        # Get and store all arguments other than model and loss from locals.
        # Contents of locals should not be modified, may not affect values. So make
        # a copy. https://docs.python.org/2/library/functions.html#locals.
        saved_params = dict(locals())
        saved_params.pop('model', None)
        saved_params.pop('loss', None)
        kwargs = saved_params.pop('kwargs', {})
        saved_params.update(kwargs)
        with ops.name_scope('cyclegan_x2y_train'):
            train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y,
                                          **saved_params)
        with ops.name_scope('cyclegan_y2x_train'):
            train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x,
                                          **saved_params)
        return namedtuples.GANTrainOps(
            (train_ops_x2y.generator_train_op,
             train_ops_y2x.generator_train_op),
            (train_ops_x2y.discriminator_train_op,
             train_ops_y2x.discriminator_train_op),
            training_util.get_or_create_global_step().assign_add(1))

    # Create global step increment op.
    global_step = training_util.get_or_create_global_step()
    global_step_inc = global_step.assign_add(1)

    # Get generator and discriminator update ops. We split them so that update
    # ops aren't accidentally run multiple times. For now, throw an error if
    # there are update ops that aren't associated with either the generator or
    # the discriminator. Might modify the `kwargs` dictionary.
    gen_update_ops, dis_update_ops = _get_update_ops(
        kwargs, model.generator_scope.name, model.discriminator_scope.name,
        check_for_unused_update_ops)

    generator_global_step = None
    if isinstance(generator_optimizer,
                  sync_replicas_optimizer.SyncReplicasOptimizer):
        # TODO (joelshor): Figure out a way to get this work without including the id:885
        # https://github.com/imdone/tensorflow/issues/886
        # dummy global step in the checkpoint.
        # WARNING: Making this variable a local variable causes sync replicas to
        # hang forever.
        generator_global_step = variable_scope.get_variable(
            'dummy_global_step_generator',
            shape=[],
            dtype=global_step.dtype.base_dtype,
            initializer=init_ops.zeros_initializer(),
            trainable=False,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        gen_update_ops += [generator_global_step.assign(global_step)]
    with ops.name_scope('generator_train'):
        gen_train_op = training.create_train_op(
            total_loss=loss.generator_loss,
            optimizer=generator_optimizer,
            variables_to_train=model.generator_variables,
            global_step=generator_global_step,
            update_ops=gen_update_ops,
            **kwargs)

    discriminator_global_step = None
    if isinstance(discriminator_optimizer,
                  sync_replicas_optimizer.SyncReplicasOptimizer):
        # See comment above `generator_global_step`.
        discriminator_global_step = variable_scope.get_variable(
            'dummy_global_step_discriminator',
            shape=[],
            dtype=global_step.dtype.base_dtype,
            initializer=init_ops.zeros_initializer(),
            trainable=False,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        dis_update_ops += [discriminator_global_step.assign(global_step)]
    with ops.name_scope('discriminator_train'):
        disc_train_op = training.create_train_op(
            total_loss=loss.discriminator_loss,
            optimizer=discriminator_optimizer,
            variables_to_train=model.discriminator_variables,
            global_step=discriminator_global_step,
            update_ops=dis_update_ops,
            **kwargs)

    return namedtuples.GANTrainOps(gen_train_op, disc_train_op,
                                   global_step_inc)
Example #34
0
def ctc_estimator(tokens,
                  token_lengths,
                  logits,
                  glogits,
                  sequence_mask,
                  sequence_length_ctc,
                  vocab,
                  run_config,
                  params,
                  mode,
                  model_scope,
                  training_hooks=[]):
    with tf.name_scope(model_scope + "/"):
        tok_1 = tokens + 1
        ctc_labels_sparse = sparsify(tf.cast(tok_1, tf.int32), sequence_mask)
        ctc_labels = tf.sparse_tensor_to_dense(ctc_labels_sparse,
                                               default_value=-1)
        # ctc_labels = tf.sparse_transpose(ctc_labels, (1,0))
        print("Labels: {}".format(ctc_labels))
        print("logits: {}".format(logits))
        print("glogits: {}".format(glogits))
        # tf.tile(tf.pow([2], depth), (n,))
        print("CTC: {}, {}, {}".format(ctc_labels, logits,
                                       sequence_length_ctc))
        if tf.flags.FLAGS.gpu_ctc:
            ctc_loss_raw = ctc_loss_dense(labels=tok_1,
                                          label_length=token_lengths,
                                          logits=logits,
                                          logit_length=sequence_length_ctc)
        else:
            with tf.device("/cpu:0"):
                ctc_loss_raw = ctc_loss_dense(labels=tok_1,
                                              label_length=token_lengths,
                                              logits=logits,
                                              logit_length=sequence_length_ctc)
            # blank_index=-1
            # sequence_length=tf.shape(logits)[0],
            # ctc_merge_repeated=True,
            # preprocess_collapse_repeated=False,
            # ctc_merge_repeated=True,
            # ignore_longer_outputs_than_inputs=False,
            # time_major=True
        ctc_loss = tf.reduce_mean(ctc_loss_raw, name='ctc_loss')
        tf.losses.add_loss(ctc_loss)

    losses = tf.losses.get_losses(scope=model_scope)
    print("Estimator losses: {}".format(losses))
    losses += tf.losses.get_regularization_losses(scope=model_scope)
    total_loss = tf.add_n(losses)
    updates = tf.get_collection(key=tf.GraphKeys.UPDATE_OPS, scope=model_scope)

    evaluation_hooks = []
    if logits is not None:
        autoencode_hook = CTCHook(logits=logits,
                                  lengths=sequence_length_ctc,
                                  vocab=vocab,
                                  path=os.path.join(run_config.model_dir,
                                                    "autoencoded",
                                                    "autoencoded-{:08d}.csv"),
                                  true=ctc_labels,
                                  name="Autoencoded",
                                  merge_repeated=True)
        evaluation_hooks.append(autoencode_hook)
    if glogits is not None:
        generate_hook = CTCHook(logits=glogits,
                                lengths=sequence_length_ctc,
                                vocab=vocab,
                                path=os.path.join(run_config.model_dir,
                                                  "generated",
                                                  "generated-{:08d}.csv"),
                                true=ctc_labels,
                                name="Generated",
                                merge_repeated=True)
        evaluation_hooks.append(generate_hook)

    tf.summary.scalar('ctc_loss', ctc_loss)
    tf.summary.scalar('total_loss', total_loss)

    # Train
    optimizer = tf.train.AdamOptimizer(params.lr)
    variables = tf.trainable_variables(scope=model_scope)
    transform_grads_fn = make_transform_grads_fn(params=params)

    train_op = create_train_op(total_loss=total_loss,
                               optimizer=optimizer,
                               update_ops=updates,
                               variables_to_train=variables,
                               transform_grads_fn=transform_grads_fn,
                               summarize_gradients=False,
                               aggregation_method=None,
                               check_numerics=True)
    eval_metric_ops = {
        'ctc_loss_eval': tf.metrics.mean(ctc_loss_raw),
        'token_lengths_eval': tf.metrics.mean(token_lengths)
    }

    return EstimatorSpec(mode=mode,
                         loss=total_loss,
                         eval_metric_ops=eval_metric_ops,
                         evaluation_hooks=evaluation_hooks,
                         training_hooks=training_hooks,
                         train_op=train_op)
Example #35
0
def gan_train_ops(
    model,
    loss,
    generator_optimizer,
    discriminator_optimizer,
    check_for_unused_update_ops=True,
    # Optional args to pass directly to the `create_train_op`.
    **kwargs):
  """Returns GAN train ops.

  The highest-level call in TFGAN. It is composed of functions that can also
  be called, should a user require more control over some part of the GAN
  training process.

  Args:
    model: A GANModel.
    loss: A GANLoss.
    generator_optimizer: The optimizer for generator updates.
    discriminator_optimizer: The optimizer for the discriminator updates.
    check_for_unused_update_ops: If `True`, throws an exception if there are
      update ops outside of the generator or discriminator scopes.
    **kwargs: Keyword args to pass directly to
      `training.create_train_op` for both the generator and
      discriminator train op.

  Returns:
    A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
    be used to train a generator/discriminator pair.
  """
  if isinstance(model, namedtuples.CycleGANModel):
    saved_params = locals()
    saved_params.pop('model', None)
    saved_params.pop('loss', None)
    kwargs = saved_params.pop('kwargs', {})
    saved_params.update(kwargs)
    with ops.name_scope('cyclegan_x2y_train'):
      train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y,
                                    **saved_params)
    with ops.name_scope('cyclegan_y2x_train'):
      train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x,
                                    **saved_params)
    return namedtuples.GANTrainOps(
        (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op),
        (train_ops_x2y.discriminator_train_op,
         train_ops_y2x.discriminator_train_op),
        training_util.get_or_create_global_step().assign_add(1))

  # Create global step increment op.
  global_step = training_util.get_or_create_global_step()
  global_step_inc = global_step.assign_add(1)

  # Get generator and discriminator update ops. We split them so that update
  # ops aren't accidentally run multiple times. For now, throw an error if
  # there are update ops that aren't associated with either the generator or
  # the discriminator. Might modify the `kwargs` dictionary.
  gen_update_ops, dis_update_ops = _get_update_ops(
      kwargs, model.generator_scope.name, model.discriminator_scope.name,
      check_for_unused_update_ops)

  generator_global_step = None
  if isinstance(generator_optimizer,
                sync_replicas_optimizer.SyncReplicasOptimizer):
    # TODO(joelshor): Figure out a way to get this work without including the
    # dummy global step in the checkpoint.
    # WARNING: Making this variable a local variable causes sync replicas to
    # hang forever.
    generator_global_step = variable_scope.get_variable(
        'dummy_global_step_generator',
        shape=[],
        dtype=global_step.dtype.base_dtype,
        initializer=init_ops.zeros_initializer(),
        trainable=False,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
    gen_update_ops += [generator_global_step.assign(global_step)]
  with ops.name_scope('generator_train'):
    gen_train_op = training.create_train_op(
        total_loss=loss.generator_loss,
        optimizer=generator_optimizer,
        variables_to_train=model.generator_variables,
        global_step=generator_global_step,
        update_ops=gen_update_ops,
        **kwargs)

  discriminator_global_step = None
  if isinstance(discriminator_optimizer,
                sync_replicas_optimizer.SyncReplicasOptimizer):
    # See comment above `generator_global_step`.
    discriminator_global_step = variable_scope.get_variable(
        'dummy_global_step_discriminator',
        shape=[],
        dtype=global_step.dtype.base_dtype,
        initializer=init_ops.zeros_initializer(),
        trainable=False,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
    dis_update_ops += [discriminator_global_step.assign(global_step)]
  with ops.name_scope('discriminator_train'):
    disc_train_op = training.create_train_op(
        total_loss=loss.discriminator_loss,
        optimizer=discriminator_optimizer,
        variables_to_train=model.discriminator_variables,
        global_step=discriminator_global_step,
        update_ops=dis_update_ops,
        **kwargs)

  return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)