def apply(self, x, nout, strides=(1, 1), bottleneck=True): features = nout nout = nout * 4 if bottleneck else nout needs_projection = x.shape[-1] != nout or strides != (1, 1) residual = x if needs_projection: residual = StdConv(residual, nout, (1, 1), strides, bias=False, name="conv_proj") residual = nn.GroupNorm(residual, epsilon=1e-4, name="gn_proj") if bottleneck: x = StdConv(x, features, (1, 1), bias=False, name="conv1") x = nn.GroupNorm(x, epsilon=1e-4, name="gn1") x = nn.relu(x) x = StdConv(x, features, (3, 3), strides, bias=False, name="conv2") x = nn.GroupNorm(x, epsilon=1e-4, name="gn2") x = nn.relu(x) last_kernel = (1, 1) if bottleneck else (3, 3) x = StdConv(x, nout, last_kernel, bias=False, name="conv3") x = nn.GroupNorm(x, epsilon=1e-4, name="gn3", scale_init=nn.initializers.zeros) x = nn.relu(residual + x) return x
def apply(self, x, nout, strides=(1, 1)): needs_projection = x.shape[-1] != nout * 4 or strides != (1, 1) residual = x if needs_projection: residual = StdConv(residual, nout * 4, (1, 1), strides, bias=False, name='conv_proj') residual = nn.GroupNorm(residual, name='gn_proj') y = StdConv(x, nout, (1, 1), bias=False, name='conv1') y = nn.GroupNorm(y, name='gn1') y = nn.relu(y) y = StdConv(y, nout, (3, 3), strides, bias=False, name='conv2') y = nn.GroupNorm(y, name='gn2') y = nn.relu(y) y = StdConv(y, nout * 4, (1, 1), bias=False, name='conv3') y = nn.GroupNorm(y, name='gn3', scale_init=nn.initializers.zeros) y = nn.relu(residual + y) return y
def apply(self, x, num_classes=1000, train=False, width_factor=1, num_layers=50): del train blocks, bottleneck = get_block_desc(num_layers) width = int(64 * width_factor) # Root block x = StdConv(x, width, (7, 7), (2, 2), bias=False, name="conv_root") x = nn.GroupNorm(x, name="gn_root") x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") # Stages x = ResNetStage(x, blocks[0], width, first_stride=(1, 1), bottleneck=bottleneck, name="block1") for i, block_size in enumerate(blocks[1:], 1): x = ResNetStage(x, block_size, width * 2**i, first_stride=(2, 2), bottleneck=bottleneck, name=f"block{i + 1}") # Head x = jnp.mean(x, axis=(1, 2)) x = IdentityLayer(x, name="pre_logits") x = nn.Dense(x, num_classes, kernel_init=nn.initializers.zeros, name="head") return x
def apply(self, x, num_classes=1000, train=False, resnet=None, patches=None, hidden_size=None, transformer=None, representation_size=None, classifier='gap'): # (Possibly partial) ResNet root. if resnet is not None: width = int(64 * resnet.width_factor) # Root block. x = models_resnet.StdConv(x, width, (7, 7), (2, 2), bias=False, name='conv_root') x = nn.GroupNorm(x, name='gn_root') x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') # ResNet stages. x = models_resnet.ResNetStage(x, resnet.num_layers[0], width, first_stride=(1, 1), name='block1') for i, block_size in enumerate(resnet.num_layers[1:], 1): x = models_resnet.ResNetStage(x, block_size, width * 2**i, first_stride=(2, 2), name=f'block{i + 1}') n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv(x, hidden_size, patches.size, strides=patches.size, padding='VALID', name='embedding') # Here, x is a grid of embeddings. # (Possibly partial) Transformer. if transformer is not None: n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if classifier == 'token': cls = self.param('cls', (1, 1, c), nn.initializers.zeros) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(x, train=train, name='Transformer', **transformer) if classifier == 'token': x = x[:, 0] elif classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) if representation_size is not None: x = nn.Dense(x, representation_size, name='pre_logits') x = nn.tanh(x) else: x = IdentityLayer(x, name='pre_logits') x = nn.Dense(x, num_classes, name='head', kernel_init=nn.initializers.zeros) return x
def apply(self, x, num_classes=1, train=False, hidden_size=None, transformer=None, resnet_emb=None, representation_size=None): """Apply model on inputs. Args: x: the processed input patches and position annotations. num_classes: the number of output classes. 1 for single model. train: train or eval. hidden_size: the hidden dimension for patch embedding tokens. transformer: the model config for Transformer backbone. resnet_emb: the config for patch embedding w/ small resnet. representation_size: size of the last FC before prediction. Returns: Model prediction output. """ assert transformer is not None # Either 3: (batch size, seq len, channel) or # 4: (batch size, crops, seq len, channel) assert len(x.shape) in [3, 4] multi_crops_input = False if len(x.shape) == 4: multi_crops_input = True batch_size, num_crops, l, channel = x.shape x = jnp.reshape(x, [batch_size * num_crops, l, channel]) # We concat (x, spatial_positions, scale_posiitons, input_masks) # when preprocessing. inputs_spatial_positions = x[:, :, -3] inputs_spatial_positions = inputs_spatial_positions.astype(jnp.int32) inputs_scale_positions = x[:, :, -2] inputs_scale_positions = inputs_scale_positions.astype(jnp.int32) inputs_masks = x[:, :, -1] inputs_masks = inputs_masks.astype(jnp.bool_) x = x[:, :, :-3] n, l, channel = x.shape if hidden_size: if resnet_emb: # channel = patch_size * patch_size * 3 patch_size = int(np.sqrt(channel // 3)) x = jnp.reshape(x, [-1, patch_size, patch_size, 3]) x = resnet.StdConv( x, RESNET_TOKEN_DIM, (7, 7), (2, 2), bias=False, name="conv_root") x = nn.GroupNorm(x, name="gn_root") x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") if resnet_emb.num_layers > 0: blocks, bottleneck = resnet.get_block_desc(resnet_emb.num_layers) if blocks: x = resnet.ResNetStage( x, blocks[0], RESNET_TOKEN_DIM, first_stride=(1, 1), bottleneck=bottleneck, name="block1") for i, block_size in enumerate(blocks[1:], 1): x = resnet.ResNetStage( x, block_size, RESNET_TOKEN_DIM * 2**i, first_stride=(2, 2), bottleneck=bottleneck, name=f"block{i + 1}") x = jnp.reshape(x, [n, l, -1]) x = nn.Dense(x, hidden_size, name="embedding") # Here, x is a list of embeddings. x = utils.Encoder( x, inputs_spatial_positions, inputs_scale_positions, inputs_masks, train=train, name="Transformer", **transformer) x = x[:, 0] if representation_size: x = nn.Dense(x, representation_size, name="pre_logits") x = nn.tanh(x) else: x = resnet.IdentityLayer(x, name="pre_logits") x = nn.Dense(x, num_classes, name="head", kernel_init=nn.initializers.zeros) if multi_crops_input: _, channel = x.shape x = jnp.reshape(x, [batch_size, num_crops, channel]) return x