Exemple #1
0
    def forward(self, x, adj_kv_indices, mask):
        b, n, d, h = *x.shape, self.heads
        flat_indices = repeat(adj_kv_indices, 'b n a -> (b h) (n a)', h=h)

        # derive query, key, value
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
                      (q, k, v))

        # gather keys and values according to adjacency matrix
        k, v = map(lambda t: rearrange(t, 'b h n d -> (b h) n d'), (k, v))
        k = batched_index_select(k, flat_indices)
        v = batched_index_select(v, flat_indices)
        k, v = map(
            lambda t: rearrange(t, '(b h) (n a) d -> b h n a d', h=h, n=n),
            (k, v))

        # add null key / value, so a node can attend to nothing
        # have come across this in GNN literature as some other name
        nk, nv = map(
            lambda t: rearrange(t, 'h d -> () h () () d').expand(
                b, -1, n, 1, -1), (self.null_k, self.null_v))
        k = torch.cat((nk, k), dim=-2)
        v = torch.cat((nv, v), dim=-2)
        mask = F.pad(mask, (1, 0), value=1)

        # similarity of each node to its neighbors
        sim = einsum('b h n d, b h n a d -> b h n a', q, k) * self.scale

        # mask out neighbors that are just padding
        mask_value = -torch.finfo(sim.dtype).max
        mask = rearrange(mask.bool(), 'b n a -> b () n a')
        sim.masked_fill_(~mask.bool(), mask_value)

        # attention
        attn = sim.softmax(dim=-1)

        # dropout
        attn = self.dropout(attn)

        # get weighted average of the values of all neighbors
        out = einsum('b h n a, b h n a d -> b h n d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        # combine output
        return self.to_out(out)
Exemple #2
0
    def __getitem__(self, index):
        image = self.data[index][self.FULL_IMAGE_INDEX]
        burst, flow = self.dynamic_fxn(image)
        noisy = self.noise_fxn(burst + 0.5)

        T = burst.shape[0]
        sburst = repeat(burst[T // 2], 'c h w -> t c h w', t=T)
        snoisy = self.noise_fxn(sburst)

        sample = {
            'burst': burst,
            'noisy': noisy,
            'flow': flow,
            'sburst': sburst,
            'snoisy': snoisy
        }
        return sample
Exemple #3
0
def scn_backbone_mask(scn_seq, boolean=True, l_aa=NUM_COORDS_PER_RES):
    """ Gets the boolean mask for N and CA positions. 
        Inputs: 
        * scn_seq: sequence as provided by Sidechainnet package
        * bool: whether to return as array of idxs or boolean values
        Outputs: (N_mask, CA_mask)
    """
    lengths = torch.arange(scn_seq.shape[-1] * l_aa)
    # repeat if needed:
    if len(lengths.shape) == 2:
        lengths = repeat(lengths, 'l -> b l', b=scn_seq.shape[0])
    # N is the first atom in every AA. CA is the 2nd.
    N_mask = lengths % l_aa == 0
    CA_mask = lengths % l_aa == 1
    if boolean:
        return N_mask, CA_mask
    return N_mask.nonzero(), CA_mask.nonzero()
Exemple #4
0
    def forward(self, x):
        #不需要embe
        #x = self.to_patch_embedding(img)
        x = self.lstm(x)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)
    def forward(self, img):
        p = self.patch_size

        x = rearrange(img,
                      'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                      p1=p,
                      p2=p)
        x = self.patch_to_embedding(x)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.transformer(x)

        x = self.to_cls_token(x[:, 0])
        return self.mlp_head(x)
    def forward(self, x, einops_from, einops_to, **einops_dims):
        h = self.num_heads
        # project x to q, k, v vaalues
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
                      (q, k, v))

        q *= self.scale

        # splice out CLS token at index 1
        (cls_q, q_), (cls_k, k_), (cls_v,
                                   v_) = map(lambda t: (t[:, 0:1], t[:, 1:]),
                                             (q, k, v))

        # let CLS token attend to key / values of all patches across time and space
        cls_out = attn(cls_q, k, v)
        # rearrange across time or space
        q_, k_, v_ = map(
            lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **
                                einops_dims), (q_, k_, v_))

        # expand cls token keys and values across time or space and concat
        r = q_.shape[0] // cls_k.shape[0]
        cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r),
                           (cls_k, cls_v))

        k_ = torch.cat((cls_k, k_), dim=1)
        v_ = torch.cat((cls_v, v_), dim=1)

        # attention
        out = attn(q_, k_, v_)

        # out_attn_scores = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
        # merge back time or space
        out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)

        # concat back the cls token
        out = torch.cat((cls_out, out), dim=1)

        # merge back the heads
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        ## to out
        x = self.proj(out)
        x = self.proj_drop(x)
        return x
