Exemple #1
0
    def __call__(self, inputs, is_training):
        dwc_layer = depthwise_conv.DepthwiseConv2D(1,
                                                   3,
                                                   stride=self._stride,
                                                   padding=((1, 1), (1, 1)),
                                                   with_bias=self._with_bias,
                                                   name="depthwise_conv")
        pwc_layer = conv.Conv2D(self._channels, (1, 1),
                                stride=1,
                                padding="VALID",
                                with_bias=self._with_bias,
                                name="pointwise_conv")

        net = inputs
        net = dwc_layer(net)
        if self._use_bn:
            net = batch_norm.BatchNorm(create_scale=True,
                                       create_offset=True)(net, is_training)
        net = jax.nn.relu(net)
        net = pwc_layer(net)
        if self._use_bn:
            net = batch_norm.BatchNorm(create_scale=True,
                                       create_offset=True)(net, is_training)
        net = jax.nn.relu(net)
        return net
Exemple #2
0
  def __init__(self,
               channels: int,
               stride: Union[int, Sequence[int]],
               use_projection: bool,
               bn_config: Mapping[Text, float],
               name: Optional[Text] = None):
    super(BottleNeckBlockV2, self).__init__(name=name)
    self._channels = channels
    self._stride = stride
    self._use_projection = use_projection
    self._bn_config = bn_config

    batchnorm_args = {"create_scale": True, "create_offset": True}
    batchnorm_args.update(bn_config)

    if self._use_projection:
      self._proj_conv = conv.Conv2D(
          output_channels=channels,
          kernel_shape=1,
          stride=stride,
          with_bias=False,
          padding="SAME",
          name="shortcut_conv")

    self._conv_0 = conv.Conv2D(
        output_channels=channels // 4,
        kernel_shape=1,
        stride=1,
        with_bias=False,
        padding="SAME",
        name="conv_0")

    self._bn_0 = batch_norm.BatchNorm(name="batchnorm_0", **batchnorm_args)

    self._conv_1 = conv.Conv2D(
        output_channels=channels // 4,
        kernel_shape=3,
        stride=stride,
        with_bias=False,
        padding="SAME",
        name="conv_1")

    self._bn_1 = batch_norm.BatchNorm(name="batchnorm_1", **batchnorm_args)

    self._conv_2 = conv.Conv2D(
        output_channels=channels,
        kernel_shape=1,
        stride=1,
        with_bias=False,
        padding="SAME",
        name="conv_2")

    # NOTE: Some implementations of ResNet50 v2 suggest initializing gamma/scale
    # here to zeros.
    self._bn_2 = batch_norm.BatchNorm(name="batchnorm_2", **batchnorm_args)
Exemple #3
0
    def __init__(self,
                 channels: int,
                 stride: Union[int, Sequence[int]],
                 use_projection: bool,
                 bn_config: Mapping[str, float],
                 name: Optional[str] = None):
        super().__init__(name=name)
        self._use_projection = use_projection

        bn_config = dict(bn_config)
        bn_config.setdefault("create_scale", True)
        bn_config.setdefault("create_offset", True)
        bn_config.setdefault("decay_rate", 0.999)

        if self._use_projection:
            self._proj_conv = conv.Conv2D(output_channels=channels,
                                          kernel_shape=1,
                                          stride=stride,
                                          with_bias=False,
                                          padding="SAME",
                                          name="shortcut_conv")

            self._proj_batchnorm = batch_norm.BatchNorm(
                name="shortcut_batchnorm", **bn_config)

        conv_0 = conv.Conv2D(output_channels=channels // 4,
                             kernel_shape=1,
                             stride=1,
                             with_bias=False,
                             padding="SAME",
                             name="conv_0")
        bn_0 = batch_norm.BatchNorm(name="batchnorm_0", **bn_config)

        conv_1 = conv.Conv2D(output_channels=channels // 4,
                             kernel_shape=3,
                             stride=stride,
                             with_bias=False,
                             padding="SAME",
                             name="conv_1")

        bn_1 = batch_norm.BatchNorm(name="batchnorm_1", **bn_config)

        conv_2 = conv.Conv2D(output_channels=channels,
                             kernel_shape=1,
                             stride=1,
                             with_bias=False,
                             padding="SAME",
                             name="conv_2")

        bn_2 = batch_norm.BatchNorm(name="batchnorm_2",
                                    scale_init=jnp.zeros,
                                    **bn_config)

        self._layers = ((conv_0, bn_0), (conv_1, bn_1), (conv_2, bn_2))
