def test_batch_norm(self):
    """Test virtual BN recovers BN when the virtual size equals batch size."""
    rng = jax.random.PRNGKey(0)
    batch_size = 10
    feature_size = 7
    input_shape = (batch_size, 3, 3, feature_size)
    half_input_shape = (batch_size // 2, 3, 3, feature_size)
    twos = 2.0 * jnp.ones(half_input_shape)
    nines = 9.0 * jnp.ones(half_input_shape)
    x = jnp.concatenate((twos, nines))

    bn_flax_module = nn.BatchNorm(momentum=0.9)
    bn_params, bn_state = _init(bn_flax_module, rng, input_shape)

    vbn_flax_module = normalization.VirtualBatchNorm(
        momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC')
    vbn_params, vbn_state = _init(vbn_flax_module, rng, input_shape)

    _, bn_state = bn_flax_module.apply(
        {'params': bn_params, 'batch_stats': bn_state},
        x,
        mutable=['batch_stats'],
        use_running_average=False)
    bn_state = bn_state['batch_stats']
    bn_y, bn_state = bn_flax_module.apply(
        {'params': bn_params, 'batch_stats': bn_state},
        x,
        mutable=['batch_stats'],
        use_running_average=False)
    bn_state = bn_state['batch_stats']

    _, vbn_state = vbn_flax_module.apply(
        {'params': vbn_params, 'batch_stats': vbn_state},
        x,
        mutable=['batch_stats'],
        use_running_average=False)
    vbn_state = vbn_state['batch_stats']
    vbn_y, vbn_state = vbn_flax_module.apply(
        {'params': vbn_params, 'batch_stats': vbn_state},
        x,
        mutable=['batch_stats'],
        use_running_average=False)
    vbn_state = vbn_state['batch_stats']

    # Test that the layer forward passes are the same.
    np.testing.assert_allclose(bn_y, vbn_y, atol=1e-4)

    # Test that virtual and regular BN produce the same EMAs.
    np.testing.assert_allclose(
        bn_state['mean'],
        np.squeeze(vbn_state['batch_norm_running_mean']),
        atol=1e-4)
    np.testing.assert_allclose(
        bn_state['var'],
        np.squeeze(vbn_state['batch_norm_running_var']),
        atol=1e-4)
Exemple #2
0
 def apply(
     self,
     x,
     num_classes,
     num_filters=64,
     num_layers=50,
     train=True,
     axis_name=None,
     axis_index_groups=None,
     dtype=jnp.float32,
     batch_norm_momentum=0.9,
     batch_norm_epsilon=1e-5,
     bn_output_scale=0.0,
     virtual_batch_size=None,
     data_format=None):
   if num_layers not in _block_size_options:
     raise ValueError('Please provide a valid number of layers')
   block_sizes = _block_size_options[num_layers]
   conv = nn.Conv.partial(padding=[(3, 3), (3, 3)])
   x = conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False,
            dtype=dtype, name='conv0')
   x = normalization.VirtualBatchNorm(
       x,
       use_running_average=not train,
       momentum=batch_norm_momentum,
       epsilon=batch_norm_epsilon,
       name='init_bn',
       axis_name=axis_name,
       axis_index_groups=axis_index_groups,
       dtype=dtype,
       virtual_batch_size=virtual_batch_size,
       data_format=data_format)
   x = nn.relu(x)  # MLPerf-required
   x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
   for i, block_size in enumerate(block_sizes):
     for j in range(block_size):
       strides = (2, 2) if i > 0 and j == 0 else (1, 1)
       x = ResidualBlock(
           x,
           num_filters * 2 ** i,
           strides=strides,
           train=train,
           axis_name=axis_name,
           axis_index_groups=axis_index_groups,
           dtype=dtype,
           batch_norm_momentum=batch_norm_momentum,
           batch_norm_epsilon=batch_norm_epsilon,
           bn_output_scale=bn_output_scale,
           virtual_batch_size=virtual_batch_size,
           data_format=data_format)
   x = jnp.mean(x, axis=(1, 2))
   x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(),
                dtype=dtype)
   return x
