def testEmptyUpdateOps(self):
        with tf.Graph().as_default():
            tf_inputs = tf.constant(self._inputs, dtype=tf.dtypes.float32)
            tf_labels = tf.constant(self._labels, dtype=tf.dtypes.float32)

            tf_predictions = self.batchnorm_classifier(tf_inputs)
            loss = tf.losses.log_loss(tf_labels, tf_predictions)
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
            train_op = utils.create_train_op(loss, optimizer, update_ops=[])

            moving_mean = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                            '.*moving_mean:')[0]
            moving_variance = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                '.*moving_variance:')[0]

            with self.cached_session() as session:
                # Initialize all variables
                session.run(tf.global_variables_initializer())
                mean, variance = session.run([moving_mean, moving_variance])
                # After initialization moving_mean == 0 and moving_variance == 1.
                self.assertAllClose(mean, [0] * 4)
                self.assertAllClose(variance, [1] * 4)

                for _ in range(10):
                    session.run(train_op)

                mean = moving_mean.eval()
                variance = moving_variance.eval()

                # Since we skip update_ops the moving_vars are not updated.
                self.assertAllClose(mean, [0] * 4)
                self.assertAllClose(variance, [1] * 4)
Beispiel #2
0
 def stage_2_train_op():
   return utils.create_train_op(
       stage_2_loss,
       stage_2_optimizer,
       update_ops=(None if
                   self.hparams.stage_2.training.update_encoder_batch_norm
                   else []))
    def testTrainOpInCollection(self):
        with tf.Graph().as_default():
            tf_inputs = tf.constant(self._inputs, dtype=tf.dtypes.float32)
            tf_labels = tf.constant(self._labels, dtype=tf.dtypes.float32)

            tf_predictions = self.batchnorm_classifier(tf_inputs)
            loss = tf.losses.log_loss(tf_labels, tf_predictions)
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
            train_op = utils.create_train_op(loss, optimizer)

            # Make sure the training op was recorded in the proper collection
            self.assertIn(train_op, tf.get_collection(tf.GraphKeys.TRAIN_OP))
    def testGlobalStepNotIncrementedWhenSetToNone(self):
        with tf.Graph().as_default():
            tf_inputs = tf.constant(self._inputs, dtype=tf.dtypes.float32)
            tf_labels = tf.constant(self._labels, dtype=tf.dtypes.float32)

            tf_predictions = self.batchnorm_classifier(tf_inputs)
            loss = tf.losses.log_loss(tf_labels, tf_predictions)
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
            train_op = utils.create_train_op(loss, optimizer, global_step=None)

            global_step = tf.train.get_or_create_global_step()

            with self.cached_session() as session:
                # Initialize all variables
                session.run(tf.global_variables_initializer())

                for _ in range(10):
                    session.run(train_op)

                # Since train_op don't use global_step it shouldn't change.
                self.assertAllClose(global_step.eval(), 0)
    def testUseUpdateOps(self):
        with tf.Graph().as_default():
            tf_inputs = tf.constant(self._inputs, dtype=tf.dtypes.float32)
            tf_labels = tf.constant(self._labels, dtype=tf.dtypes.float32)

            expected_mean = np.mean(self._inputs, axis=(0))
            expected_var = np.var(self._inputs, axis=(0))

            tf_predictions = self.batchnorm_classifier(tf_inputs)
            loss = tf.losses.log_loss(tf_labels, tf_predictions)
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)

            train_op = utils.create_train_op(loss, optimizer)

            moving_mean = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                            '.*moving_mean:')[0]
            moving_variance = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                                '.*moving_variance:')[0]

            with self.cached_session() as session:
                # Initialize all variables
                session.run(tf.global_variables_initializer())
                mean, variance = session.run([moving_mean, moving_variance])
                # After initialization moving_mean == 0 and moving_variance == 1.
                self.assertAllClose(mean, [0] * 4)
                self.assertAllClose(variance, [1] * 4)

                for _ in range(200):
                    session.run(train_op)

                mean = moving_mean.eval()
                variance = moving_variance.eval()
                # After 10 updates with decay 0.1 moving_mean == expected_mean and
                # moving_variance == expected_var.
                self.assertAllClose(mean, expected_mean)
                self.assertAllClose(variance, expected_var)