Пример #1
0
    def apply(self, x, nout, strides=(1, 1), bottleneck=True):
        features = nout
        nout = nout * 4 if bottleneck else nout
        needs_projection = x.shape[-1] != nout or strides != (1, 1)
        residual = x
        if needs_projection:
            residual = StdConv(residual,
                               nout, (1, 1),
                               strides,
                               bias=False,
                               name="conv_proj")
            residual = nn.GroupNorm(residual, epsilon=1e-4, name="gn_proj")

        if bottleneck:
            x = StdConv(x, features, (1, 1), bias=False, name="conv1")
            x = nn.GroupNorm(x, epsilon=1e-4, name="gn1")
            x = nn.relu(x)

        x = StdConv(x, features, (3, 3), strides, bias=False, name="conv2")
        x = nn.GroupNorm(x, epsilon=1e-4, name="gn2")
        x = nn.relu(x)

        last_kernel = (1, 1) if bottleneck else (3, 3)
        x = StdConv(x, nout, last_kernel, bias=False, name="conv3")
        x = nn.GroupNorm(x,
                         epsilon=1e-4,
                         name="gn3",
                         scale_init=nn.initializers.zeros)
        x = nn.relu(residual + x)

        return x
Пример #2
0
    def apply(self, x, nout, strides=(1, 1)):
        needs_projection = x.shape[-1] != nout * 4 or strides != (1, 1)

        residual = x
        if needs_projection:
            residual = StdConv(residual,
                               nout * 4, (1, 1),
                               strides,
                               bias=False,
                               name='conv_proj')
            residual = nn.GroupNorm(residual, name='gn_proj')

        y = StdConv(x, nout, (1, 1), bias=False, name='conv1')
        y = nn.GroupNorm(y, name='gn1')
        y = nn.relu(y)
        y = StdConv(y, nout, (3, 3), strides, bias=False, name='conv2')
        y = nn.GroupNorm(y, name='gn2')
        y = nn.relu(y)
        y = StdConv(y, nout * 4, (1, 1), bias=False, name='conv3')

        y = nn.GroupNorm(y, name='gn3', scale_init=nn.initializers.zeros)
        y = nn.relu(residual + y)
        return y
Пример #3
0
    def apply(self,
              x,
              num_classes=1000,
              train=False,
              width_factor=1,
              num_layers=50):
        del train
        blocks, bottleneck = get_block_desc(num_layers)
        width = int(64 * width_factor)

        # Root block
        x = StdConv(x, width, (7, 7), (2, 2), bias=False, name="conv_root")
        x = nn.GroupNorm(x, name="gn_root")
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")

        # Stages
        x = ResNetStage(x,
                        blocks[0],
                        width,
                        first_stride=(1, 1),
                        bottleneck=bottleneck,
                        name="block1")
        for i, block_size in enumerate(blocks[1:], 1):
            x = ResNetStage(x,
                            block_size,
                            width * 2**i,
                            first_stride=(2, 2),
                            bottleneck=bottleneck,
                            name=f"block{i + 1}")

        # Head
        x = jnp.mean(x, axis=(1, 2))
        x = IdentityLayer(x, name="pre_logits")
        x = nn.Dense(x,
                     num_classes,
                     kernel_init=nn.initializers.zeros,
                     name="head")
        return x
