예제 #1
0
  def apply(self, x, num_classes, *,
            stage_sizes,
            block_cls,
            num_filters=64,
            dtype=jnp.float32,
            act=nn.relu,
            train=True):
    conv = nn.Conv.partial(bias=False, dtype=dtype)
    norm = nn.BatchNorm.partial(
        use_running_average=not train,
        momentum=0.9, epsilon=1e-5,
        dtype=dtype)

    x = conv(x, num_filters, (7, 7), (2, 2),
             padding=[(3, 3), (3, 3)],
             name='conv_init')
    x = norm(x, name='bn_init')
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    for i, block_size in enumerate(stage_sizes):
      for j in range(block_size):
        strides = (2, 2) if i > 0 and j == 0 else (1, 1)
        x = block_cls(x, num_filters * 2 ** i,
                      strides=strides,
                      conv=conv,
                      norm=norm,
                      act=act)
    x = jnp.mean(x, axis=(1, 2))
    x = nn.Dense(x, num_classes, dtype=dtype)
    return x
예제 #2
0
 def apply(self,
           x,
           num_classes,
           num_filters=64,
           num_layers=50,
           train=True,
           dtype=jnp.float32):
     if num_layers not in _block_size_options:
         raise ValueError('Please provide a valid number of layers')
     block_sizes = _block_size_options[num_layers]
     x = nn.Conv(x,
                 num_filters, (7, 7), (2, 2),
                 padding=[(3, 3), (3, 3)],
                 bias=False,
                 dtype=dtype,
                 name='init_conv')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      dtype=dtype,
                      name='init_bn')
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = ResidualBlock(x,
                               num_filters * 2**i,
                               strides=strides,
                               train=train,
                               dtype=dtype)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x, num_classes)
     x = nn.log_softmax(x)
     return x
예제 #3
0
 def apply(
         self,
         x,
         num_outputs,
         num_filters=64,
         block_sizes=[3, 4, 6, 3],  # pylint: disable=dangerous-default-value
         train=True):
     x = nn.Conv(x,
                 num_filters, (7, 7), (2, 2),
                 bias=False,
                 name='init_conv')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      epsilon=1e-5,
                      name='init_bn')
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = BottleneckBlock(x,
                                 num_filters * 2**i,
                                 strides=strides,
                                 train=train)
     x = jnp.mean(x, axis=(1, 2))
     x_clf = nn.Dense(x, num_outputs, name='clf')
     # We return both the outputs from the dense layer *and* the features
     # that go into it.
     return x_clf, x
