def test_resnext_block_can_be_called_channels_first(self): inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32) block = blocks.ResNextBlock(filters=64, strides=2, use_projection=True, data_format='channels_first') outputs = block(inputs, training=True) grads = tf.gradients(outputs, inputs) self.assertTrue(tf.compat.v1.trainable_variables()) self.assertTrue(grads) self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) self.assertListEqual([2, 256, 8, 8], outputs.shape.as_list())
def test_resnext_block_can_be_called_float16(self): inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16) with tf.variable_scope('float16', custom_getter=custom_float16_getter): block = blocks.ResNextBlock(filters=64, strides=2, use_projection=True, data_format='channels_last') outputs = block(inputs, training=True) grads = tf.gradients(outputs, inputs) self.assertTrue(tf.compat.v1.trainable_variables()) self.assertTrue(grads) self.assertTrue(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) self.assertListEqual([2, 8, 8, 256], outputs.shape.as_list())