Пример #4
0
    def apply(self,
              x,
              num_classes=1000,
              train=False,
              resnet=None,
              patches=None,
              hidden_size=None,
              transformer=None,
              representation_size=None,
              classifier='gap'):

        # (Possibly partial) ResNet root.
        if resnet is not None:
            width = int(64 * resnet.width_factor)

            # Root block.
            x = models_resnet.StdConv(x,
                                      width, (7, 7), (2, 2),
                                      bias=False,
                                      name='conv_root')
            x = nn.GroupNorm(x, name='gn_root')
            x = nn.relu(x)
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

            # ResNet stages.
            x = models_resnet.ResNetStage(x,
                                          resnet.num_layers[0],
                                          width,
                                          first_stride=(1, 1),
                                          name='block1')
            for i, block_size in enumerate(resnet.num_layers[1:], 1):
                x = models_resnet.ResNetStage(x,
                                              block_size,
                                              width * 2**i,
                                              first_stride=(2, 2),
                                              name=f'block{i + 1}')

        n, h, w, c = x.shape

        # We can merge s2d+emb into a single conv; it's the same.
        x = nn.Conv(x,
                    hidden_size,
                    patches.size,
                    strides=patches.size,
                    padding='VALID',
                    name='embedding')

        # Here, x is a grid of embeddings.

        # (Possibly partial) Transformer.
        if transformer is not None:
            n, h, w, c = x.shape
            x = jnp.reshape(x, [n, h * w, c])

            # If we want to add a class token, add it here.
            if classifier == 'token':
                cls = self.param('cls', (1, 1, c), nn.initializers.zeros)
                cls = jnp.tile(cls, [n, 1, 1])
                x = jnp.concatenate([cls, x], axis=1)

            x = Encoder(x, train=train, name='Transformer', **transformer)

        if classifier == 'token':
            x = x[:, 0]
        elif classifier == 'gap':
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)

        if representation_size is not None:
            x = nn.Dense(x, representation_size, name='pre_logits')
            x = nn.tanh(x)
        else:
            x = IdentityLayer(x, name='pre_logits')

        x = nn.Dense(x,
                     num_classes,
                     name='head',
                     kernel_init=nn.initializers.zeros)
        return x
  def apply(self,
            x,
            num_classes=1,
            train=False,
            hidden_size=None,
            transformer=None,
            resnet_emb=None,
            representation_size=None):
    """Apply model on inputs.

    Args:
      x: the processed input patches and position annotations.
      num_classes: the number of output classes. 1 for single model.
      train: train or eval.
      hidden_size: the hidden dimension for patch embedding tokens.
      transformer: the model config for Transformer backbone.
      resnet_emb: the config for patch embedding w/ small resnet.
      representation_size: size of the last FC before prediction.

    Returns:
      Model prediction output.
    """
    assert transformer is not None
    # Either 3: (batch size, seq len, channel) or
    # 4: (batch size, crops, seq len, channel)
    assert len(x.shape) in [3, 4]

    multi_crops_input = False
    if len(x.shape) == 4:
      multi_crops_input = True
      batch_size, num_crops, l, channel = x.shape
      x = jnp.reshape(x, [batch_size * num_crops, l, channel])

    # We concat (x, spatial_positions, scale_posiitons, input_masks)
    # when preprocessing.
    inputs_spatial_positions = x[:, :, -3]
    inputs_spatial_positions = inputs_spatial_positions.astype(jnp.int32)
    inputs_scale_positions = x[:, :, -2]
    inputs_scale_positions = inputs_scale_positions.astype(jnp.int32)
    inputs_masks = x[:, :, -1]
    inputs_masks = inputs_masks.astype(jnp.bool_)
    x = x[:, :, :-3]
    n, l, channel = x.shape
    if hidden_size:
      if resnet_emb:
        # channel = patch_size * patch_size * 3
        patch_size = int(np.sqrt(channel // 3))
        x = jnp.reshape(x, [-1, patch_size, patch_size, 3])
        x = resnet.StdConv(
            x, RESNET_TOKEN_DIM, (7, 7), (2, 2), bias=False, name="conv_root")
        x = nn.GroupNorm(x, name="gn_root")
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")

        if resnet_emb.num_layers > 0:
          blocks, bottleneck = resnet.get_block_desc(resnet_emb.num_layers)
          if blocks:
            x = resnet.ResNetStage(
                x,
                blocks[0],
                RESNET_TOKEN_DIM,
                first_stride=(1, 1),
                bottleneck=bottleneck,
                name="block1")
            for i, block_size in enumerate(blocks[1:], 1):
              x = resnet.ResNetStage(
                  x,
                  block_size,
                  RESNET_TOKEN_DIM * 2**i,
                  first_stride=(2, 2),
                  bottleneck=bottleneck,
                  name=f"block{i + 1}")
        x = jnp.reshape(x, [n, l, -1])

      x = nn.Dense(x, hidden_size, name="embedding")

    # Here, x is a list of embeddings.
    x = utils.Encoder(
        x,
        inputs_spatial_positions,
        inputs_scale_positions,
        inputs_masks,
        train=train,
        name="Transformer",
        **transformer)

    x = x[:, 0]

    if representation_size:
      x = nn.Dense(x, representation_size, name="pre_logits")
      x = nn.tanh(x)
    else:
      x = resnet.IdentityLayer(x, name="pre_logits")

    x = nn.Dense(x, num_classes, name="head", kernel_init=nn.initializers.zeros)
    if multi_crops_input:
      _, channel = x.shape
      x = jnp.reshape(x, [batch_size, num_crops, channel])
    return x