def test_resnext_block_can_be_called_channels_first(self):
     inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32)
     block = blocks.ResNextBlock(filters=64,
                                 strides=2,
                                 use_projection=True,
                                 data_format='channels_first')
     outputs = block(inputs, training=True)
     grads = tf.gradients(outputs, inputs)
     self.assertTrue(tf.compat.v1.trainable_variables())
     self.assertTrue(grads)
     self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
     self.assertListEqual([2, 256, 8, 8], outputs.shape.as_list())
 def test_resnext_block_can_be_called_float16(self):
     inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16)
     with tf.variable_scope('float16', custom_getter=custom_float16_getter):
         block = blocks.ResNextBlock(filters=64,
                                     strides=2,
                                     use_projection=True,
                                     data_format='channels_last')
         outputs = block(inputs, training=True)
         grads = tf.gradients(outputs, inputs)
     self.assertTrue(tf.compat.v1.trainable_variables())
     self.assertTrue(grads)
     self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
     self.assertListEqual([2, 8, 8, 256], outputs.shape.as_list())