def test_batch_norm(self): """Test virtual BN recovers BN when the virtual size equals batch size.""" rng = jax.random.PRNGKey(0) batch_size = 10 feature_size = 7 input_shape = (batch_size, 3, 3, feature_size) half_input_shape = (batch_size // 2, 3, 3, feature_size) twos = 2.0 * jnp.ones(half_input_shape) nines = 9.0 * jnp.ones(half_input_shape) x = jnp.concatenate((twos, nines)) bn_flax_module = nn.BatchNorm(momentum=0.9) bn_params, bn_state = _init(bn_flax_module, rng, input_shape) vbn_flax_module = normalization.VirtualBatchNorm( momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC') vbn_params, vbn_state = _init(vbn_flax_module, rng, input_shape) _, bn_state = bn_flax_module.apply( {'params': bn_params, 'batch_stats': bn_state}, x, mutable=['batch_stats'], use_running_average=False) bn_state = bn_state['batch_stats'] bn_y, bn_state = bn_flax_module.apply( {'params': bn_params, 'batch_stats': bn_state}, x, mutable=['batch_stats'], use_running_average=False) bn_state = bn_state['batch_stats'] _, vbn_state = vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, x, mutable=['batch_stats'], use_running_average=False) vbn_state = vbn_state['batch_stats'] vbn_y, vbn_state = vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, x, mutable=['batch_stats'], use_running_average=False) vbn_state = vbn_state['batch_stats'] # Test that the layer forward passes are the same. np.testing.assert_allclose(bn_y, vbn_y, atol=1e-4) # Test that virtual and regular BN produce the same EMAs. np.testing.assert_allclose( bn_state['mean'], np.squeeze(vbn_state['batch_norm_running_mean']), atol=1e-4) np.testing.assert_allclose( bn_state['var'], np.squeeze(vbn_state['batch_norm_running_var']), atol=1e-4)
def apply( self, x, num_classes, num_filters=64, num_layers=50, train=True, axis_name=None, axis_index_groups=None, dtype=jnp.float32, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, bn_output_scale=0.0, virtual_batch_size=None, data_format=None): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] conv = nn.Conv.partial(padding=[(3, 3), (3, 3)]) x = conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, dtype=dtype, name='conv0') x = normalization.VirtualBatchNorm( x, use_running_average=not train, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon, name='init_bn', axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype, virtual_batch_size=virtual_batch_size, data_format=data_format) x = nn.relu(x) # MLPerf-required x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock( x, num_filters * 2 ** i, strides=strides, train=train, axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype, batch_norm_momentum=batch_norm_momentum, batch_norm_epsilon=batch_norm_epsilon, bn_output_scale=bn_output_scale, virtual_batch_size=virtual_batch_size, data_format=data_format) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(), dtype=dtype) return x
def maybe_normalize(name): if self.use_bn: return normalization.VirtualBatchNorm( momentum=self.batch_norm_momentum, epsilon=self.batch_norm_epsilon, dtype=self.dtype, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, data_format=self.data_format, name=name) else: return lambda x, **kwargs: x
def __call__(self, x, train): if self.num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[self.num_layers] x = nn.Conv(self.num_filters, (7, 7), (2, 2), use_bias=False, dtype=self.dtype, name='init_conv')(x) if self.use_bn: x = normalization.VirtualBatchNorm( momentum=self.batch_norm_momentum, epsilon=self.batch_norm_epsilon, dtype=self.dtype, name='init_bn', batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, data_format=self.data_format)(x, use_running_average=not train) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') if self.block_type == 'post_activation': residual_block = ResidualBlock elif self.block_type == 'pre_activation': residual_block = PreActResidualBlock else: raise ValueError('Invalid Block Type: {}'.format(self.block_type)) index = 0 for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) index += 1 x = residual_block( self.num_filters * 2**i, strides=strides, dtype=self.dtype, batch_norm_momentum=self.batch_norm_momentum, batch_norm_epsilon=self.batch_norm_epsilon, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, data_format=self.data_format, bn_relu_conv=self.bn_relu_conv, use_bn=self.use_bn, activation_function=self.activation_function)(x, train=train) x = jnp.mean(x, axis=(1, 2)) if self.dropout_rate > 0.0: x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) x = nn.Dense(self.num_outputs, dtype=self.dtype)(x) return x
def apply(self, x, num_outputs, num_filters=64, num_layers=50, train=True, batch_stats=None, dtype=jnp.float32, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, virtual_batch_size=None, data_format=None): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] x = nn.Conv(x, num_filters, (3, 3), (1, 1), 'SAME', bias=False, dtype=dtype, name='init_conv') x = normalization.VirtualBatchNorm( x, batch_stats=batch_stats, use_running_average=not train, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon, dtype=dtype, name='init_bn', virtual_batch_size=virtual_batch_size, data_format=data_format) x = nn.relu(x) residual_block = block_type_options[num_layers] for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = residual_block(x, num_filters * 2**i, strides=strides, train=train, batch_stats=batch_stats, dtype=dtype, batch_norm_momentum=batch_norm_momentum, batch_norm_epsilon=batch_norm_epsilon, virtual_batch_size=virtual_batch_size, data_format=data_format) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_outputs, dtype=dtype) return x
def __call__(self, x, train): if self.num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[self.num_layers] conv = functools.partial(nn.Conv, padding=[(3, 3), (3, 3)]) x = conv(self.num_filters, kernel_size=(7, 7), strides=(2, 2), use_bias=False, dtype=self.dtype, name='conv0')(x) x = normalization.VirtualBatchNorm( momentum=self.batch_norm_momentum, epsilon=self.batch_norm_epsilon, name='init_bn', axis_name=self.axis_name, axis_index_groups=self.axis_index_groups, dtype=self.dtype, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, data_format=self.data_format)(x, use_running_average=not train) x = nn.relu(x) # MLPerf-required x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock(self.num_filters * 2**i, strides=strides, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups, dtype=self.dtype, batch_norm_momentum=self.batch_norm_momentum, batch_norm_epsilon=self.batch_norm_epsilon, bn_output_scale=self.bn_output_scale, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, data_format=self.data_format)(x, train=train) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype)(x) return x
def test_different_eval_batch_size(self): """Test virtual BN can use a different batch size for evals.""" rng = jax.random.PRNGKey(0) batch_size = 10 feature_size = 7 input_shape = (batch_size, 3, 3, feature_size) x = 2.0 * jnp.ones(input_shape) vbn_flax_module = normalization.VirtualBatchNorm( momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC') vbn_params, vbn_state = _init(vbn_flax_module, rng, input_shape) _, vbn_state = vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, x, mutable=['batch_stats'], use_running_average=False) vbn_state = vbn_state['batch_stats'] vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, jnp.ones((13, 3, 3, feature_size)), use_running_average=True)
def __call__(self, x, train): if self.num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[self.num_layers] x = nn.Conv( self.num_filters, (3, 3), (1, 1), 'SAME', use_bias=False, dtype=self.dtype, name='init_conv')(x) x = normalization.VirtualBatchNorm( momentum=self.batch_norm_momentum, epsilon=self.batch_norm_epsilon, dtype=self.dtype, name='init_bn', batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, data_format=self.data_format)(x, use_running_average=not train) x = nn.relu(x) residual_block = block_type_options[self.num_layers] for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = residual_block( self.num_filters * 2**i, strides=strides, dtype=self.dtype, batch_norm_momentum=self.batch_norm_momentum, batch_norm_epsilon=self.batch_norm_epsilon, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, data_format=self.data_format)(x, train=train) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(self.num_outputs, dtype=self.dtype)(x) return x
def test_forward_pass(self): """Test that two calls are the same as one with twice the batch size.""" rng = jax.random.PRNGKey(0) batch_size = 10 feature_size = 7 input_shape = (batch_size, 3, 3, feature_size) half_input_shape = (batch_size // 2, 3, 3, feature_size) twos = 2.0 * jnp.ones(half_input_shape) threes = 3.0 * jnp.ones(half_input_shape) fives = 5.0 * jnp.ones(half_input_shape) nines = 9.0 * jnp.ones(half_input_shape) # The mean(x1) = 2.5, stddev(x1) = 0.5 so we expect # `(x1 - mean(x1)) / stddev(x1)` to be half -1.0, then half 1.0. x1 = jnp.concatenate((twos, threes)) # The mean(x2) = 7.0, stddev(x2) = 2.0 so we expect # `(x2 - mean(x2)) / stddev(x2)` to be half -1.0, then half 1.0. x2 = jnp.concatenate((fives, nines)) x_both = jnp.concatenate((x1, x2)) expected_bn_y = jnp.concatenate( (jnp.ones(half_input_shape) * -1.0, jnp.ones(half_input_shape))) bn_flax_module = nn.BatchNorm(momentum=0.9) bn_params, bn_state = _init(bn_flax_module, rng, input_shape) vbn_flax_module = normalization.VirtualBatchNorm( momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC') vbn_params, vbn_state = _init(vbn_flax_module, rng, input_shape) bn_y1, _ = bn_flax_module.apply( {'params': bn_params, 'batch_stats': bn_state}, x1, mutable=['batch_stats'], use_running_average=False) bn_y2, _ = bn_flax_module.apply( {'params': bn_params, 'batch_stats': bn_state}, x2, mutable=['batch_stats'], use_running_average=False) bn_y_both, _ = bn_flax_module.apply( {'params': bn_params, 'batch_stats': bn_state}, x_both, mutable=['batch_stats'], use_running_average=False) vbn_y_both, vbn_state = vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, x_both, mutable=['batch_stats'], use_running_average=False) vbn_state = vbn_state['batch_stats'] # Test that the layer forward passes behave as expected. np.testing.assert_allclose(bn_y1, expected_bn_y, atol=1e-4) np.testing.assert_allclose(bn_y2, expected_bn_y, atol=1e-4) np.testing.assert_allclose( vbn_y_both, jnp.concatenate((bn_y1, bn_y2)), atol=1e-4) # Test that the virtual batch norm and nn.BatchNorm layers do not perform # the same calculation on the concatenated batch. # There is no negative of `np.testing.assert_allclose` so we test that the # diff is greater than zero. np.testing.assert_array_less( -jnp.abs(vbn_y_both - bn_y_both), jnp.zeros_like(vbn_y_both)) _, vbn_state = vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, x_both, mutable=['batch_stats'], use_running_average=False) vbn_state = vbn_state['batch_stats'] # The mean running average stats at 0.0, and the variance starts at 1.0. So # after two applications of the same batch we should expect the value to be # mean_ema = 0.9 * (0.9 * 0.0 + 0.1 * mean) + 0.1 * mean = 0.19 * mean # var_ema = 0.9 * (0.9 * 1.0 + 0.1 * var) + 0.1 * var = 0.19 * var + 0.81 expected_mean_ema_x1 = 0.19 * jnp.mean(x1) * jnp.ones((feature_size,)) expected_mean_ema_x2 = 0.19 * jnp.mean(x2) * jnp.ones((feature_size,)) expected_mean_ema_both = (expected_mean_ema_x1 + expected_mean_ema_x2) / 2.0 expected_var_ema_both = ( (0.19 * jnp.std(jnp.concatenate((x1, x2))) ** 2.0 + 0.81) * jnp.ones((feature_size,))) np.testing.assert_allclose( np.squeeze(vbn_state['batch_norm_running_mean']), expected_mean_ema_both, atol=1e-4) np.testing.assert_allclose( np.squeeze(vbn_state['batch_norm_running_var']), expected_var_ema_both, atol=1e-4)