Exemple #7
0
def compute_pixel_difference(blocks,block_search_space):
    # -- vectorize search since single patch --
    R,B,T,N,C,PS1,PS2 = blocks.shape
    REF_N = get_ref_block_index(int(np.sqrt(N)))
    #print(cfg.nframes,T,cfg.nblocks,N,block_search_space.shape)
    assert (R == 1) and (B == 1), "single pixel's block and single sample please."
    expanded = blocks[:,:,np.arange(T),block_search_space]
    E = expanded.shape[2]
    R,B,E,T,C,H,W = expanded.shape
    PS = PS1

    ref = repeat(expanded[:,:,:,T//2],'r b e c h w -> r b e tile c h w',tile=T-1)
    neighbors = torch.cat([expanded[:,:,:,:T//2],expanded[:,:,:,T//2+1:]],dim=3)
    delta = F.mse_loss(ref[...,PS//2,PS//2],neighbors[...,PS//2,PS//2],reduction='none')
    delta = delta.view(R,B,E,-1)
    delta = torch.mean(delta,dim=3)
    pix_diff = delta[0,0]
    return pix_diff
Exemple #8
0
    def forward(self, img, mask = None):
        p = self.patch_size

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        x = self.patch_to_embedding(x)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x, mask)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)
Exemple #9
0
    def forward(self, x, mask=None):
        b, n, device = *x.shape, x.device

        x = self.token_emb(x)
        rel_pos_emb = self.rel_pos_emb(x) if exists(self.rel_pos_emb) else None

        cls_tokens = repeat(self.cls_token, 'd -> b () d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)

        if exists(mask):
            mask = F.pad(mask, (1, 0), value=True)

        pos_emb = self.pos_emb(x)
        x += rearrange(pos_emb, 'n d -> () n d')

        x = self.net(x)

        return self.norm(x[:, 0])
Exemple #10
0
def ref_flow_to_pix_torch(_flow, centers):
    # -- copy --
    pix = _flow.clone()

    # -- compute deltas to ref --
    nsamples, nframes, two = pix.shape
    ref_frame = nframes // 2

    # -- change from _spatial_ _object_ motion into _image coords_ _object_ motion
    pix[..., 1] = -pix[..., 1]

    # -- add locations --
    centers = repeat(centers, 's two -> s t two', t=nframes)

    # -- create pix --
    pix += centers

    return pix
Exemple #11
0
    def forward(self, x, pos_embedding=None, return_intermediate=False):
        # TODO before it was possible to choose the input format as an image or a sequence
        # if self.cfg["image_input"]:
        #     x = self.rearange_tensor(x)
        #     if pos_embedding is not None:
        #         pos_embedding = self.rearange_tensor(pos_embedding)

        b, n, _ = x.shape

        tokens_0 = repeat(self.token_0, '() n d -> b n d', b=b)

        if pos_embedding is None:
            pass
        else:
            x = x + pos_embedding

        x = torch.cat((tokens_0, x), dim=1)

        x = self.dropout(x)

        x = self.transformer(x)

        if return_intermediate:
            output = {}
            output["intermediate"] = x
        x = x["output"]

        if self.cfg["pool"] == "token_0":
            x = x[:, 0]
        elif self.cfg["pool"] == "mean":
            x = x.mean(dim=1)
        elif self.cfg["pool"] == "none":
            pass
        else:
            raise NotImplementedError

        x = self.to_latent(x)
        x = self.layer_norm(x)

        if return_intermediate:
            output["output"] = x
            return output
        else:
            return x
Exemple #12
0
    def forward(self, x: torch.Tensor, mask: torch.LongTensor) -> torch.Tensor:
        """apply image positional encoding to feature

        Parameters
        ----------
        x : torch.Tensor
            [b, h, w, d]
        mask: torch.LongTensor
            [b, h, w]

        Returns
        -------
        torch.Tensor
            [b, h, w, d]
        """
        not_mask = ~mask
        embed_y = not_mask.cumsum(1, dtype=torch.float32)
        embed_x = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            embed_y = embed_y / (embed_y[:, -1:, :] + eps) * self.scale
            embed_x = embed_x / (embed_x[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(
            0, self.half_d_model, 2, dtype=torch.float, device=self.device
        )
        inv_feq = 1.0 / (self.temperature ** (dim_t / self.half_d_model))

        # [b, h, w, d_model // 4]
        pos_x = torch.einsum("b h w, d -> b h w d", embed_x, inv_feq)
        pos_y = torch.einsum("b h w, d -> b h w d", embed_y, inv_feq)

        # [b, h, w, d_model // 2]
        sin_x, cos_x, sin_y, cos_y = map(
            lambda t: repeat(t, "b h w d -> b h w (d n)", n=2),
            (pos_x.sin(), pos_x.cos(), pos_y.sin(), pos_y.cos()),
        )

        # [b, h, w, d_model]
        sin = torch.cat((sin_x, sin_y), dim=-1)
        cos = torch.cat((cos_x, cos_y), dim=-1)

        x = (x * cos) + (rotate_every_two(x) * sin)
        return x
    def forward(self, img):
        p = self.patch_size

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        x = self.patch_to_embedding(x)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.attn_layers(x)
        x = self.norm(x)

        if not exists(self.mlp_head):
            return x

        return self.mlp_head(x[:, 0])
Exemple #14
0
def shuffle_aligned_pixels(aligned, R):
    T, B, C, H, W = aligned.shape
    aligned = rearrange(aligned, 'n b c h w -> n b c (h w)')
    shuffled = repeat(aligned[0].clone(), 'b c hw -> r b c hw', r=R)
    hw_grid = torch.arange(H * W)
    for b in range(B):
        for r in range(R):
            indices = torch.randint(T, (H * W, )).long().to(aligned.device)
            for c in range(C):
                shuffled[r, b, c, :] = aligned[indices[r], b, c, hw_grid]
    shuffled = rearrange(shuffled, 'r b c (h w) -> r b c h w', h=H)
    # aligned = rearrange(aligned,'n b c (h w) -> n b c h w',h=H)
    # images = [shuffled,aligned]
    # cropped = crop_center_patch(images,3,128)
    # shuffled,aligned = images[0],images[1]
    # print_tensor_stats("shuffled - aligned",shuffled - aligned)
    # print_tensor_stats("aligned0 - aligned1",aligned[0] - aligned[1])
    # exit()
    return shuffled
Exemple #15
0
    def forward(self, img):
        x = self.conv(img)
        x = self.to_patch_embedding(x)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x, c = self.transformer(x)

        if self.with_lca:
            x = self.LCA(c)[:, 0]
        else:
            x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)
Exemple #16
0
def compute_flownet_of(model, cfg, expanded):
    R, B, E, T, C, H, W = expanded.shape
    device = expanded.device
    model = model.to(device)

    samples = rearrange(expanded, 'r b e t c h w -> t (r c h w) (b e)')
    samples = samples.contiguous()  # speed up?
    t_ref = T // 2

    # -- shape to pairs --
    ref = repeat(samples[T // 2], 'd be -> tile d be', tile=T - 1)
    neighbors = torch.cat([samples[:T // 2], samples[T // 2 + 1:]], dim=0)
    pairs = torch.stack([ref, neighbors], dim=0)  # 2, T-1, D, BE
    pairs = rearrange(pairs, 'two tm1 d be -> (tm1 be) two d')

    # -- compute flow --
    flow = model(pairs)

    flows = model(samples)
Exemple #17
0
    def forward(self, img, return_sampled_token_ids=False):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x, token_ids = self.transformer(x)

        logits = self.mlp_head(x[:, 0])

        if return_sampled_token_ids:
            # remove CLS token and decrement by 1 to make -1 the padding
            token_ids = token_ids[:, 1:] - 1
            return logits, token_ids

        return logits
Exemple #18
0
    def forward(self, x, einops_from, einops_to, **einops_dims):
        h = self.heads
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
                      (q, k, v))

        q *= self.scale

        # splice out classification token at index 1
        (cls_q, q_), (cls_k, k_), (cls_v,
                                   v_) = map(lambda t: (t[:, 0:1], t[:, 1:]),
                                             (q, k, v))

        # classification token attend to key/value of all patches across time and space
        cls_out = attn(cls_q, k, v)

        # reararange across time or space
        q_, k_, v_ = map(
            lambda t: rearrange(t, f'{einops_from}->{einops_to}', **einops_dims
                                ), (q_, k_, v_))

        # expand cls token keys and values across time or space
        r = q_.shape[0] // cls_k.shape[0]
        cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r),
                           (cls_k, cls_v))

        k_ = torch.cat((cls_k, k_), dim=1)
        v_ = torch.cat((cls_v, v_), dim=1)

        # attention
        out = attn(q_, k_, v_)

        # merge back time or space
        out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)

        # concat back the cls token
        out = torch.cat((cls_out, out), dim=1)

        # merge back the heads
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)

        return self.to_out(out)
Exemple #19
0
    def forward(self, src, tgt=None, mems=None, src_mask=None, tgt_mask=None):
        b, n, num_mem, device = *src.shape, self.num_mem, src.device
        mems = default(mems, lambda: self.get_initial_mem(b))

        enc = self.encoder(src, context=mems, src_mask=src_mask)

        if exists(self.decoder) and exists(tgt):
            dec_out = self.decoder(tgt,
                                   context=enc,
                                   src_mask=tgt_mask,
                                   tgt_mask=src_mask,
                                   return_loss=True)
        else:
            dec_out = torch.tensor(0., requires_grad=True, device=device)

        # update memory with attention
        mem_mask = torch.eye(num_mem, num_mem, device=device).bool()
        mem_mask = repeat(mem_mask, 'i j -> b i j', b=b)
        mem_mask = F.pad(mem_mask, (0, n), value=True)

        if exists(src_mask):
            src_mask = rearrange(src_mask, 'b j -> b () j')
            mem_enc_mask = F.pad(src_mask, (num_mem, 0), value=True)
            mem_mask &= mem_enc_mask

        for _ in range(self.num_mem_updates):
            prev_mems = mems
            updated_mems = self.mem_updater(mems,
                                            enc,
                                            mask=mem_mask,
                                            attend_self=True)

            next_mems = self.gru(rearrange(updated_mems, 'b n d -> (b n) d'),
                                 rearrange(prev_mems, 'b n d -> (b n) d'))

            mems = rearrange(next_mems, '(b n) d -> b n d', b=b)
            mems = self.mem_ff(mems)

        if not exists(self.decoder):
            return EncOnlyResults(enc, mems)

        return Results(enc, mems, dec_out)
Exemple #20
0
    def __call__(self, x):
        bs, dim = x.shape[0], x.shape[-1]
        latents = self.param(
            "latents", init.normal(), (self.n_latents, dim * self.ff_mult)
        )
        latent = repeat(latents, "n d -> b n d", b=bs)

        x = fourier_encode(x, self.n_fourier_features)
        x = rearrange(x, "b n ... -> b n (...)")

        cross_attn = partial(
            Attention,
            heads=self.cross_n_heads,
            head_features=self.cross_head_features,
            dropout=self.attn_dropout,
        )
        latent_attn = partial(
            Attention,
            heads=self.latent_n_heads,
            head_features=self.latent_head_features,
            dropout=self.attn_dropout,
        )
        ff = partial(FeedForward, mult=self.ff_mult, dropout=self.ff_dropout)
        if self.tie_layer_weights:
            ca = cross_attn(name="cross_attn")
            la = latent_attn(name="latent_attn")
            cf = ff(name="cross_ff")
            lf = ff(name="latent_ff")
            for i in range(self.depth):
                rz = ReZero(name=f"rezero_{i}")
                latent += rz(ca(latent, x))
                latent += rz(cf(latent))
                latent += rz(la(latent))
                latent += rz(lf(latent))
        else:
            for i in range(self.depth):
                rz = ReZero(name=f"rezero_{i}")
                latent += rz(cross_attn(name=f"cross_attn_{i}")(latent, x))
                latent += rz(ff(name=f"cross_ff_{i}")(latent))
                latent += rz(latent_attn(name=f"latent_attn_{i}")(latent))
                latent += rz(ff(name=f"latent_ff_{i}")(latent))
        return latent
Exemple #21
0
def seq_flow_to_pix_torch(_flow, centers):

    # -- copy --
    flow = _flow.clone()

    # -- compute deltas to ref --
    nsamples, nframes_minus_1, two = flow.shape
    nframes = nframes_minus_1 + 1
    ref_frame = nframes // 2

    # -- init pix --
    flip, csum = torch.fliplr, torch.cumsum
    zeros = torch.zeros((nsamples, 1, 2), device=flow.device)
    left_idx = slice(None, nframes // 2)
    right_idx = slice(nframes // 2, None)

    # -- change from _spatial_ _object_ motion into _image coords_ _object_ motion
    flow[..., 1] = -flow[..., 1]

    # -- swap dx and dy --
    r"""
    go from

        "x -> x -> x*" to get "sum(->,->), sum(->)"

    1. (1st flip) the first element is _further_ from ref than the left_idx[-1] element
    2. The cumulative sum goes from single arrow to sum of arrows
    3. (2nd flip) back to original order
    4. (negative) the origin of the starts from _ref_ location.
    """

    left = -flip(csum(flip(flow[:, left_idx]), 1))
    right = csum(flow[:, right_idx], 1)
    pix = torch.cat([left, zeros, right], dim=1)

    # -- add locations --
    centers = repeat(centers, 's two -> s t two', t=nframes)

    # -- create pix --
    pix += centers

    return pix
    def forward(self, feats, coors, adj_mat=None, edges=None, mask=None):
        b = feats.shape[0]

        if exists(self.token_emb):
            feats = self.token_emb(feats)

        if exists(edges) and exists(self.edge_emb):
            edges = self.edge_emb(edges)

        # create N-degrees adjacent matrix from 1st degree connections
        if exists(self.num_adj_degrees):
            assert exists(
                adj_mat
            ), 'adjacency matrix must be passed in (keyword argument adj_mat)'

            if len(adj_mat.shape) == 2:
                adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b=b)

            adj_indices = adj_mat.clone().long()

            for ind in range(self.num_adj_degrees - 1):
                degree = ind + 2

                next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
                next_degree_mask = (next_degree_adj_mat.float() -
                                    adj_mat.float()).bool()
                adj_indices.masked_fill_(next_degree_mask, degree)
                adj_mat = next_degree_adj_mat.clone()

            if exists(self.adj_emb):
                adj_emb = self.adj_emb(adj_indices)
                edges = torch.cat(
                    (edges, adj_emb), dim=-1) if exists(edges) else adj_emb

        for layer in self.layers:
            feats, coors = layer(feats,
                                 coors,
                                 adj_mat=adj_mat,
                                 edges=edges,
                                 mask=mask)

        return feats, coors
Exemple #23
0
    def forward(self,
                img,
                mask=None,
                distill_token=None,
                return_attention=False):
        assert distill_token is not None
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        if len(distill_token.shape) == 2:
            distill_token = distill_token.unsqueeze(1)
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x = torch.cat((x, distill_token), dim=1)
        x += self.pos_embedding[:, :(n + 2)]
        x = self.dropout(x)

        x = self.transformer(x, mask, return_attention)
        if return_attention:
            return x[0], x[1]
        return x
Exemple #24
0
    def forward(self, img):
        p = self.patch_size

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        x = self.patch_to_embedding(x)
        batch_size, seq_len, _ = x.shape
        b, n = batch_size, seq_len

        class_tokens = repeat(self.class_token, '() n d -> b n d', b = b)

        out = torch.cat((class_tokens, x), dim=1)
        out += self.pos_embedding[:, :(n + 1)]
        out = self.dropout(out)

        out = self.transformer(out)

        out = out[:, 0]
        out = self.mlp(out)

        return out
    def forward(self, video):
        b, f, _, h, w, *_, device, p = *video.shape, video.device, self.patch_size
        assert h % p == 0 and w % p == 0, f'height {h} and width {w} of video must be divisible by the patch size {p}'

        n = (h // p) * (w // p)

        video = rearrange(video, 'b f c (h p1) (w p2) -> b (f h w) (p1 p2 c)', p1 = p, p2 = p)
        tokens = self.to_patch_embedding(video)

        cls_token = repeat(self.cls_token, 'n d -> b n d', b = b)
        x =  torch.cat((cls_token, tokens), dim = 1)
        x += self.pos_emb(torch.arange(x.shape[1], device = device))

        for (time_attn, spatial_attn, ff) in self.layers:
            x = time_attn(x, 'b (f n) d', '(b n) f d', n = n) + x
            x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f) + x
            x = ff(x) + x

        cls_token = x[:, 0]
        return self.to_out(cls_token)
Exemple #26
0
def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, \
    normalize_data = True, device=None):
    """Generalized features = kernel_fn(W^T * data) or kernel_fn(data).

    Args:
        data: (:obj:`tensor`) Input query or key for which features are computed.
        projection_matrix: (:obj:`tensor`) Random matrix W used to compute features.
        kernel_fn: (:obj:`Callable`, 'optional`, defaults to :class:`nn.ReLU()`) Basis function to
            produce features.
        kernel_epsilon: (:obj:`float`, `optional`, defaults to 1e-3) Bias added to produced
            features for numerical stability(?).
        normalize_data: (:obj:`bool`, `optional`, defaults to :obj:`True`) Whether or not to
            normalize data.

    Returns:
        data_prime: (:obj:`tensor`) Random features.

    Shape for inputs:
        - data: (batch, heads, seq_length, hidden_size)
        - projection_matrix: (nb_features, w_size)

    Shape for outputs:
        - data_prime: (batch, heads, seq_length, hidden_size) if `no_projection=True`.
            (batch, heads, seq_length, nb_features) if `no_projection=False`.
    """
    b, h, *_ = data.shape

    data_normalizer = (data.shape[-1]**-0.25) if normalize_data else 1.

    if projection_matrix is None:
        return kernel_fn(data_normalizer * data) + kernel_epsilon

    projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h)
    projection = projection.type_as(data)

    data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data),
                             projection)

    data_prime = kernel_fn(data_dash) + kernel_epsilon

    return data_prime.type_as(data)
Exemple #27
0
def homo_warp(src_feat, proj_mat, depth_values):
    """
    src_feat: (B, C, H, W)
    proj_mat: (B, 3, 4) equal to "src_proj @ ref_proj_inv"
    depth_values: (B, D, H, W)
    out: (B, C, D, H, W)
    """
    B, C, H, W = src_feat.shape
    D = depth_values.shape[1]
    device = src_feat.device

    R = proj_mat[:, :, :3] # (B, 3, 3)
    T = proj_mat[:, :, 3:] # (B, 3, 1)
    # create grid from the ref frame
    ref_grid = create_meshgrid(H, W, normalized_coordinates=False,
                               device=device) # (1, H, W, 2)
    ref_grid = rearrange(ref_grid, '1 h w c -> 1 c (h w)') # (1, 2, H*W)
    ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W)
    ref_grid = torch.cat((ref_grid, torch.ones_like(ref_grid[:,:1])), 1) # (B, 3, H*W)
    ref_grid_d = repeat(ref_grid, 'b c x -> b c (d x)', d=D) # (B, 3, D*H*W)
    src_grid_d = R @ ref_grid_d + T/rearrange(depth_values, 'b d h w -> b 1 (d h w)')
    del ref_grid_d, ref_grid, proj_mat, R, T, depth_values # release (GPU) memory
    
    # project negative depth pixels to somewhere outside the image
    negative_depth_mask = src_grid_d[:, 2:] <= 1e-7
    src_grid_d[:, 0:1][negative_depth_mask] = W
    src_grid_d[:, 1:2][negative_depth_mask] = H
    src_grid_d[:, 2:3][negative_depth_mask] = 1

    src_grid = src_grid_d[:, :2] / src_grid_d[:, 2:] # divide by depth (B, 2, D*H*W)
    del src_grid_d
    src_grid[:, 0] = src_grid[:, 0]/((W-1)/2) - 1 # scale to -1~1
    src_grid[:, 1] = src_grid[:, 1]/((H-1)/2) - 1 # scale to -1~1
    src_grid = rearrange(src_grid, 'b c (d h w) -> b d (h w) c', d=D, h=H, w=W)

    warped_src_feat = F.grid_sample(src_feat, src_grid,
                                    mode='bilinear', padding_mode='zeros',
                                    align_corners=True) # (B, C, D, H*W)
    warped_src_feat = rearrange(warped_src_feat, 'b c d (h w) -> b c d h w', h=H, w=W)

    return warped_src_feat
Exemple #28
0
    def forward(self, x, einops_from, einops_to, mask = None, cls_mask = None, rot_emb = None, **einops_dims):
        h = self.heads
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        q = q * self.scale

        # splice out classification token at index 1
        (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k, v))

        # let classification token attend to key / values of all patches across time and space
        cls_out = attn(cls_q, k, v, mask = cls_mask)

        # rearrange across time or space
        q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_))

        # add rotary embeddings, if applicable
        if exists(rot_emb):
            q_, k_ = apply_rot_emb(q_, k_, rot_emb)

        # expand cls token keys and values across time or space and concat
        r = q_.shape[0] // cls_k.shape[0]
        cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v))

        k_ = torch.cat((cls_k, k_), dim = 1)
        v_ = torch.cat((cls_v, v_), dim = 1)

        # attention
        out = attn(q_, k_, v_, mask = mask)

        # merge back time or space
        out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)

        # concat back the cls token
        out = torch.cat((cls_out, out), dim = 1)

        # merge back the heads
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)

        # combine heads out
        return self.to_out(out)