Exemple #4
0
    def __init__(self,
                 channels: int,
                 stride: Union[int, Sequence[int]],
                 use_projection: bool,
                 bn_config: Mapping[str, float],
                 name: Optional[str] = None):
        super().__init__(name=name)
        self._use_projection = use_projection

        bn_config = dict(bn_config)
        bn_config.setdefault("create_scale", True)
        bn_config.setdefault("create_offset", True)

        if self._use_projection:
            self._proj_conv = conv.Conv2D(output_channels=channels,
                                          kernel_shape=1,
                                          stride=stride,
                                          with_bias=False,
                                          padding="SAME",
                                          name="shortcut_conv")

        conv_0 = conv.Conv2D(output_channels=channels // 4,
                             kernel_shape=1,
                             stride=1,
                             with_bias=False,
                             padding="SAME",
                             name="conv_0")

        bn_0 = batch_norm.BatchNorm(name="batchnorm_0", **bn_config)

        conv_1 = conv.Conv2D(output_channels=channels // 4,
                             kernel_shape=3,
                             stride=stride,
                             with_bias=False,
                             padding="SAME",
                             name="conv_1")

        bn_1 = batch_norm.BatchNorm(name="batchnorm_1", **bn_config)

        conv_2 = conv.Conv2D(output_channels=channels,
                             kernel_shape=1,
                             stride=1,
                             with_bias=False,
                             padding="SAME",
                             name="conv_2")

        # NOTE: Some implementations of ResNet50 v2 suggest initializing gamma/scale
        # here to zeros.
        bn_2 = batch_norm.BatchNorm(name="batchnorm_2", **bn_config)

        self._layers = ((conv_0, bn_0), (conv_1, bn_1), (conv_2, bn_2))
Exemple #5
0
 def test_no_offset_beta_init_provided(self):
     with self.assertRaisesRegex(
             ValueError,
             "Cannot set `offset_init` if `create_offset=False`"):
         batch_norm.BatchNorm(create_scale=True,
                              create_offset=False,
                              offset_init=jnp.zeros)
Exemple #6
0
  def test_no_scale_and_offset(self):
    layer = batch_norm.BatchNorm(
        create_scale=False, create_offset=False, decay_rate=0.9)

    inputs = jnp.ones([2, 5, 3, 3, 3])
    result = layer(inputs, True)
    np.testing.assert_equal(result, np.zeros_like(inputs))
Exemple #7
0
 def f(x, is_training=True):
   return batch_norm.BatchNorm(
       create_scale=False,
       create_offset=False,
       decay_rate=0.9,
       cross_replica_axis="i",
   )(x, is_training=is_training)
Exemple #8
0
 def test_no_scale_and_init_provided(self):
     with self.assertRaisesRegex(
             ValueError, "Cannot set `scale_init` if `create_scale=False`"):
         batch_norm.BatchNorm(create_scale=False,
                              create_offset=True,
                              decay_rate=0.9,
                              scale_init=jnp.ones)
Exemple #9
0
    def test_simple_training(self):
        layer = batch_norm.BatchNorm(create_scale=False, create_offset=False)

        inputs = np.ones([2, 3, 3, 5])
        scale = np.full((5, ), 0.5)
        offset = np.full((5, ), 2.0)

        result = layer(inputs, True, scale=scale, offset=offset)
        np.testing.assert_equal(result, np.full(inputs.shape, 2.0))
Exemple #10
0
  def test_basic(self):
    data = jnp.arange(2 * 3 * 4, dtype=jnp.float32).reshape([2, 3, 4])

    norm = batch_norm.BatchNorm(True, True, 0.9)
    result = norm(data, is_training=True)
    result_0_replicated = jnp.broadcast_to(result[:, :, :1], result.shape)
    # Input data is symmetrical variance per-channel.
    np.testing.assert_allclose(result, result_0_replicated)
    # Running through again in test mode produces same output.
    np.testing.assert_allclose(norm(data, is_training=False), result, rtol=2e-2)
