예제 #1
0
  def _train_model(self, checkpoint_dir, num_steps):
    """Trains a simple classification model.

    Note that the data has been configured such that after around 300 steps,
    the model has memorized the dataset (e.g. we can expect %100 accuracy).

    Args:
      checkpoint_dir: The directory where the checkpoint is written to.
      num_steps: The number of steps to train for.
    """
    with tf.Graph().as_default():
      tf.compat.v1.random.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      tf_predictions = logistic_classifier(tf_inputs)
      loss = tf.compat.v1.losses.log_loss(tf_predictions, tf_labels)

      optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=1.0)
      train_op = contrib.create_train_op(loss, optimizer)

      with tf.compat.v1.train.MonitoredTrainingSession(
          hooks=[tf.estimator.StopAtStepHook(num_steps)],
          checkpoint_dir=checkpoint_dir) as sess:
        loss = None
        while not sess.should_stop():
          loss = sess.run(train_op)
예제 #2
0
    def testGlobalStepNotIncrementedWhenSetToNone(self):
        with tf.Graph().as_default():
            tf.compat.v1.random.set_random_seed(0)
            tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
            tf_labels = tf.constant(self._labels, dtype=tf.float32)

            tf_predictions = batchnorm_classifier(tf_inputs)
            self.assertNotEmpty(
                tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS))
            loss = tf.compat.v1.losses.log_loss(tf_labels, tf_predictions)
            optimizer = tf.compat.v1.train.GradientDescentOptimizer(
                learning_rate=1.0)
            train_op = contrib_utils.create_train_op(loss,
                                                     optimizer,
                                                     global_step=None)

            global_step = tf.compat.v1.train.get_or_create_global_step()

            with self.cached_session() as sess:
                # Initialize all variables
                sess.run(tf.compat.v1.global_variables_initializer())

                for _ in range(10):
                    sess.run(train_op)

                # Since train_op don't use global_step it shouldn't change.
                self.assertAllClose(sess.run(global_step), 0)
예제 #3
0
  def testEmptyUpdateOps(self):
    with tf.Graph().as_default():
      tf.compat.v1.random.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      self.assertNotEmpty(
          tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS))
      loss = tf.compat.v1.losses.log_loss(tf_labels, tf_predictions)
      optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=1.0)
      train_op = contrib_utils.create_train_op(loss, optimizer, update_ops=[])

      moving_mean = contrib_utils.get_variables_by_name('moving_mean')[0]
      moving_variance = contrib_utils.get_variables_by_name(
          'moving_variance')[0]

      with self.cached_session() as sess:
        # Initialize all variables
        sess.run(tf.compat.v1.global_variables_initializer())
        mean, variance = sess.run([moving_mean, moving_variance])
        # After initialization moving_mean == 0 and moving_variance == 1.
        self.assertAllClose(mean, [0] * 4)
        self.assertAllClose(variance, [1] * 4)

        for _ in range(10):
          sess.run(train_op)

        mean = sess.run(moving_mean)
        variance = sess.run(moving_variance)

        # Since we skip update_ops the moving_vars are not updated.
        self.assertAllClose(mean, [0] * 4)
        self.assertAllClose(variance, [1] * 4)
예제 #4
0
  def testTrainOpInCollection(self):
    with tf.Graph().as_default():
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      self.assertNotEmpty(
          tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS))
      loss = tf.compat.v1.losses.log_loss(tf_labels, tf_predictions)
      optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=1.0)
      train_op = contrib_utils.create_train_op(loss, optimizer)

      # Make sure the training op was recorded in the proper collection
      self.assertIn(
          train_op,
          tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAIN_OP))
예제 #5
0
    def dis_train_op(gan_model, gan_loss):
        """Get the discriminator train op for a single training substep.

    Args:
      gan_model: The GANModel tuple.
      gan_loss: The GANLoss tuple.

    Returns:
      An Op that performs a single discriminator training substep.
    """
        with tf.compat.v1.name_scope('discriminator_train'):
            return contrib.create_train_op(
                total_loss=gan_loss.discriminator_loss,
                optimizer=optimizers.dopt,
                variables_to_train=gan_model.discriminator_variables,
                global_step=None,
                update_ops=update_ops(gan_model)[1])