예제 #4
0
    def apply(self,
              x,
              num_filters=64,
              block_sizes=(3, 4, 6, 3),
              train=True,
              block=BottleneckBlock,
              small_inputs=False):
        if small_inputs:
            x = nn.Conv(x,
                        num_filters,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        bias=False,
                        name="init_conv")
        else:
            x = nn.Conv(x,
                        num_filters,
                        kernel_size=(7, 7),
                        strides=(2, 2),
                        bias=False,
                        name="init_conv")
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         epsilon=1e-5,
                         name="init_bn")
        if not small_inputs:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
        for i, block_size in enumerate(block_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                x = block(x, num_filters * 2**i, strides=strides, train=train)

        return x
예제 #5
0
            def apply(self, x, n_bins, stage_sizes, block_cls, num_filters=64, 
                    dtype=jnp.float32, act=nn.leaky_relu, train=True):
                b = x.shape[0]
                conv = nn.Conv.partial(bias=False, dtype=dtype)
                norm = nn.BatchNorm.partial(
                    use_running_average=not train,
                    dtype=dtype
                )

                x = conv(x, num_filters, kernel_size=(7,), strides=(2,), padding=[(3, 3)])
                x = norm(x)
                x = nn.leaky_relu(x)
                x = nn.max_pool(x, window_shape=(3,), strides=(2,), padding='SAME')

                for i, block_size in enumerate(stage_sizes):
                    for j in range(block_size):
                        strides = (2,) if i > 0 and j == 0 else (1,)
                        x = block_cls(x, num_filters * 2 ** i,
                                  strides=strides,
                                  conv=conv,
                                  norm=norm,
                                  act=act)
                x = x.reshape(b, -1)
                x = nn.Dense(x, n_bins, dtype=dtype)
                x = nn.softmax(x)
                return x
예제 #6
0
 def apply(self, x, features, n_stages, act=nn.relu):
     x = act(x)
     path = x
     for _ in range(n_stages):
         path = nn.max_pool(path,
                            window_shape=(5, 5),
                            strides=(1, 1),
                            padding='SAME')
         path = ncsn_conv3x3(path, features, stride=1, bias=False)
         x = path + x
     return x
예제 #7
0
 def apply(
     self,
     x,
     num_classes,
     num_filters=64,
     num_layers=50,
     train=True,
     axis_name=None,
     axis_index_groups=None,
     dtype=jnp.float32,
     batch_norm_momentum=0.9,
     batch_norm_epsilon=1e-5,
     bn_output_scale=0.0,
     virtual_batch_size=None,
     data_format=None):
   if num_layers not in _block_size_options:
     raise ValueError('Please provide a valid number of layers')
   block_sizes = _block_size_options[num_layers]
   conv = nn.Conv.partial(padding=[(3, 3), (3, 3)])
   x = conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False,
            dtype=dtype, name='conv0')
   x = normalization.VirtualBatchNorm(
       x,
       use_running_average=not train,
       momentum=batch_norm_momentum,
       epsilon=batch_norm_epsilon,
       name='init_bn',
       axis_name=axis_name,
       axis_index_groups=axis_index_groups,
       dtype=dtype,
       virtual_batch_size=virtual_batch_size,
       data_format=data_format)
   x = nn.relu(x)  # MLPerf-required
   x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
   for i, block_size in enumerate(block_sizes):
     for j in range(block_size):
       strides = (2, 2) if i > 0 and j == 0 else (1, 1)
       x = ResidualBlock(
           x,
           num_filters * 2 ** i,
           strides=strides,
           train=train,
           axis_name=axis_name,
           axis_index_groups=axis_index_groups,
           dtype=dtype,
           batch_norm_momentum=batch_norm_momentum,
           batch_norm_epsilon=batch_norm_epsilon,
           bn_output_scale=bn_output_scale,
           virtual_batch_size=virtual_batch_size,
           data_format=data_format)
   x = jnp.mean(x, axis=(1, 2))
   x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(),
                dtype=dtype)
   return x
예제 #8
0
    def apply(self, x, width):
        x = fixed_padding(x, 7)
        x = StdConv(x,
                    width, (7, 7), (2, 2),
                    padding="VALID",
                    bias=False,
                    name="conv_root")

        x = fixed_padding(x, 3)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="VALID")

        return x
예제 #9
0
 def apply(self,
           x,
           num_classes,
           num_filters=64,
           num_layers=50,
           train=True,
           axis_name=None,
           axis_index_groups=None,
           dtype=jnp.float32,
           conv0_space_to_depth=False):
     if num_layers not in _block_size_options:
         raise ValueError('Please provide a valid number of layers')
     block_sizes = _block_size_options[num_layers]
     if conv0_space_to_depth:
         conv = SpaceToDepthConv.partial(block_size=(2, 2),
                                         padding=[(2, 1), (2, 1)])
     else:
         conv = nn.Conv.partial(padding=[(3, 3), (3, 3)])
     x = conv(x,
              num_filters,
              kernel_size=(7, 7),
              strides=(2, 2),
              bias=False,
              dtype=dtype,
              name='conv0')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      name='init_bn',
                      axis_name=axis_name,
                      axis_index_groups=axis_index_groups,
                      dtype=dtype)
     x = nn.relu(x)  # MLPerf-required
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(block_sizes):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = ResidualBlock(x,
                               num_filters * 2**i,
                               strides=strides,
                               train=train,
                               axis_name=axis_name,
                               axis_index_groups=axis_index_groups,
                               dtype=dtype)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x,
                  num_classes,
                  kernel_init=nn.initializers.normal(),
                  dtype=dtype)
     x = nn.log_softmax(x)
     return x
