def test_padded_conv_can_be_called_channels_first(self): inputs = tf.random.normal([2, 32, 16, 16], dtype=tf.float32) block = blocks.Conv2DFixedPadding( filters=4, kernel_size=3, strides=2, data_format='channels_first') outputs = block(inputs, training=True) grads = tf.gradients(outputs, inputs) self.assertTrue(tf.compat.v1.trainable_variables()) self.assertTrue(grads) self.assertListEqual([2, 4, 8, 8], outputs.shape.as_list())
def test_padded_conv_can_be_called_float16(self): inputs = tf.random.normal([2, 16, 16, 32], dtype=tf.float16) with tf.variable_scope('float16', custom_getter=custom_float16_getter): block = blocks.Conv2DFixedPadding( filters=4, kernel_size=3, strides=2, data_format='channels_last') outputs = block(inputs, training=True) grads = tf.gradients(outputs, inputs) self.assertTrue(tf.compat.v1.trainable_variables()) self.assertTrue(grads) self.assertListEqual([2, 8, 8, 4], outputs.shape.as_list())
def __init__(self, block_fn, block_group_sizes, width=1, first_conv_kernel_size=7, first_conv_stride=2, use_initial_max_pool=True, data_format='channels_last', batch_norm_momentum=blocks_lib.BATCH_NORM_MOMENTUM, use_global_batch_norm=True, name='AbstractResidualNetwork', **kwargs): super(_BaseResidualNetwork, self).__init__(name=name, **kwargs) self.data_format = data_format self.num_block_groups = len(block_group_sizes) self.initial_conv = blocks_lib.Conv2DFixedPadding( filters=int(64 * width), kernel_size=first_conv_kernel_size, strides=first_conv_stride, data_format=data_format) self.initial_batchnorm = blocks_lib.batch_norm( data_format=data_format, use_global_batch_norm=use_global_batch_norm) self.initial_activation = tf.keras.layers.Activation('relu') self.initial_max_pool = None if use_initial_max_pool: self.initial_max_pool = tf.layers.MaxPooling2D( pool_size=3, strides=2, padding='SAME', data_format=data_format) for i, num_blocks in enumerate(block_group_sizes): # Use setattr rather than appending to a list, since Keras only tracks # sublayers that are direct members of the parent layers. setattr( self, f'block_group_{i}', _BlockGroup(filters=int(64 * 2**i * width), block_fn=block_fn, num_blocks=num_blocks, strides=1 if i == 0 else 2, data_format=data_format, batch_norm_momentum=batch_norm_momentum, use_global_batch_norm=use_global_batch_norm, name=f'BlockGroup{i}'))