Beispiel #1
0
 def test_output_shapes(self, groups, input_channel, output_channel):
     l = custom_layers.GroupConv2D(output_channel,
                                   3,
                                   groups=groups,
                                   use_bias=False,
                                   padding='same')
     outputs = l(_get_random_inputs(input_shape=[2, 32, 32, input_channel]))
     self.assertListEqual(outputs.get_shape().as_list(),
                          [2, 32, 32, output_channel])
Beispiel #2
0
 def test_kernel_shapes(self, groups, input_channel, output_channel):
     l = custom_layers.GroupConv2D(output_channel,
                                   3,
                                   groups=groups,
                                   use_bias=False)
     _ = l(_get_random_inputs(input_shape=(1, 32, 32, input_channel)))
     expected_kernel_shapes = [(3, 3, int(input_channel / groups),
                                int(output_channel / groups))
                               for _ in range(groups)]
     kernel_shapes = [
         l.trainable_weights[i].get_shape()
         for i in range(len(l.trainable_weights))
     ]
     self.assertListEqual(kernel_shapes, expected_kernel_shapes)
Beispiel #3
0
 def test_construction(self, groups, input_channel, output_channel,
                       use_batch_norm):
     batch_norm_layer = BATCH_NORM_LAYER if use_batch_norm else None
     l = custom_layers.GroupConv2D(output_channel,
                                   3,
                                   groups=groups,
                                   use_bias=True,
                                   batch_norm_layer=batch_norm_layer)
     inputs = _get_random_inputs(input_shape=(1, 4, 4, output_channel))
     _ = l(inputs)
     # kernel and bias for each group. When using batch norm, 2 additional
     # trainable weights per group for batchnorm layers: gamma and beta.
     expected_num_trainable_weights = groups * (2 + 2 * use_batch_norm)
     self.assertLen(l.trainable_weights, expected_num_trainable_weights)
Beispiel #4
0
    def test_equivalence(self, groups, input_channel, output_channel,
                         use_batch_norm, activation):
        batch_norm_layer = BATCH_NORM_LAYER if use_batch_norm else None
        kwargs = dict(filters=output_channel,
                      groups=groups,
                      kernel_size=1,
                      use_bias=False,
                      batch_norm_layer=batch_norm_layer,
                      activation=activation)
        gc_layer = tf.keras.Sequential([custom_layers.GroupConv2D(**kwargs)])
        gc_model = custom_layers.GroupConv2DKerasModel(**kwargs)
        gc_layer.build(input_shape=(None, 3, 3, input_channel))
        gc_model.build(input_shape=(None, 3, 3, input_channel))

        inputs = _get_random_inputs((2, 3, 3, input_channel))
        gc_layer.set_weights(gc_model.get_weights())

        self.assertAllEqual(gc_layer(inputs), gc_model(inputs))
def groupconv2d_block(conv_filters: Optional[int],
                      config: ModelConfig,
                      kernel_size: Any = (1, 1),
                      strides: Any = (1, 1),
                      group_size: Optional[int] = None,
                      use_batch_norm: bool = True,
                      use_bias: bool = False,
                      activation: Any = None,
                      name: Optional[str] = None) -> tf.keras.layers.Layer:
    """2D group convolution with batchnorm and activation."""
    batch_norm = common_modules.get_batch_norm(config.batch_norm)
    bn_momentum = config.bn_momentum
    bn_epsilon = config.bn_epsilon
    data_format = tf.keras.backend.image_data_format()
    weight_decay = config.weight_decay
    if group_size is None:
        group_size = config.group_base_size

    name = name or ''
    # Compute the # of groups
    if conv_filters % group_size != 0:
        raise ValueError(
            f'Number of filters: {conv_filters} is not divisible by '
            f'size of the groups: {group_size}')
    groups = int(conv_filters / group_size)
    # Collect args based on what kind of groupconv2d block is desired
    init_kwargs = {
        'kernel_size': kernel_size,
        'strides': strides,
        'use_bias': use_bias,
        'padding': 'same',
        'name': name + '_groupconv2d',
        'kernel_regularizer': tf.keras.regularizers.l2(weight_decay),
        'bias_regularizer': tf.keras.regularizers.l2(weight_decay),
        'filters': conv_filters,
        'groups': groups,
        'batch_norm_layer': batch_norm if use_batch_norm else None,
        'bn_epsilon': bn_epsilon,
        'bn_momentum': bn_momentum,
        'activation': activation,
        'data_format': data_format,
    }
    return custom_layers.GroupConv2D(**init_kwargs)
Beispiel #6
0
 def test_serialization_deserialization(self, groups, use_batch_norm,
                                        activation):
     batch_norm_layer = BATCH_NORM_LAYER if use_batch_norm else None
     l = custom_layers.GroupConv2D(filters=8,
                                   kernel_size=1,
                                   groups=groups,
                                   use_bias=False,
                                   padding='same',
                                   batch_norm_layer=batch_norm_layer,
                                   activation=activation)
     config = l.get_config()
     # New layer from config
     new_l = custom_layers.GroupConv2D.from_config(config)
     # Copy the weights too.
     l.build(input_shape=(1, 1, 4))
     new_l.build(input_shape=(1, 1, 4))
     new_l.set_weights(l.get_weights())
     inputs = _get_random_inputs((1, 1, 1, 4))
     self.assertNotEqual(l, new_l)
     self.assertAllEqual(l(inputs), new_l(inputs))