def forward(self, x, cls_tokens): h, w = x.shape[2:4] x = rearrange(x, "b c h w -> b (h w) c") token_length = cls_tokens.shape[1] x = torch.cat((cls_tokens, x), dim=1) for blk in self.blocks: x = blk(x) cls_tokens = x[:, :token_length] x = x[:, token_length:] x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) return x, cls_tokens
def forward_features(self, x): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add the dist_token patch_size = self.patch_embed.patch_size[0] H, W = x.shape[-2:] H, W = H // patch_size, W // patch_size B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand( B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) # pick the spatial embed and do iterp pos_embed = self._get_pos_embed(H, W) x = x + pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) # x = self.norm(x) spatial = rearrange(x[:, 2:], "b (h w) c -> b c h w", h=H, w=W) return x[:, 0], x[:, 1], spatial
def apply_rotary_pos_emb(q, k, sinu_pos): sinu_pos = rearrange(sinu_pos, "() n (j d) -> n j d", j=2) sin, cos = sinu_pos.unbind(dim=-2) sin, cos = map(lambda t: repeat(t, "b n -> b (n j)", j=2), (sin, cos)) q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) return q, k
def rotate_every_two(x): x = rearrange(x, "... (d j) -> ... d j", j=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d j -> ... (d j)")