def __call__(self, inputs, is_training): dwc_layer = depthwise_conv.DepthwiseConv2D(1, 3, stride=self._stride, padding=((1, 1), (1, 1)), with_bias=self._with_bias, name="depthwise_conv") pwc_layer = conv.Conv2D(self._channels, (1, 1), stride=1, padding="VALID", with_bias=self._with_bias, name="pointwise_conv") net = inputs net = dwc_layer(net) if self._use_bn: net = batch_norm.BatchNorm(create_scale=True, create_offset=True)(net, is_training) net = jax.nn.relu(net) net = pwc_layer(net) if self._use_bn: net = batch_norm.BatchNorm(create_scale=True, create_offset=True)(net, is_training) net = jax.nn.relu(net) return net
def __init__(self, channels: int, stride: Union[int, Sequence[int]], use_projection: bool, bn_config: Mapping[Text, float], name: Optional[Text] = None): super(BottleNeckBlockV2, self).__init__(name=name) self._channels = channels self._stride = stride self._use_projection = use_projection self._bn_config = bn_config batchnorm_args = {"create_scale": True, "create_offset": True} batchnorm_args.update(bn_config) if self._use_projection: self._proj_conv = conv.Conv2D( output_channels=channels, kernel_shape=1, stride=stride, with_bias=False, padding="SAME", name="shortcut_conv") self._conv_0 = conv.Conv2D( output_channels=channels // 4, kernel_shape=1, stride=1, with_bias=False, padding="SAME", name="conv_0") self._bn_0 = batch_norm.BatchNorm(name="batchnorm_0", **batchnorm_args) self._conv_1 = conv.Conv2D( output_channels=channels // 4, kernel_shape=3, stride=stride, with_bias=False, padding="SAME", name="conv_1") self._bn_1 = batch_norm.BatchNorm(name="batchnorm_1", **batchnorm_args) self._conv_2 = conv.Conv2D( output_channels=channels, kernel_shape=1, stride=1, with_bias=False, padding="SAME", name="conv_2") # NOTE: Some implementations of ResNet50 v2 suggest initializing gamma/scale # here to zeros. self._bn_2 = batch_norm.BatchNorm(name="batchnorm_2", **batchnorm_args)
def __init__(self, channels: int, stride: Union[int, Sequence[int]], use_projection: bool, bn_config: Mapping[str, float], name: Optional[str] = None): super().__init__(name=name) self._use_projection = use_projection bn_config = dict(bn_config) bn_config.setdefault("create_scale", True) bn_config.setdefault("create_offset", True) bn_config.setdefault("decay_rate", 0.999) if self._use_projection: self._proj_conv = conv.Conv2D(output_channels=channels, kernel_shape=1, stride=stride, with_bias=False, padding="SAME", name="shortcut_conv") self._proj_batchnorm = batch_norm.BatchNorm( name="shortcut_batchnorm", **bn_config) conv_0 = conv.Conv2D(output_channels=channels // 4, kernel_shape=1, stride=1, with_bias=False, padding="SAME", name="conv_0") bn_0 = batch_norm.BatchNorm(name="batchnorm_0", **bn_config) conv_1 = conv.Conv2D(output_channels=channels // 4, kernel_shape=3, stride=stride, with_bias=False, padding="SAME", name="conv_1") bn_1 = batch_norm.BatchNorm(name="batchnorm_1", **bn_config) conv_2 = conv.Conv2D(output_channels=channels, kernel_shape=1, stride=1, with_bias=False, padding="SAME", name="conv_2") bn_2 = batch_norm.BatchNorm(name="batchnorm_2", scale_init=jnp.zeros, **bn_config) self._layers = ((conv_0, bn_0), (conv_1, bn_1), (conv_2, bn_2))
def __init__(self, channels: int, stride: Union[int, Sequence[int]], use_projection: bool, bn_config: Mapping[str, float], name: Optional[str] = None): super().__init__(name=name) self._use_projection = use_projection bn_config = dict(bn_config) bn_config.setdefault("create_scale", True) bn_config.setdefault("create_offset", True) if self._use_projection: self._proj_conv = conv.Conv2D(output_channels=channels, kernel_shape=1, stride=stride, with_bias=False, padding="SAME", name="shortcut_conv") conv_0 = conv.Conv2D(output_channels=channels // 4, kernel_shape=1, stride=1, with_bias=False, padding="SAME", name="conv_0") bn_0 = batch_norm.BatchNorm(name="batchnorm_0", **bn_config) conv_1 = conv.Conv2D(output_channels=channels // 4, kernel_shape=3, stride=stride, with_bias=False, padding="SAME", name="conv_1") bn_1 = batch_norm.BatchNorm(name="batchnorm_1", **bn_config) conv_2 = conv.Conv2D(output_channels=channels, kernel_shape=1, stride=1, with_bias=False, padding="SAME", name="conv_2") # NOTE: Some implementations of ResNet50 v2 suggest initializing gamma/scale # here to zeros. bn_2 = batch_norm.BatchNorm(name="batchnorm_2", **bn_config) self._layers = ((conv_0, bn_0), (conv_1, bn_1), (conv_2, bn_2))
def test_no_offset_beta_init_provided(self): with self.assertRaisesRegex( ValueError, "Cannot set `offset_init` if `create_offset=False`"): batch_norm.BatchNorm(create_scale=True, create_offset=False, offset_init=jnp.zeros)
def test_no_scale_and_offset(self): layer = batch_norm.BatchNorm( create_scale=False, create_offset=False, decay_rate=0.9) inputs = jnp.ones([2, 5, 3, 3, 3]) result = layer(inputs, True) np.testing.assert_equal(result, np.zeros_like(inputs))
def f(x, is_training=True): return batch_norm.BatchNorm( create_scale=False, create_offset=False, decay_rate=0.9, cross_replica_axis="i", )(x, is_training=is_training)
def test_no_scale_and_init_provided(self): with self.assertRaisesRegex( ValueError, "Cannot set `scale_init` if `create_scale=False`"): batch_norm.BatchNorm(create_scale=False, create_offset=True, decay_rate=0.9, scale_init=jnp.ones)
def test_simple_training(self): layer = batch_norm.BatchNorm(create_scale=False, create_offset=False) inputs = np.ones([2, 3, 3, 5]) scale = np.full((5, ), 0.5) offset = np.full((5, ), 2.0) result = layer(inputs, True, scale=scale, offset=offset) np.testing.assert_equal(result, np.full(inputs.shape, 2.0))
def test_basic(self): data = jnp.arange(2 * 3 * 4, dtype=jnp.float32).reshape([2, 3, 4]) norm = batch_norm.BatchNorm(True, True, 0.9) result = norm(data, is_training=True) result_0_replicated = jnp.broadcast_to(result[:, :, :1], result.shape) # Input data is symmetrical variance per-channel. np.testing.assert_allclose(result, result_0_replicated) # Running through again in test mode produces same output. np.testing.assert_allclose(norm(data, is_training=False), result, rtol=2e-2)
def test_simple_training_nchw(self): layer = batch_norm.BatchNorm(create_scale=False, create_offset=False, data_format="NCHW") inputs = np.ones([2, 5, 3, 3]) scale = np.full((5, 1, 1), 0.5) offset = np.full((5, 1, 1), 2.0) result = layer(inputs, True, scale=scale, offset=offset) np.testing.assert_equal(result, np.full(inputs.shape, 2.0))
def test_simple_training_normalized_axes(self): layer = batch_norm.BatchNorm(create_scale=False, create_offset=False, axis=[0, 2, 3]) # Not the second axis. # This differs only in the second axis. inputs = np.stack([2.0 * np.ones([5, 3, 3]), np.ones([5, 3, 3])], 1) result = layer(inputs, True) # Despite not all values being identical, treating slices from the first # axis separately leads to a fully normalized = equal array. np.testing.assert_equal(result, np.zeros(inputs.shape))
def __call__(self, inputs, is_training): initial_conv = conv.Conv2D(32, (3, 3), stride=2, padding="VALID", with_bias=self._with_bias) net = initial_conv(inputs) if self._use_bn: net = batch_norm.BatchNorm(create_scale=True, create_offset=True)(net, is_training) net = jax.nn.relu(net) for i in range(len(self._strides)): net = MobileNetV1Block(self._channels[i], self._strides[i], self._use_bn)(net, is_training) net = jnp.mean(net, axis=(1, 2)) net = reshape.Flatten()(net) net = basic.Linear(self._num_classes, name="logits")(net) return net
def f(x): net = batch_norm.BatchNorm(True, True, 0.9) return net(x, is_training, test_local_stats)
def __init__(self, blocks_per_group_list: Sequence[int], num_classes: int, bn_config: Optional[Mapping[Text, float]] = None, resnet_v2: bool = False, channels_per_group_list: Sequence[int] = (256, 512, 1024, 2048), name: Optional[Text] = None): """Constructs a ResNet model. Args: blocks_per_group_list: A sequence of length 4 that indicates the number of blocks created in each group. num_classes: The number of classes to classify the inputs into. bn_config: A dictionary of two elements, `decay_rate` and `eps` to be passed on to the `BatchNorm` layers. By default the `decay_rate` is `0.9` and `eps` is `1e-5`. resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to False. channels_per_group_list: A sequence of length 4 that indicates the number of channels used for each block in each group. name: Name of the module. """ super(ResNet, self).__init__(name=name) if bn_config is None: bn_config = {"decay_rate": 0.9, "eps": 1e-5} self._bn_config = bn_config self._resnet_v2 = resnet_v2 # Number of blocks in each group for ResNet. if len(blocks_per_group_list) != 4: raise ValueError( "`blocks_per_group_list` must be of length 4 not {}".format( len(blocks_per_group_list))) self._blocks_per_group_list = blocks_per_group_list # Number of channels in each group for ResNet. if len(channels_per_group_list) != 4: raise ValueError( "`channels_per_group_list` must be of length 4 not {}".format( len(channels_per_group_list))) self._channels_per_group_list = channels_per_group_list self._initial_conv = conv.Conv2D( output_channels=64, kernel_shape=7, stride=2, with_bias=False, padding="SAME", name="initial_conv") if not self._resnet_v2: self._initial_batchnorm = batch_norm.BatchNorm( create_scale=True, create_offset=True, name="initial_batchnorm", **bn_config) self._block_groups = [] strides = [1, 2, 2, 2] for i in range(4): self._block_groups.append( BlockGroup( channels=self._channels_per_group_list[i], num_blocks=self._blocks_per_group_list[i], stride=strides[i], bn_config=bn_config, resnet_v2=resnet_v2, name="block_group_%d" % (i))) if self._resnet_v2: self._final_batchnorm = batch_norm.BatchNorm( create_scale=True, create_offset=True, name="final_batchnorm", **bn_config) self._logits = basic.Linear( output_size=num_classes, w_init=jnp.zeros, name="logits")
def __init__(self, channels: int, stride: Union[int, Sequence[int]], use_projection: bool, bn_config: Mapping[Text, float], name: Optional[Text] = None): super(BottleNeckBlockV1, self).__init__(name=name) self._channels = channels self._stride = stride self._use_projection = use_projection self._bn_config = bn_config batchnorm_args = { "create_scale": True, "create_offset": True, "decay_rate": 0.999, } batchnorm_args.update(bn_config) if self._use_projection: self._proj_conv = conv.Conv2D( output_channels=channels, kernel_shape=1, stride=stride, with_bias=False, padding="SAME", name="shortcut_conv") self._proj_batchnorm = batch_norm.BatchNorm( name="shortcut_batchnorm", **batchnorm_args) self._layers = [] conv_0 = conv.Conv2D( output_channels=channels // 4, kernel_shape=1, stride=1, with_bias=False, padding="SAME", name="conv_0") self._layers.append( [conv_0, batch_norm.BatchNorm(name="batchnorm_0", **batchnorm_args)]) conv_1 = conv.Conv2D( output_channels=channels // 4, kernel_shape=3, stride=stride, with_bias=False, padding="SAME", name="conv_1") self._layers.append( [conv_1, batch_norm.BatchNorm(name="batchnorm_1", **batchnorm_args)]) conv_2 = conv.Conv2D( output_channels=channels, kernel_shape=1, stride=1, with_bias=False, padding="SAME", name="conv_2") batchnorm_2 = batch_norm.BatchNorm( name="batchnorm_2", scale_init=jnp.zeros, **batchnorm_args) self._layers.append([conv_2, batchnorm_2])
def f(x, is_training): return batch_norm.BatchNorm(True, True, 0.9, eps=0.1)(x, is_training)
def get_batch_norm(): return batch_norm.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.99)
def __init__(self, blocks_per_group: Sequence[int], num_classes: int, bn_config: Optional[Mapping[str, float]] = None, resnet_v2: bool = False, channels_per_group: Sequence[int] = (256, 512, 1024, 2048), name: Optional[str] = None): """Constructs a ResNet model. Args: blocks_per_group: A sequence of length 4 that indicates the number of blocks created in each group. num_classes: The number of classes to classify the inputs into. bn_config: A dictionary of two elements, `decay_rate` and `eps` to be passed on to the `BatchNorm` layers. By default the `decay_rate` is `0.9` and `eps` is `1e-5`. resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to False. channels_per_group: A sequence of length 4 that indicates the number of channels used for each block in each group. name: Name of the module. """ super().__init__(name=name) self._resnet_v2 = resnet_v2 bn_config = dict(bn_config or {}) bn_config.setdefault("decay_rate", 0.9) bn_config.setdefault("eps", 1e-5) bn_config.setdefault("create_scale", True) bn_config.setdefault("create_offset", True) # Number of blocks in each group for ResNet. check_length(4, blocks_per_group, "blocks_per_group") check_length(4, channels_per_group, "channels_per_group") self._initial_conv = conv.Conv2D(output_channels=64, kernel_shape=7, stride=2, with_bias=False, padding="SAME", name="initial_conv") if not self._resnet_v2: self._initial_batchnorm = batch_norm.BatchNorm( name="initial_batchnorm", **bn_config) self._block_groups = [] strides = (1, 2, 2, 2) for i in range(4): self._block_groups.append( BlockGroup(channels=channels_per_group[i], num_blocks=blocks_per_group[i], stride=strides[i], bn_config=bn_config, resnet_v2=resnet_v2, name="block_group_%d" % (i))) if self._resnet_v2: self._final_batchnorm = batch_norm.BatchNorm( name="final_batchnorm", **bn_config) self._logits = basic.Linear(num_classes, w_init=jnp.zeros, name="logits")