Exemple #11
0
    def test_simple_training_nchw(self):
        layer = batch_norm.BatchNorm(create_scale=False,
                                     create_offset=False,
                                     data_format="NCHW")

        inputs = np.ones([2, 5, 3, 3])
        scale = np.full((5, 1, 1), 0.5)
        offset = np.full((5, 1, 1), 2.0)

        result = layer(inputs, True, scale=scale, offset=offset)
        np.testing.assert_equal(result, np.full(inputs.shape, 2.0))
Exemple #12
0
    def test_simple_training_normalized_axes(self):
        layer = batch_norm.BatchNorm(create_scale=False,
                                     create_offset=False,
                                     axis=[0, 2, 3])  # Not the second axis.

        # This differs only in the second axis.
        inputs = np.stack([2.0 * np.ones([5, 3, 3]), np.ones([5, 3, 3])], 1)

        result = layer(inputs, True)

        # Despite not all values being identical, treating slices from the first
        # axis separately leads to a fully normalized = equal array.
        np.testing.assert_equal(result, np.zeros(inputs.shape))
Exemple #13
0
 def __call__(self, inputs, is_training):
     initial_conv = conv.Conv2D(32, (3, 3),
                                stride=2,
                                padding="VALID",
                                with_bias=self._with_bias)
     net = initial_conv(inputs)
     if self._use_bn:
         net = batch_norm.BatchNorm(create_scale=True,
                                    create_offset=True)(net, is_training)
     net = jax.nn.relu(net)
     for i in range(len(self._strides)):
         net = MobileNetV1Block(self._channels[i], self._strides[i],
                                self._use_bn)(net, is_training)
     net = jnp.mean(net, axis=(1, 2))
     net = reshape.Flatten()(net)
     net = basic.Linear(self._num_classes, name="logits")(net)
     return net
Exemple #14
0
 def f(x):
   net = batch_norm.BatchNorm(True, True, 0.9)
   return net(x, is_training, test_local_stats)
Exemple #15
0
  def __init__(self,
               blocks_per_group_list: Sequence[int],
               num_classes: int,
               bn_config: Optional[Mapping[Text, float]] = None,
               resnet_v2: bool = False,
               channels_per_group_list: Sequence[int] = (256, 512, 1024, 2048),
               name: Optional[Text] = None):
    """Constructs a ResNet model.

    Args:
      blocks_per_group_list: A sequence of length 4 that indicates the number of
        blocks created in each group.
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers. By default the `decay_rate` is
        `0.9` and `eps` is `1e-5`.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
        False.
      channels_per_group_list: A sequence of length 4 that indicates the number
        of channels used for each block in each group.
      name: Name of the module.
    """
    super(ResNet, self).__init__(name=name)
    if bn_config is None:
      bn_config = {"decay_rate": 0.9, "eps": 1e-5}
    self._bn_config = bn_config
    self._resnet_v2 = resnet_v2

    # Number of blocks in each group for ResNet.
    if len(blocks_per_group_list) != 4:
      raise ValueError(
          "`blocks_per_group_list` must be of length 4 not {}".format(
              len(blocks_per_group_list)))
    self._blocks_per_group_list = blocks_per_group_list

    # Number of channels in each group for ResNet.
    if len(channels_per_group_list) != 4:
      raise ValueError(
          "`channels_per_group_list` must be of length 4 not {}".format(
              len(channels_per_group_list)))
    self._channels_per_group_list = channels_per_group_list

    self._initial_conv = conv.Conv2D(
        output_channels=64,
        kernel_shape=7,
        stride=2,
        with_bias=False,
        padding="SAME",
        name="initial_conv")
    if not self._resnet_v2:
      self._initial_batchnorm = batch_norm.BatchNorm(
          create_scale=True,
          create_offset=True,
          name="initial_batchnorm",
          **bn_config)

    self._block_groups = []
    strides = [1, 2, 2, 2]
    for i in range(4):
      self._block_groups.append(
          BlockGroup(
              channels=self._channels_per_group_list[i],
              num_blocks=self._blocks_per_group_list[i],
              stride=strides[i],
              bn_config=bn_config,
              resnet_v2=resnet_v2,
              name="block_group_%d" % (i)))

    if self._resnet_v2:
      self._final_batchnorm = batch_norm.BatchNorm(
          create_scale=True,
          create_offset=True,
          name="final_batchnorm",
          **bn_config)

    self._logits = basic.Linear(
        output_size=num_classes, w_init=jnp.zeros, name="logits")