예제 #10
0
  def apply(self, x, num_classes, parameters, num_filters=64,
            train=True, axis_name=None, num_layers='34'):
    block_sizes = [3, 4, 6]
    data_format = 'channels_last'
    if ('conv0_space_to_depth' in parameters and
        parameters['conv0_space_to_depth']):
      # conv0 uses space-to-depth transform for TPU performance.
      x = func_conv0_space_to_depth(inputs=x, data_format=data_format,
                                    dtype=parameters['dtype'])
    else:
      x = conv2d_fixed_padding(
          inputs=x,
          filters=num_filters,
          kernel_size=7,
          strides=2,
          data_format=data_format,
          name='init_conv')

    replica_groups = _make_replica_groups(parameters)
    x = nn.BatchNorm(
        x,
        use_running_average=not train,
        momentum=0.9,
        epsilon=1e-5,
        name='init_bn',
        axis_name=axis_name,
        dtype=parameters['dtype'],
        axis_index_groups=replica_groups)
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    for i, block_size in enumerate(block_sizes):
      for j in range(block_size):
        strides = 2 if i == 1 and j == 0 else 1
        use_projection = False if i == 0 or j > 0 else True
        x = ResidualBlock(
            x,
            num_filters * 2**i,
            parameters,
            strides=strides,
            train=train,
            axis_name=axis_name,
            use_projection=use_projection,
            data_format=data_format)
    if num_layers == '34':
      x = jnp.mean(x, axis=(1, 2))
      x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(),
                   dtype=jnp.float32)  # TODO(deveci): dtype=dtype
      x = nn.log_softmax(x)
    return x
 def apply(self, x, use_squeeze_excite=False):
     x = nn.Conv(x, features=8, kernel_size=(3, 3), padding="VALID")
     x = nn.relu(x)
     x = nn.Conv(x, features=16, kernel_size=(3, 3), padding="VALID")
     x = nn.relu(x)
     if use_squeeze_excite:
         x = SqueezeExciteLayer(x)
     x = nn.Conv(x, features=32, kernel_size=(3, 3), padding="VALID")
     x = nn.relu(x)
     if use_squeeze_excite:
         x = SqueezeExciteLayer(x)
     x = nn.Conv(x, features=1, kernel_size=(3, 3), padding="VALID")
     scores = nn.max_pool(x, window_shape=(8, 8), strides=(8, 8))[Ellipsis,
                                                                  0]
     return scores
예제 #12
0
 def apply(
     self,
     x,
     num_outputs,
     num_filters=64,
     num_layers=50,
     train=True,
     batch_stats=None,
     dtype=jnp.float32,
     batch_norm_momentum=0.9,
     batch_norm_epsilon=1e-5,
     virtual_batch_size=None,
     data_format=None):
   if num_layers not in _block_size_options:
     raise ValueError('Please provide a valid number of layers')
   block_sizes = _block_size_options[num_layers]
   x = nn.Conv(x, num_filters, (7, 7), (2, 2),
               bias=False,
               dtype=dtype,
               name='init_conv')
   x = normalization.VirtualBatchNorm(
       x,
       batch_stats=batch_stats,
       use_running_average=not train,
       momentum=batch_norm_momentum,
       epsilon=batch_norm_epsilon,
       dtype=dtype,
       name='init_bn',
       virtual_batch_size=virtual_batch_size,
       data_format=data_format)
   x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
   for i, block_size in enumerate(block_sizes):
     for j in range(block_size):
       strides = (2, 2) if i > 0 and j == 0 else (1, 1)
       x = ResidualBlock(
           x,
           num_filters * 2 ** i,
           strides=strides,
           train=train,
           batch_stats=batch_stats,
           dtype=dtype,
           batch_norm_momentum=batch_norm_momentum,
           batch_norm_epsilon=batch_norm_epsilon,
           virtual_batch_size=virtual_batch_size,
           data_format=data_format)
   x = jnp.mean(x, axis=(1, 2))
   x = nn.Dense(x, num_outputs, dtype=dtype)
   return x
