Example #1
0
    def __call__(self, inputs, *, train):

        x = inputs
        # (Possibly partial) ResNet root.
        if self.resnet is not None:
            width = int(64 * self.resnet.width_factor)

            # Root block.
            x = models_resnet.StdConv(features=width,
                                      kernel_size=(7, 7),
                                      strides=(2, 2),
                                      use_bias=False,
                                      name='conv_root')(x)
            x = nn.GroupNorm(name='gn_root')(x)
            x = nn.relu(x)
            x = nn.max_pool(x,
                            window_shape=(3, 3),
                            strides=(2, 2),
                            padding='SAME')

            # ResNet stages.
            if self.resnet.num_layers:
                x = models_resnet.ResNetStage(
                    block_size=self.resnet.num_layers[0],
                    nout=width,
                    first_stride=(1, 1),
                    name='block1')(x)
                for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
                    x = models_resnet.ResNetStage(block_size=block_size,
                                                  nout=width * 2**i,
                                                  first_stride=(2, 2),
                                                  name=f'block{i + 1}')(x)

        n, h, w, c = x.shape

        # We can merge s2d+emb into a single conv; it's the same.
        x = nn.Conv(features=self.hidden_size,
                    kernel_size=self.patches.size,
                    strides=self.patches.size,
                    padding='VALID',
                    name='embedding')(x)

        # Here, x is a grid of embeddings.

        # Transformer.
        n, h, w, c = x.shape
        x = jnp.reshape(x, [n, h * w, c])

        # If we want to add a class token, add it here.
        if self.classifier == 'token':
            cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
            cls = jnp.tile(cls, [n, 1, 1])
            x = jnp.concatenate([cls, x], axis=1)

        x = Encoder(name='Transformer', **self.transformer)(x, train=train)

        if self.classifier == 'token':
            x = x[:, 0]
        elif self.classifier == 'gap':
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)
        else:
            raise ValueError(f'Invalid classifier={self.classifier}')

        if self.representation_size is not None:
            x = nn.Dense(features=self.representation_size,
                         name='pre_logits')(x)
            x = nn.tanh(x)
        else:
            x = IdentityLayer(name='pre_logits')(x)

        if self.num_classes:
            x = nn.Dense(features=self.num_classes,
                         name='head',
                         kernel_init=nn.initializers.zeros)(x)
        return x
Example #2
0
    def apply(self,
              x,
              num_classes=1000,
              train=False,
              resnet=None,
              patches=None,
              hidden_size=None,
              transformer=None,
              representation_size=None,
              classifier='gap'):

        # (Possibly partial) ResNet root.
        if resnet is not None:
            width = int(64 * resnet.width_factor)

            # Root block.
            x = models_resnet.StdConv(x,
                                      width, (7, 7), (2, 2),
                                      bias=False,
                                      name='conv_root')
            x = nn.GroupNorm(x, name='gn_root')
            x = nn.relu(x)
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

            # ResNet stages.
            x = models_resnet.ResNetStage(x,
                                          resnet.num_layers[0],
                                          width,
                                          first_stride=(1, 1),
                                          name='block1')
            for i, block_size in enumerate(resnet.num_layers[1:], 1):
                x = models_resnet.ResNetStage(x,
                                              block_size,
                                              width * 2**i,
                                              first_stride=(2, 2),
                                              name=f'block{i + 1}')

        n, h, w, c = x.shape

        # We can merge s2d+emb into a single conv; it's the same.
        x = nn.Conv(x,
                    hidden_size,
                    patches.size,
                    strides=patches.size,
                    padding='VALID',
                    name='embedding')

        # Here, x is a grid of embeddings.

        # (Possibly partial) Transformer.
        if transformer is not None:
            n, h, w, c = x.shape
            x = jnp.reshape(x, [n, h * w, c])

            # If we want to add a class token, add it here.
            if classifier == 'token':
                cls = self.param('cls', (1, 1, c), nn.initializers.zeros)
                cls = jnp.tile(cls, [n, 1, 1])
                x = jnp.concatenate([cls, x], axis=1)

            x = Encoder(x, train=train, name='Transformer', **transformer)

        if classifier == 'token':
            x = x[:, 0]
        elif classifier == 'gap':
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)

        if representation_size is not None:
            x = nn.Dense(x, representation_size, name='pre_logits')
            x = nn.tanh(x)
        else:
            x = IdentityLayer(x, name='pre_logits')

        x = nn.Dense(x,
                     num_classes,
                     name='head',
                     kernel_init=nn.initializers.zeros)
        return x