Exemple #16
0
  def __init__(self,
               channels: int,
               stride: Union[int, Sequence[int]],
               use_projection: bool,
               bn_config: Mapping[Text, float],
               name: Optional[Text] = None):
    super(BottleNeckBlockV1, self).__init__(name=name)
    self._channels = channels
    self._stride = stride
    self._use_projection = use_projection
    self._bn_config = bn_config

    batchnorm_args = {
        "create_scale": True,
        "create_offset": True,
        "decay_rate": 0.999,
    }
    batchnorm_args.update(bn_config)

    if self._use_projection:
      self._proj_conv = conv.Conv2D(
          output_channels=channels,
          kernel_shape=1,
          stride=stride,
          with_bias=False,
          padding="SAME",
          name="shortcut_conv")
      self._proj_batchnorm = batch_norm.BatchNorm(
          name="shortcut_batchnorm", **batchnorm_args)

    self._layers = []
    conv_0 = conv.Conv2D(
        output_channels=channels // 4,
        kernel_shape=1,
        stride=1,
        with_bias=False,
        padding="SAME",
        name="conv_0")
    self._layers.append(
        [conv_0,
         batch_norm.BatchNorm(name="batchnorm_0", **batchnorm_args)])

    conv_1 = conv.Conv2D(
        output_channels=channels // 4,
        kernel_shape=3,
        stride=stride,
        with_bias=False,
        padding="SAME",
        name="conv_1")
    self._layers.append(
        [conv_1,
         batch_norm.BatchNorm(name="batchnorm_1", **batchnorm_args)])

    conv_2 = conv.Conv2D(
        output_channels=channels,
        kernel_shape=1,
        stride=1,
        with_bias=False,
        padding="SAME",
        name="conv_2")
    batchnorm_2 = batch_norm.BatchNorm(
        name="batchnorm_2", scale_init=jnp.zeros, **batchnorm_args)
    self._layers.append([conv_2, batchnorm_2])
Exemple #17
0
 def f(x, is_training):
   return batch_norm.BatchNorm(True, True, 0.9, eps=0.1)(x, is_training)
Exemple #18
0
 def get_batch_norm():
     return batch_norm.BatchNorm(create_scale=True,
                                 create_offset=True,
                                 decay_rate=0.99)
Exemple #19
0
    def __init__(self,
                 blocks_per_group: Sequence[int],
                 num_classes: int,
                 bn_config: Optional[Mapping[str, float]] = None,
                 resnet_v2: bool = False,
                 channels_per_group: Sequence[int] = (256, 512, 1024, 2048),
                 name: Optional[str] = None):
        """Constructs a ResNet model.

    Args:
      blocks_per_group: A sequence of length 4 that indicates the number of
        blocks created in each group.
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers. By default the `decay_rate` is
        `0.9` and `eps` is `1e-5`.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
        False.
      channels_per_group: A sequence of length 4 that indicates the number
        of channels used for each block in each group.
      name: Name of the module.
    """
        super().__init__(name=name)
        self._resnet_v2 = resnet_v2

        bn_config = dict(bn_config or {})
        bn_config.setdefault("decay_rate", 0.9)
        bn_config.setdefault("eps", 1e-5)
        bn_config.setdefault("create_scale", True)
        bn_config.setdefault("create_offset", True)

        # Number of blocks in each group for ResNet.
        check_length(4, blocks_per_group, "blocks_per_group")
        check_length(4, channels_per_group, "channels_per_group")

        self._initial_conv = conv.Conv2D(output_channels=64,
                                         kernel_shape=7,
                                         stride=2,
                                         with_bias=False,
                                         padding="SAME",
                                         name="initial_conv")

        if not self._resnet_v2:
            self._initial_batchnorm = batch_norm.BatchNorm(
                name="initial_batchnorm", **bn_config)

        self._block_groups = []
        strides = (1, 2, 2, 2)
        for i in range(4):
            self._block_groups.append(
                BlockGroup(channels=channels_per_group[i],
                           num_blocks=blocks_per_group[i],
                           stride=strides[i],
                           bn_config=bn_config,
                           resnet_v2=resnet_v2,
                           name="block_group_%d" % (i)))

        if self._resnet_v2:
            self._final_batchnorm = batch_norm.BatchNorm(
                name="final_batchnorm", **bn_config)

        self._logits = basic.Linear(num_classes,
                                    w_init=jnp.zeros,
                                    name="logits")