예제 #13
0
파일: nn_test.py 프로젝트: zhang-yd15/flax
 def test_max_pool(self):
   x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
   pool = lambda x: nn.max_pool(x, (2, 2))
   expected_y = jnp.array([
       [4., 5.],
       [7., 8.],
   ]).reshape((1, 2, 2, 1))
   y = pool(x)
   onp.testing.assert_allclose(y, expected_y)
   y_grad = jax.grad(lambda x: pool(x).sum())(x)
   expected_grad = jnp.array([
       [0., 0., 0.],
       [0., 1., 1.],
       [0., 1., 1.],
   ]).reshape((1, 3, 3, 1))
   onp.testing.assert_allclose(y_grad, expected_grad)
예제 #14
0
def features(x, num_layers, normalizer, dtype, train):
    """Implements the feature extraction portion of the network."""

    layers = _layer_size_options[num_layers]
    conv = nn.Conv.partial(bias=False, dtype=dtype)
    maybe_normalize = model_utils.get_normalizer(normalizer, train)
    for l in layers:
        if l == 'M':
            x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        else:
            x = conv(x,
                     features=l,
                     kernel_size=(3, 3),
                     padding=((1, 1), (1, 1)))
            x = maybe_normalize(x)
            x = nn.relu(x)
    return x
예제 #15
0
 def apply(self, x, num_outputs, train=True):
   x = nn.Conv(x, self.NUM_FILTERS, (7, 7), (2, 2), bias=False,
               name='init_conv')
   x = nn.BatchNorm(x,
                    use_running_average=not train,
                    momentum=0.9, epsilon=1e-5,
                    name='init_bn')
   x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
   for i, block_size in enumerate(self.BLOCK_SIZES):
     for j in range(block_size):
       strides = (2, 2) if i > 0 and j == 0 else (1, 1)
       x = BottleneckBlock(x, self.NUM_FILTERS * 2 ** i,
                           strides=strides, groups=self.GROUPS,
                           base_width=self.WIDTH_PER_GROUP, train=train)
   x = jnp.mean(x, axis=(1, 2))
   x = nn.Dense(x, num_outputs, name='clf')
   return x
예제 #16
0
파일: nn_test.py 프로젝트: wdevazelhes/flax
 def test_max_pool_explicit_pads(self):
     x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
     pool = lambda x: nn.max_pool(x, (2, 2), padding=((1, 1), (1, 1)))
     expected_y = jnp.array([
         [0., 1., 2., 2.],
         [3., 4., 5., 5.],
         [6., 7., 8., 8.],
         [6., 7., 8., 8.],
     ]).reshape((1, 4, 4, 1))
     y = pool(x)
     onp.testing.assert_allclose(y, expected_y)
     y_grad = jax.grad(lambda x: pool(x).sum())(x)
     expected_grad = jnp.array([
         [1., 1., 2.],
         [1., 1., 2.],
         [2., 2., 4.],
     ]).reshape((1, 3, 3, 1))
     onp.testing.assert_allclose(y_grad, expected_grad)