Exemple #29
0
def reconstruct_ref_fullpersp(normalized_2d, coords3d_rel, validity_mask):
    """Reconstructs the reference point location.

    Args:
      normalized_2d: normalized image coordinates of the joints
         (without intrinsics applied), shape [batch_size, n_points, 2]
      coords3d_rel: 3D camera coordinate offsets relative to the unknown reference
         point which we want to reconstruct, shape [batch_size, n_points, 3]
      validity_mask: boolean mask of shape [batch_size, n_points] containing True
         where the point is reliable and should be used in the reconstruction

    Returns:
      The 3D reference point in camera coordinates, shape [batch_size, 3]
    """
    def rms_normalize(x):
        scale = tf.sqrt(tf.reduce_mean(tf.square(x)))
        normalized = x / scale
        return scale, normalized

    n_batch = tf.shape(normalized_2d)[0]
    n_points = normalized_2d.shape.as_list()[1]
    eyes = tf.tile(tf.expand_dims(tf.eye(2, 2), 0), [n_batch, n_points, 1])
    scale2d, reshaped2d = rms_normalize(
        tf.reshape(normalized_2d, [-1, n_points * 2, 1]))
    A = tf.concat([eyes, -reshaped2d], axis=2)

    rel_backproj = normalized_2d * coords3d_rel[:, :,
                                                2:] - coords3d_rel[:, :, :2]
    scale_rel_backproj, b = rms_normalize(
        tf.reshape(rel_backproj, [-1, n_points * 2, 1]))

    weights = tf.cast(validity_mask, tf.float32) + np.float32(1e-4)
    weights = einops.repeat(weights, 'b j -> b (j c) 1', c=2)

    ref = tf.linalg.lstsq(A * weights,
                          b * weights,
                          l2_regularizer=1e-2,
                          fast=True)
    ref = tf.concat([ref[:, :2], ref[:, 2:] / scale2d],
                    axis=1) * scale_rel_backproj
    return tf.squeeze(ref, axis=-1)
Exemple #30
0
def generalized_kernel(data,
                       *,
                       projection_matrix,
                       kernel_fn=nn.ReLU(),
                       kernel_epsilon=0.001,
                       normalize_data=True,
                       device=None):
    b, h, *_ = data.shape

    data_normalizer = (data.shape[-1]**-0.25) if normalize_data else 1.

    if projection_matrix is None:
        return kernel_fn(data_normalizer * data) + kernel_epsilon

    projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h)

    data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data),
                             projection)

    data_prime = kernel_fn(data_dash) + kernel_epsilon
    return data_prime