Exemple #3
0
 def maybe_normalize(name):
     if self.use_bn:
         return normalization.VirtualBatchNorm(
             momentum=self.batch_norm_momentum,
             epsilon=self.batch_norm_epsilon,
             dtype=self.dtype,
             batch_size=self.batch_size,
             virtual_batch_size=self.virtual_batch_size,
             total_batch_size=self.total_batch_size,
             data_format=self.data_format,
             name=name)
     else:
         return lambda x, **kwargs: x
Exemple #4
0
    def __call__(self, x, train):
        if self.num_layers not in _block_size_options:
            raise ValueError('Please provide a valid number of layers')
        block_sizes = _block_size_options[self.num_layers]

        x = nn.Conv(self.num_filters, (7, 7), (2, 2),
                    use_bias=False,
                    dtype=self.dtype,
                    name='init_conv')(x)
        if self.use_bn:
            x = normalization.VirtualBatchNorm(
                momentum=self.batch_norm_momentum,
                epsilon=self.batch_norm_epsilon,
                dtype=self.dtype,
                name='init_bn',
                batch_size=self.batch_size,
                virtual_batch_size=self.virtual_batch_size,
                total_batch_size=self.total_batch_size,
                data_format=self.data_format)(x, use_running_average=not train)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        if self.block_type == 'post_activation':
            residual_block = ResidualBlock
        elif self.block_type == 'pre_activation':
            residual_block = PreActResidualBlock
        else:
            raise ValueError('Invalid Block Type: {}'.format(self.block_type))
        index = 0
        for i, block_size in enumerate(block_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                index += 1
                x = residual_block(
                    self.num_filters * 2**i,
                    strides=strides,
                    dtype=self.dtype,
                    batch_norm_momentum=self.batch_norm_momentum,
                    batch_norm_epsilon=self.batch_norm_epsilon,
                    batch_size=self.batch_size,
                    virtual_batch_size=self.virtual_batch_size,
                    total_batch_size=self.total_batch_size,
                    data_format=self.data_format,
                    bn_relu_conv=self.bn_relu_conv,
                    use_bn=self.use_bn,
                    activation_function=self.activation_function)(x,
                                                                  train=train)
        x = jnp.mean(x, axis=(1, 2))
        if self.dropout_rate > 0.0:
            x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(self.num_outputs, dtype=self.dtype)(x)
        return x
Exemple #5
0
 def apply(self,
           x,
           num_outputs,
           num_filters=64,
           num_layers=50,
           train=True,
           batch_stats=None,
           dtype=jnp.float32,
           batch_norm_momentum=0.9,
           batch_norm_epsilon=1e-5,
           virtual_batch_size=None,
           data_format=None):
     if num_layers not in _block_size_options:
         raise ValueError('Please provide a valid number of layers')
     block_sizes = _block_size_options[num_layers]
     x = nn.Conv(x,
                 num_filters, (3, 3), (1, 1),
                 'SAME',
                 bias=False,
                 dtype=dtype,
                 name='init_conv')
     x = normalization.VirtualBatchNorm(
         x,
         batch_stats=batch_stats,
         use_running_average=not train,
         momentum=batch_norm_momentum,
         epsilon=batch_norm_epsilon,
         dtype=dtype,
         name='init_bn',
         virtual_batch_size=virtual_batch_size,
         data_format=data_format)
     x = nn.relu(x)
     residual_block = block_type_options[num_layers]
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = residual_block(x,
                                num_filters * 2**i,
                                strides=strides,
                                train=train,
                                batch_stats=batch_stats,
                                dtype=dtype,
                                batch_norm_momentum=batch_norm_momentum,
                                batch_norm_epsilon=batch_norm_epsilon,
                                virtual_batch_size=virtual_batch_size,
                                data_format=data_format)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x, num_outputs, dtype=dtype)
     return x