예제 #17
0
  def apply(self,
            x,
            output_shape,
            encoder,
            decoder,
            train=True):
    if not len(encoder['filter_sizes']) == len(encoder['kernel_sizes']) == len(
        encoder['kernel_paddings']) == len(encoder['window_sizes']) == len(
            encoder['window_paddings']) == len(encoder['strides']) == len(
                encoder['activations']):
      raise ValueError(
          'The elements in encoder dict do not have the same length.')

    if not len(decoder['filter_sizes']) == len(decoder['kernel_sizes']) == len(
        decoder['window_sizes']) == len(decoder['paddings']) == len(
            decoder['activations']):
      raise ValueError(
          'The elements in decoder dict do not have the same length.')

    # encoder
    for i in range(len(encoder['filter_sizes'])):
      x = nn.Conv(
          x,
          encoder['filter_sizes'][i],
          encoder['kernel_sizes'][i],
          padding=encoder['kernel_paddings'][i])
      x = model_utils.ACTIVATIONS[encoder['activations'][i]](x)
      x = nn.max_pool(
          x, encoder['window_sizes'][i],
          strides=encoder['strides'][i],
          padding=encoder['window_paddings'][i])

    # decoder
    for i in range(len(decoder['filter_sizes'])):
      x = nn.ConvTranspose(
          x,
          decoder['filter_sizes'][i],
          decoder['kernel_sizes'][i],
          decoder['window_sizes'][i],
          padding=decoder['paddings'][i])
      x = model_utils.ACTIVATIONS[decoder['activations'][i]](x)
    return x
예제 #18
0
  def apply(self,
            x,
            num_outputs,
            num_filters,
            kernel_sizes,
            kernel_paddings,
            window_sizes,
            window_paddings,
            strides,
            num_dense_units,
            activation_fn,
            normalizer='none',
            kernel_init=initializers.lecun_normal(),
            bias_init=initializers.zeros,
            train=True):

    maybe_normalize = model_utils.get_normalizer(normalizer, train)
    for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in zip(
        num_filters, kernel_sizes, kernel_paddings, window_sizes,
        window_paddings, strides):
      x = nn.Conv(
          x,
          num_filters, (kernel_size, kernel_size), (1, 1),
          padding=kernel_padding,
          kernel_init=kernel_init,
          bias_init=bias_init)
      x = model_utils.ACTIVATIONS[activation_fn](x)
      x = maybe_normalize(x)
      x = nn.max_pool(
          x,
          window_shape=(window_size, window_size),
          strides=(stride, stride),
          padding=window_padding)
    x = jnp.reshape(x, (x.shape[0], -1))
    for num_units in num_dense_units:
      x = nn.Dense(
          x, num_units, kernel_init=kernel_init, bias_init=bias_init)
      x = model_utils.ACTIVATIONS[activation_fn](x)
      x = maybe_normalize(x)
    x = nn.Dense(x, num_outputs, kernel_init=kernel_init, bias_init=bias_init)
    return x
예제 #19
0
    def apply(self,
              x,
              num_classes=1000,
              train=False,
              width_factor=1,
              num_layers=50):
        del train
        blocks, bottleneck = get_block_desc(num_layers)
        width = int(64 * width_factor)

        # Root block
        x = StdConv(x, width, (7, 7), (2, 2), bias=False, name="conv_root")
        x = nn.GroupNorm(x, name="gn_root")
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")

        # Stages
        x = ResNetStage(x,
                        blocks[0],
                        width,
                        first_stride=(1, 1),
                        bottleneck=bottleneck,
                        name="block1")
        for i, block_size in enumerate(blocks[1:], 1):
            x = ResNetStage(x,
                            block_size,
                            width * 2**i,
                            first_stride=(2, 2),
                            bottleneck=bottleneck,
                            name=f"block{i + 1}")

        # Head
        x = jnp.mean(x, axis=(1, 2))
        x = IdentityLayer(x, name="pre_logits")
        x = nn.Dense(x,
                     num_classes,
                     kernel_init=nn.initializers.zeros,
                     name="head")
        return x
