def _train_model(self, checkpoint_dir, num_steps): """Trains a simple classification model. Note that the data has been configured such that after around 300 steps, the model has memorized the dataset (e.g. we can expect %100 accuracy). Args: checkpoint_dir: The directory where the checkpoint is written to. num_steps: The number of steps to train for. """ with tf.Graph().as_default(): tf.compat.v1.random.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) tf_predictions = logistic_classifier(tf_inputs) loss = tf.compat.v1.losses.log_loss(tf_predictions, tf_labels) optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=1.0) train_op = contrib.create_train_op(loss, optimizer) with tf.compat.v1.train.MonitoredTrainingSession( hooks=[tf.estimator.StopAtStepHook(num_steps)], checkpoint_dir=checkpoint_dir) as sess: loss = None while not sess.should_stop(): loss = sess.run(train_op)
def testGlobalStepNotIncrementedWhenSetToNone(self): with tf.Graph().as_default(): tf.compat.v1.random.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) tf_predictions = batchnorm_classifier(tf_inputs) self.assertNotEmpty( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)) loss = tf.compat.v1.losses.log_loss(tf_labels, tf_predictions) optimizer = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=1.0) train_op = contrib_utils.create_train_op(loss, optimizer, global_step=None) global_step = tf.compat.v1.train.get_or_create_global_step() with self.cached_session() as sess: # Initialize all variables sess.run(tf.compat.v1.global_variables_initializer()) for _ in range(10): sess.run(train_op) # Since train_op don't use global_step it shouldn't change. self.assertAllClose(sess.run(global_step), 0)
def testEmptyUpdateOps(self): with tf.Graph().as_default(): tf.compat.v1.random.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) tf_predictions = batchnorm_classifier(tf_inputs) self.assertNotEmpty( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)) loss = tf.compat.v1.losses.log_loss(tf_labels, tf_predictions) optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=1.0) train_op = contrib_utils.create_train_op(loss, optimizer, update_ops=[]) moving_mean = contrib_utils.get_variables_by_name('moving_mean')[0] moving_variance = contrib_utils.get_variables_by_name( 'moving_variance')[0] with self.cached_session() as sess: # Initialize all variables sess.run(tf.compat.v1.global_variables_initializer()) mean, variance = sess.run([moving_mean, moving_variance]) # After initialization moving_mean == 0 and moving_variance == 1. self.assertAllClose(mean, [0] * 4) self.assertAllClose(variance, [1] * 4) for _ in range(10): sess.run(train_op) mean = sess.run(moving_mean) variance = sess.run(moving_variance) # Since we skip update_ops the moving_vars are not updated. self.assertAllClose(mean, [0] * 4) self.assertAllClose(variance, [1] * 4)
def testTrainOpInCollection(self): with tf.Graph().as_default(): tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) tf_predictions = batchnorm_classifier(tf_inputs) self.assertNotEmpty( tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)) loss = tf.compat.v1.losses.log_loss(tf_labels, tf_predictions) optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=1.0) train_op = contrib_utils.create_train_op(loss, optimizer) # Make sure the training op was recorded in the proper collection self.assertIn( train_op, tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAIN_OP))
def dis_train_op(gan_model, gan_loss): """Get the discriminator train op for a single training substep. Args: gan_model: The GANModel tuple. gan_loss: The GANLoss tuple. Returns: An Op that performs a single discriminator training substep. """ with tf.compat.v1.name_scope('discriminator_train'): return contrib.create_train_op( total_loss=gan_loss.discriminator_loss, optimizer=optimizers.dopt, variables_to_train=gan_model.discriminator_variables, global_step=None, update_ops=update_ops(gan_model)[1])