Exemple #6
0
 def __call__(self, x, train):
     if self.num_layers not in _block_size_options:
         raise ValueError('Please provide a valid number of layers')
     block_sizes = _block_size_options[self.num_layers]
     conv = functools.partial(nn.Conv, padding=[(3, 3), (3, 3)])
     x = conv(self.num_filters,
              kernel_size=(7, 7),
              strides=(2, 2),
              use_bias=False,
              dtype=self.dtype,
              name='conv0')(x)
     x = normalization.VirtualBatchNorm(
         momentum=self.batch_norm_momentum,
         epsilon=self.batch_norm_epsilon,
         name='init_bn',
         axis_name=self.axis_name,
         axis_index_groups=self.axis_index_groups,
         dtype=self.dtype,
         batch_size=self.batch_size,
         virtual_batch_size=self.virtual_batch_size,
         total_batch_size=self.total_batch_size,
         data_format=self.data_format)(x, use_running_average=not train)
     x = nn.relu(x)  # MLPerf-required
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = ResidualBlock(self.num_filters * 2**i,
                               strides=strides,
                               axis_name=self.axis_name,
                               axis_index_groups=self.axis_index_groups,
                               dtype=self.dtype,
                               batch_norm_momentum=self.batch_norm_momentum,
                               batch_norm_epsilon=self.batch_norm_epsilon,
                               bn_output_scale=self.bn_output_scale,
                               batch_size=self.batch_size,
                               virtual_batch_size=self.virtual_batch_size,
                               total_batch_size=self.total_batch_size,
                               data_format=self.data_format)(x, train=train)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(self.num_classes,
                  kernel_init=nn.initializers.normal(),
                  dtype=self.dtype)(x)
     return x
  def test_different_eval_batch_size(self):
    """Test virtual BN can use a different batch size for evals."""
    rng = jax.random.PRNGKey(0)
    batch_size = 10
    feature_size = 7
    input_shape = (batch_size, 3, 3, feature_size)
    x = 2.0 * jnp.ones(input_shape)

    vbn_flax_module = normalization.VirtualBatchNorm(
        momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC')
    vbn_params, vbn_state = _init(vbn_flax_module, rng, input_shape)

    _, vbn_state = vbn_flax_module.apply(
        {'params': vbn_params, 'batch_stats': vbn_state},
        x,
        mutable=['batch_stats'],
        use_running_average=False)
    vbn_state = vbn_state['batch_stats']

    vbn_flax_module.apply(
        {'params': vbn_params, 'batch_stats': vbn_state},
        jnp.ones((13, 3, 3, feature_size)),
        use_running_average=True)
Exemple #8
0
 def __call__(self, x, train):
   if self.num_layers not in _block_size_options:
     raise ValueError('Please provide a valid number of layers')
   block_sizes = _block_size_options[self.num_layers]
   x = nn.Conv(
       self.num_filters, (3, 3), (1, 1),
       'SAME',
       use_bias=False,
       dtype=self.dtype,
       name='init_conv')(x)
   x = normalization.VirtualBatchNorm(
       momentum=self.batch_norm_momentum,
       epsilon=self.batch_norm_epsilon,
       dtype=self.dtype,
       name='init_bn',
       batch_size=self.batch_size,
       virtual_batch_size=self.virtual_batch_size,
       total_batch_size=self.total_batch_size,
       data_format=self.data_format)(x, use_running_average=not train)
   x = nn.relu(x)
   residual_block = block_type_options[self.num_layers]
   for i, block_size in enumerate(block_sizes):
     for j in range(block_size):
       strides = (2, 2) if i > 0 and j == 0 else (1, 1)
       x = residual_block(
           self.num_filters * 2**i,
           strides=strides,
           dtype=self.dtype,
           batch_norm_momentum=self.batch_norm_momentum,
           batch_norm_epsilon=self.batch_norm_epsilon,
           batch_size=self.batch_size,
           virtual_batch_size=self.virtual_batch_size,
           total_batch_size=self.total_batch_size,
           data_format=self.data_format)(x, train=train)
   x = jnp.mean(x, axis=(1, 2))
   x = nn.Dense(self.num_outputs, dtype=self.dtype)(x)
   return x
  def test_forward_pass(self):
    """Test that two calls are the same as one with twice the batch size."""
    rng = jax.random.PRNGKey(0)
    batch_size = 10
    feature_size = 7
    input_shape = (batch_size, 3, 3, feature_size)
    half_input_shape = (batch_size // 2, 3, 3, feature_size)
    twos = 2.0 * jnp.ones(half_input_shape)
    threes = 3.0 * jnp.ones(half_input_shape)
    fives = 5.0 * jnp.ones(half_input_shape)
    nines = 9.0 * jnp.ones(half_input_shape)
    # The mean(x1) = 2.5, stddev(x1) = 0.5 so we expect
    # `(x1 - mean(x1)) / stddev(x1)` to be half -1.0, then half 1.0.
    x1 = jnp.concatenate((twos, threes))
    # The mean(x2) = 7.0, stddev(x2) = 2.0 so we expect
    # `(x2 - mean(x2)) / stddev(x2)` to be half -1.0, then half 1.0.
    x2 = jnp.concatenate((fives, nines))
    x_both = jnp.concatenate((x1, x2))

    expected_bn_y = jnp.concatenate(
        (jnp.ones(half_input_shape) * -1.0, jnp.ones(half_input_shape)))

    bn_flax_module = nn.BatchNorm(momentum=0.9)
    bn_params, bn_state = _init(bn_flax_module, rng, input_shape)

    vbn_flax_module = normalization.VirtualBatchNorm(
        momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC')
    vbn_params, vbn_state = _init(vbn_flax_module, rng, input_shape)

    bn_y1, _ = bn_flax_module.apply(
        {'params': bn_params, 'batch_stats': bn_state},
        x1,
        mutable=['batch_stats'],
        use_running_average=False)

    bn_y2, _ = bn_flax_module.apply(
        {'params': bn_params, 'batch_stats': bn_state},
        x2,
        mutable=['batch_stats'],
        use_running_average=False)

    bn_y_both, _ = bn_flax_module.apply(
        {'params': bn_params, 'batch_stats': bn_state},
        x_both,
        mutable=['batch_stats'],
        use_running_average=False)

    vbn_y_both, vbn_state = vbn_flax_module.apply(
        {'params': vbn_params, 'batch_stats': vbn_state},
        x_both,
        mutable=['batch_stats'],
        use_running_average=False)
    vbn_state = vbn_state['batch_stats']

    # Test that the layer forward passes behave as expected.
    np.testing.assert_allclose(bn_y1, expected_bn_y, atol=1e-4)
    np.testing.assert_allclose(bn_y2, expected_bn_y, atol=1e-4)
    np.testing.assert_allclose(
        vbn_y_both, jnp.concatenate((bn_y1, bn_y2)), atol=1e-4)
    # Test that the virtual batch norm and nn.BatchNorm layers do not perform
    # the same calculation on the concatenated batch.
    # There is no negative of `np.testing.assert_allclose` so we test that the
    # diff is greater than zero.
    np.testing.assert_array_less(
        -jnp.abs(vbn_y_both - bn_y_both), jnp.zeros_like(vbn_y_both))

    _, vbn_state = vbn_flax_module.apply(
        {'params': vbn_params, 'batch_stats': vbn_state},
        x_both,
        mutable=['batch_stats'],
        use_running_average=False)
    vbn_state = vbn_state['batch_stats']

    # The mean running average stats at 0.0, and the variance starts at 1.0. So
    # after two applications of the same batch we should expect the value to be
    # mean_ema = 0.9 * (0.9 * 0.0 + 0.1 * mean) + 0.1 * mean = 0.19 * mean
    # var_ema = 0.9 * (0.9 * 1.0 + 0.1 * var) + 0.1 * var = 0.19 * var + 0.81
    expected_mean_ema_x1 = 0.19 * jnp.mean(x1) * jnp.ones((feature_size,))
    expected_mean_ema_x2 = 0.19 * jnp.mean(x2) * jnp.ones((feature_size,))
    expected_mean_ema_both = (expected_mean_ema_x1 + expected_mean_ema_x2) / 2.0
    expected_var_ema_both = (
        (0.19 * jnp.std(jnp.concatenate((x1, x2))) ** 2.0 + 0.81) *
        jnp.ones((feature_size,)))
    np.testing.assert_allclose(
        np.squeeze(vbn_state['batch_norm_running_mean']),
        expected_mean_ema_both,
        atol=1e-4)
    np.testing.assert_allclose(
        np.squeeze(vbn_state['batch_norm_running_var']),
        expected_var_ema_both,
        atol=1e-4)