예제 #20
0
    def apply(self,
              x,
              *,
              train,
              num_classes,
              block_class=BottleneckResNetImageNetBlock,
              stage_sizes,
              width_factor=1,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              bias_scale=0.0,
              weight_norm='none',
              compensate_padding=True,
              softplus_scale=None,
              no_head=False,
              zero_inits=True):
        """Construct ResNet V1 with `num_classes` outputs."""
        self._stage_sizes = stage_sizes
        if std_penalty_mult > 0:
            raise NotImplementedError(
                'std_penalty_mult not supported for ResNetImageNet')

        width = 64 * width_factor

        # Root block.
        activation_f = get_activation_f(activation_f, train, softplus_scale,
                                        bias_scale)
        norm = get_norm(activation_f, normalization, train)
        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)
        x = conv(x,
                 width,
                 kernel_size=(7, 7),
                 strides=(2, 2),
                 name='init_conv')
        x = norm(x, name='init_bn')

        if compensate_padding:
            # NOTE: this leads to lower performance.
            x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding='SAME')
        else:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

        # Stages.
        for i, stage_size in enumerate(stage_sizes):
            x = ResNetStage(
                x,
                stage_size,
                filters=width * 2**i,
                block_class=block_class,
                first_block_strides=(1, 1) if i == 0 else (2, 2),
                train=train,
                name=f'stage{i + 1}',
                conv=conv,
                norm=norm,
                activation_f=activation_f,
                use_residual=use_residual,
                zero_inits=zero_inits,
            )

        if not no_head:
            # Head.
            x = jnp.mean(x, axis=(1, 2))
            x = nn.Dense(x,
                         num_classes,
                         kernel_init=nn.initializers.zeros
                         if zero_inits else nn.initializers.lecun_normal(),
                         name='head')
        return x, 0, {}
