def test_residual_block_can_be_called_channels_first(self):
     inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32)
     block = blocks.ResidualBlock(filters=3,
                                  strides=2,
                                  use_projection=True,
                                  data_format='channels_first')
     outputs = block(inputs, training=True)
     grads = tf.gradients(outputs, inputs)
     self.assertTrue(tf.compat.v1.trainable_variables())
     self.assertTrue(grads)
     self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
     self.assertListEqual([2, 3, 8, 8], outputs.shape.as_list())
 def test_residual_block_can_be_called_float16(self):
     inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16)
     with tf.variable_scope('float16', custom_getter=custom_float16_getter):
         block = blocks.ResidualBlock(filters=3,
                                      strides=2,
                                      use_projection=True,
                                      data_format='channels_last')
         outputs = block(inputs, training=True)
         grads = tf.gradients(outputs, inputs)
     self.assertTrue(tf.compat.v1.trainable_variables())
     self.assertTrue(grads)
     self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
     self.assertListEqual([2, 8, 8, 3], outputs.shape.as_list())