예제 #1
0
  def testEmptyUpdateOps(self):
    with tf.Graph().as_default():
      tf.compat.v1.random.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      self.assertNotEmpty(
          tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS))
      loss = tf.compat.v1.losses.log_loss(tf_labels, tf_predictions)
      optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=1.0)
      train_op = contrib_utils.create_train_op(loss, optimizer, update_ops=[])

      moving_mean = contrib_utils.get_variables_by_name('moving_mean')[0]
      moving_variance = contrib_utils.get_variables_by_name(
          'moving_variance')[0]

      with self.cached_session() as sess:
        # Initialize all variables
        sess.run(tf.compat.v1.global_variables_initializer())
        mean, variance = sess.run([moving_mean, moving_variance])
        # After initialization moving_mean == 0 and moving_variance == 1.
        self.assertAllClose(mean, [0] * 4)
        self.assertAllClose(variance, [1] * 4)

        for _ in range(10):
          sess.run(train_op)

        mean = sess.run(moving_mean)
        variance = sess.run(moving_variance)

        # Since we skip update_ops the moving_vars are not updated.
        self.assertAllClose(mean, [0] * 4)
        self.assertAllClose(variance, [1] * 4)
예제 #2
0
 def testCreateVariables(self):
     if tf.executing_eagerly():
         # Collections don't work with eager.
         return
     height, width = 3, 3
     images = tf.random.uniform((5, height, width, 3), seed=1)
     norm.instance_norm(images, center=True, scale=True)
     self.assertLen(contrib_utils.get_variables_by_name('beta'), 1)
     self.assertLen(contrib_utils.get_variables_by_name('gamma'), 1)
예제 #3
0
 def testReuseVariables(self):
     if tf.executing_eagerly():
         # Variable reuse doesn't work with eager.
         return
     height, width = 3, 3
     images = tf.random.uniform((5, height, width, 3), seed=1)
     norm.instance_norm(images, scale=True, scope='IN')
     norm.instance_norm(images, scale=True, scope='IN', reuse=True)
     self.assertLen(contrib_utils.get_variables_by_name('beta'), 1)
     self.assertLen(contrib_utils.get_variables_by_name('gamma'), 1)
예제 #4
0
 def testCreateOpNoScaleCenter(self):
   if tf.executing_eagerly():
     # Collections don't work with eager.
     return
   height, width = 3, 3
   images = tf.random.uniform((5, height, width, 3), dtype=tf.float64, seed=1)
   output = norm.instance_norm(images, center=False, scale=False)
   self.assertListEqual([5, height, width, 3], output.shape.as_list())
   self.assertEmpty(contrib_utils.get_variables_by_name('beta'))
   self.assertEmpty(contrib_utils.get_variables_by_name('gamma'))
예제 #5
0
 def testCreateVariables_NCHW(self):
     if tf.executing_eagerly():
         # Collections don't work with eager.
         return
     height, width, groups = 3, 3, 4
     images = tf.random.uniform((5, 2 * groups, height, width), seed=1)
     norm.group_norm(images,
                     groups=4,
                     channels_axis=-3,
                     reduction_axes=(-2, -1),
                     center=True,
                     scale=True)
     self.assertLen(contrib_utils.get_variables_by_name('beta'), 1)
     self.assertLen(contrib_utils.get_variables_by_name('gamma'), 1)