def __call__(self, inputs, *, train): x = inputs # (Possibly partial) ResNet root. if self.resnet is not None: width = int(64 * self.resnet.width_factor) # Root block. x = models_resnet.StdConv(features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name='conv_root')(x) x = nn.GroupNorm(name='gn_root')(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding='SAME') # ResNet stages. if self.resnet.num_layers: x = models_resnet.ResNetStage( block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name='block1')(x) for i, block_size in enumerate(self.resnet.num_layers[1:], 1): x = models_resnet.ResNetStage(block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f'block{i + 1}')(x) n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv(features=self.hidden_size, kernel_size=self.patches.size, strides=self.patches.size, padding='VALID', name='embedding')(x) # Here, x is a grid of embeddings. # Transformer. n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if self.classifier == 'token': cls = self.param('cls', nn.initializers.zeros, (1, 1, c)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(name='Transformer', **self.transformer)(x, train=train) if self.classifier == 'token': x = x[:, 0] elif self.classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) else: raise ValueError(f'Invalid classifier={self.classifier}') if self.representation_size is not None: x = nn.Dense(features=self.representation_size, name='pre_logits')(x) x = nn.tanh(x) else: x = IdentityLayer(name='pre_logits')(x) if self.num_classes: x = nn.Dense(features=self.num_classes, name='head', kernel_init=nn.initializers.zeros)(x) return x
def apply(self, x, num_classes=1000, train=False, resnet=None, patches=None, hidden_size=None, transformer=None, representation_size=None, classifier='gap'): # (Possibly partial) ResNet root. if resnet is not None: width = int(64 * resnet.width_factor) # Root block. x = models_resnet.StdConv(x, width, (7, 7), (2, 2), bias=False, name='conv_root') x = nn.GroupNorm(x, name='gn_root') x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') # ResNet stages. x = models_resnet.ResNetStage(x, resnet.num_layers[0], width, first_stride=(1, 1), name='block1') for i, block_size in enumerate(resnet.num_layers[1:], 1): x = models_resnet.ResNetStage(x, block_size, width * 2**i, first_stride=(2, 2), name=f'block{i + 1}') n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv(x, hidden_size, patches.size, strides=patches.size, padding='VALID', name='embedding') # Here, x is a grid of embeddings. # (Possibly partial) Transformer. if transformer is not None: n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if classifier == 'token': cls = self.param('cls', (1, 1, c), nn.initializers.zeros) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(x, train=train, name='Transformer', **transformer) if classifier == 'token': x = x[:, 0] elif classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) if representation_size is not None: x = nn.Dense(x, representation_size, name='pre_logits') x = nn.tanh(x) else: x = IdentityLayer(x, name='pre_logits') x = nn.Dense(x, num_classes, name='head', kernel_init=nn.initializers.zeros) return x