예제 #21
0
    def apply(self,
              x,
              depth,
              num_outputs,
              dropout_rate=0.0,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              train=True,
              bias_scale=0.0,
              weight_norm='none',
              filters=16,
              no_head=False,
              report_metrics=False,
              benchmark='cifar10',
              compensate_padding=True,
              softplus_scale=None):

        bn_index = iter(range(1000))
        conv_index = iter(range(1000))
        summaries = {}
        summary_ind = [0]

        def add_summary(name, val):
            """Summarize statistics of tensor."""
            if report_metrics:
                assert val.ndim == 4, (
                    'Assuming 4D inputs with channels last, got %s' %
                    str(val.shape))
                assert val.shape[1] == val.shape[
                    2], 'Assuming 4D inputs with channels last'
                summaries['%s_%d_mean_abs' %
                          (name, summary_ind[0] // 2)] = jnp.mean(
                              jnp.abs(jnp.mean(val, axis=(0, 1, 2))))
                summaries['%s_%d_mean_std' %
                          (name, summary_ind[0] // 2)] = jnp.mean(
                              jnp.std(val, axis=(0, 1, 2)))
                summary_ind[0] += 1

        penalty = 0

        activation_f = get_activation_f(activation_f, train, softplus_scale,
                                        bias_scale)
        norm = get_norm(activation_f, normalization, train)

        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)

        def resnet_layer(
            inputs,
            penalty,
            filters,
            kernel_size=3,
            strides=1,
            activation=None,
        ):
            """2D Convolution-Batch Normalization-Activation stack builder."""
            x = inputs
            x = conv(x,
                     filters, (kernel_size, kernel_size),
                     strides=(strides, strides),
                     padding='SAME',
                     name='conv%d' % next(conv_index))
            x = norm(x, name='norm%d' % next(bn_index))
            add_summary('postnorm', x)
            if std_penalty_mult > 0:
                penalty += std_penalty(x)

            if activation:
                x = activation_f(x, features=x.shape[-1])
            add_summary('postact', x)
            return x, penalty

        # Main network code.
        num_res_blocks = (depth - 2) // 6

        if (depth - 2) % 6 != 0:
            raise ValueError('depth must be 6n+2 (e.g. 20, 32, 44).')

        inputs = x
        add_summary('input', x)
        add_summary('inputb', x)
        if benchmark in ['cifar10', 'cifar100']:
            x, penalty = resnet_layer(inputs,
                                      penalty,
                                      filters=filters,
                                      activation=True)
            head_kernel_init = nn.initializers.lecun_normal()
        elif benchmark in ['imagenet']:
            head_kernel_init = nn.initializers.zeros
            x, penalty = resnet_layer(inputs,
                                      penalty,
                                      filters=filters,
                                      activation=False,
                                      kernel_size=7,
                                      strides=2)
            # TODO(basv): evaluate max pool v/s avg_pool in an experiment?
            # if compensate_padding:
            #   x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding="VALID")
            # else:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        else:
            raise ValueError('Model def not prepared for benchmark %s' %
                             benchmark)

        for stack in range(3):
            for res_block in range(num_res_blocks):
                strides = 1
                if stack > 0 and res_block == 0:  # First layer but not first stack.
                    strides = 2  # Downsample.
                y, penalty = resnet_layer(
                    x,
                    penalty,
                    filters=filters,
                    strides=strides,
                    activation=True,
                )
                y, penalty = resnet_layer(
                    y,
                    penalty,
                    filters=filters,
                    activation=False,
                )
                if stack > 0 and res_block == 0:  # First layer but not first stack.
                    # Linear projection residual shortcut to match changed dims.
                    x, penalty = resnet_layer(
                        x,
                        penalty,
                        filters=filters,
                        kernel_size=1,
                        strides=strides,
                        activation=False,
                    )

                if use_residual == 1:
                    # Apply an up projection in case of channel mismatch
                    x = x + y
                elif use_residual == 2:
                    x = (x + y) / jnp.sqrt(
                        1**2 + 1**2)  # Sum of independent normals.
                else:
                    x = y

                add_summary('postres', x)
                x = activation_f(x, features=x.shape[-1])
                add_summary('postresact', x)
            filters *= 2

        # V1 does not use BN after last shortcut connection-ReLU.
        if not no_head:
            x = jnp.mean(x, axis=(1, 2))
            add_summary('postpool', x)
            x = x.reshape((x.shape[0], -1))

            x = nn.Dense(x, num_outputs, kernel_init=head_kernel_init)
        return x, penalty, summaries
예제 #22
0
    def apply(self,
              x,
              num_classes=1000,
              train=False,
              resnet=None,
              patches=None,
              hidden_size=None,
              transformer=None,
              representation_size=None,
              classifier='gap'):

        # (Possibly partial) ResNet root.
        if resnet is not None:
            width = int(64 * resnet.width_factor)

            # Root block.
            x = models_resnet.StdConv(x,
                                      width, (7, 7), (2, 2),
                                      bias=False,
                                      name='conv_root')
            x = nn.GroupNorm(x, name='gn_root')
            x = nn.relu(x)
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

            # ResNet stages.
            x = models_resnet.ResNetStage(x,
                                          resnet.num_layers[0],
                                          width,
                                          first_stride=(1, 1),
                                          name='block1')
            for i, block_size in enumerate(resnet.num_layers[1:], 1):
                x = models_resnet.ResNetStage(x,
                                              block_size,
                                              width * 2**i,
                                              first_stride=(2, 2),
                                              name=f'block{i + 1}')

        n, h, w, c = x.shape

        # We can merge s2d+emb into a single conv; it's the same.
        x = nn.Conv(x,
                    hidden_size,
                    patches.size,
                    strides=patches.size,
                    padding='VALID',
                    name='embedding')

        # Here, x is a grid of embeddings.

        # (Possibly partial) Transformer.
        if transformer is not None:
            n, h, w, c = x.shape
            x = jnp.reshape(x, [n, h * w, c])

            # If we want to add a class token, add it here.
            if classifier == 'token':
                cls = self.param('cls', (1, 1, c), nn.initializers.zeros)
                cls = jnp.tile(cls, [n, 1, 1])
                x = jnp.concatenate([cls, x], axis=1)

            x = Encoder(x, train=train, name='Transformer', **transformer)

        if classifier == 'token':
            x = x[:, 0]
        elif classifier == 'gap':
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)

        if representation_size is not None:
            x = nn.Dense(x, representation_size, name='pre_logits')
            x = nn.tanh(x)
        else:
            x = IdentityLayer(x, name='pre_logits')

        x = nn.Dense(x,
                     num_classes,
                     name='head',
                     kernel_init=nn.initializers.zeros)
        return x
  def apply(self,
            x,
            num_classes=1,
            train=False,
            hidden_size=None,
            transformer=None,
            resnet_emb=None,
            representation_size=None):
    """Apply model on inputs.

    Args:
      x: the processed input patches and position annotations.
      num_classes: the number of output classes. 1 for single model.
      train: train or eval.
      hidden_size: the hidden dimension for patch embedding tokens.
      transformer: the model config for Transformer backbone.
      resnet_emb: the config for patch embedding w/ small resnet.
      representation_size: size of the last FC before prediction.

    Returns:
      Model prediction output.
    """
    assert transformer is not None
    # Either 3: (batch size, seq len, channel) or
    # 4: (batch size, crops, seq len, channel)
    assert len(x.shape) in [3, 4]

    multi_crops_input = False
    if len(x.shape) == 4:
      multi_crops_input = True
      batch_size, num_crops, l, channel = x.shape
      x = jnp.reshape(x, [batch_size * num_crops, l, channel])

    # We concat (x, spatial_positions, scale_posiitons, input_masks)
    # when preprocessing.
    inputs_spatial_positions = x[:, :, -3]
    inputs_spatial_positions = inputs_spatial_positions.astype(jnp.int32)
    inputs_scale_positions = x[:, :, -2]
    inputs_scale_positions = inputs_scale_positions.astype(jnp.int32)
    inputs_masks = x[:, :, -1]
    inputs_masks = inputs_masks.astype(jnp.bool_)
    x = x[:, :, :-3]
    n, l, channel = x.shape
    if hidden_size:
      if resnet_emb:
        # channel = patch_size * patch_size * 3
        patch_size = int(np.sqrt(channel // 3))
        x = jnp.reshape(x, [-1, patch_size, patch_size, 3])
        x = resnet.StdConv(
            x, RESNET_TOKEN_DIM, (7, 7), (2, 2), bias=False, name="conv_root")
        x = nn.GroupNorm(x, name="gn_root")
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")

        if resnet_emb.num_layers > 0:
          blocks, bottleneck = resnet.get_block_desc(resnet_emb.num_layers)
          if blocks:
            x = resnet.ResNetStage(
                x,
                blocks[0],
                RESNET_TOKEN_DIM,
                first_stride=(1, 1),
                bottleneck=bottleneck,
                name="block1")
            for i, block_size in enumerate(blocks[1:], 1):
              x = resnet.ResNetStage(
                  x,
                  block_size,
                  RESNET_TOKEN_DIM * 2**i,
                  first_stride=(2, 2),
                  bottleneck=bottleneck,
                  name=f"block{i + 1}")
        x = jnp.reshape(x, [n, l, -1])

      x = nn.Dense(x, hidden_size, name="embedding")

    # Here, x is a list of embeddings.
    x = utils.Encoder(
        x,
        inputs_spatial_positions,
        inputs_scale_positions,
        inputs_masks,
        train=train,
        name="Transformer",
        **transformer)

    x = x[:, 0]

    if representation_size:
      x = nn.Dense(x, representation_size, name="pre_logits")
      x = nn.tanh(x)
    else:
      x = resnet.IdentityLayer(x, name="pre_logits")

    x = nn.Dense(x, num_classes, name="head", kernel_init=nn.initializers.zeros)
    if multi_crops_input:
      _, channel = x.shape
      x = jnp.reshape(x, [batch_size, num_crops, channel])
    return x