def forward(self, x, adj_kv_indices, mask): b, n, d, h = *x.shape, self.heads flat_indices = repeat(adj_kv_indices, 'b n a -> (b h) (n a)', h=h) # derive query, key, value q, k, v = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) # gather keys and values according to adjacency matrix k, v = map(lambda t: rearrange(t, 'b h n d -> (b h) n d'), (k, v)) k = batched_index_select(k, flat_indices) v = batched_index_select(v, flat_indices) k, v = map( lambda t: rearrange(t, '(b h) (n a) d -> b h n a d', h=h, n=n), (k, v)) # add null key / value, so a node can attend to nothing # have come across this in GNN literature as some other name nk, nv = map( lambda t: rearrange(t, 'h d -> () h () () d').expand( b, -1, n, 1, -1), (self.null_k, self.null_v)) k = torch.cat((nk, k), dim=-2) v = torch.cat((nv, v), dim=-2) mask = F.pad(mask, (1, 0), value=1) # similarity of each node to its neighbors sim = einsum('b h n d, b h n a d -> b h n a', q, k) * self.scale # mask out neighbors that are just padding mask_value = -torch.finfo(sim.dtype).max mask = rearrange(mask.bool(), 'b n a -> b () n a') sim.masked_fill_(~mask.bool(), mask_value) # attention attn = sim.softmax(dim=-1) # dropout attn = self.dropout(attn) # get weighted average of the values of all neighbors out = einsum('b h n a, b h n a d -> b h n d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') # combine output return self.to_out(out)
def __getitem__(self, index): image = self.data[index][self.FULL_IMAGE_INDEX] burst, flow = self.dynamic_fxn(image) noisy = self.noise_fxn(burst + 0.5) T = burst.shape[0] sburst = repeat(burst[T // 2], 'c h w -> t c h w', t=T) snoisy = self.noise_fxn(sburst) sample = { 'burst': burst, 'noisy': noisy, 'flow': flow, 'sburst': sburst, 'snoisy': snoisy } return sample
def scn_backbone_mask(scn_seq, boolean=True, l_aa=NUM_COORDS_PER_RES): """ Gets the boolean mask for N and CA positions. Inputs: * scn_seq: sequence as provided by Sidechainnet package * bool: whether to return as array of idxs or boolean values Outputs: (N_mask, CA_mask) """ lengths = torch.arange(scn_seq.shape[-1] * l_aa) # repeat if needed: if len(lengths.shape) == 2: lengths = repeat(lengths, 'l -> b l', b=scn_seq.shape[0]) # N is the first atom in every AA. CA is the 2nd. N_mask = lengths % l_aa == 0 CA_mask = lengths % l_aa == 1 if boolean: return N_mask, CA_mask return N_mask.nonzero(), CA_mask.nonzero()
def forward(self, x): #不需要embe #x = self.to_patch_embedding(img) x = self.lstm(x) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x) x = self.transformer(x) x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] x = self.to_latent(x) return self.mlp_head(x)
def forward(self, img): p = self.patch_size x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) x = self.patch_to_embedding(x) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] x = self.transformer(x) x = self.to_cls_token(x[:, 0]) return self.mlp_head(x)
def forward(self, x, einops_from, einops_to, **einops_dims): h = self.num_heads # project x to q, k, v vaalues q, k, v = self.qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q *= self.scale # splice out CLS token at index 1 (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v)) # let CLS token attend to key / values of all patches across time and space cls_out = attn(cls_q, k, v) # rearrange across time or space q_, k_, v_ = map( lambda t: rearrange(t, f'{einops_from} -> {einops_to}', ** einops_dims), (q_, k_, v_)) # expand cls token keys and values across time or space and concat r = q_.shape[0] // cls_k.shape[0] cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v)) k_ = torch.cat((cls_k, k_), dim=1) v_ = torch.cat((cls_v, v_), dim=1) # attention out = attn(q_, k_, v_) # out_attn_scores = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) # merge back time or space out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) # concat back the cls token out = torch.cat((cls_out, out), dim=1) # merge back the heads out = rearrange(out, '(b h) n d -> b n (h d)', h=h) ## to out x = self.proj(out) x = self.proj_drop(x) return x
def compute_pixel_difference(blocks,block_search_space): # -- vectorize search since single patch -- R,B,T,N,C,PS1,PS2 = blocks.shape REF_N = get_ref_block_index(int(np.sqrt(N))) #print(cfg.nframes,T,cfg.nblocks,N,block_search_space.shape) assert (R == 1) and (B == 1), "single pixel's block and single sample please." expanded = blocks[:,:,np.arange(T),block_search_space] E = expanded.shape[2] R,B,E,T,C,H,W = expanded.shape PS = PS1 ref = repeat(expanded[:,:,:,T//2],'r b e c h w -> r b e tile c h w',tile=T-1) neighbors = torch.cat([expanded[:,:,:,:T//2],expanded[:,:,:,T//2+1:]],dim=3) delta = F.mse_loss(ref[...,PS//2,PS//2],neighbors[...,PS//2,PS//2],reduction='none') delta = delta.view(R,B,E,-1) delta = torch.mean(delta,dim=3) pix_diff = delta[0,0] return pix_diff
def forward(self, img, mask = None): p = self.patch_size x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) x = self.patch_to_embedding(x) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x) x = self.transformer(x, mask) x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] x = self.to_latent(x) return self.mlp_head(x)
def forward(self, x, mask=None): b, n, device = *x.shape, x.device x = self.token_emb(x) rel_pos_emb = self.rel_pos_emb(x) if exists(self.rel_pos_emb) else None cls_tokens = repeat(self.cls_token, 'd -> b () d', b=b) x = torch.cat((cls_tokens, x), dim=1) if exists(mask): mask = F.pad(mask, (1, 0), value=True) pos_emb = self.pos_emb(x) x += rearrange(pos_emb, 'n d -> () n d') x = self.net(x) return self.norm(x[:, 0])
def ref_flow_to_pix_torch(_flow, centers): # -- copy -- pix = _flow.clone() # -- compute deltas to ref -- nsamples, nframes, two = pix.shape ref_frame = nframes // 2 # -- change from _spatial_ _object_ motion into _image coords_ _object_ motion pix[..., 1] = -pix[..., 1] # -- add locations -- centers = repeat(centers, 's two -> s t two', t=nframes) # -- create pix -- pix += centers return pix
def forward(self, x, pos_embedding=None, return_intermediate=False): # TODO before it was possible to choose the input format as an image or a sequence # if self.cfg["image_input"]: # x = self.rearange_tensor(x) # if pos_embedding is not None: # pos_embedding = self.rearange_tensor(pos_embedding) b, n, _ = x.shape tokens_0 = repeat(self.token_0, '() n d -> b n d', b=b) if pos_embedding is None: pass else: x = x + pos_embedding x = torch.cat((tokens_0, x), dim=1) x = self.dropout(x) x = self.transformer(x) if return_intermediate: output = {} output["intermediate"] = x x = x["output"] if self.cfg["pool"] == "token_0": x = x[:, 0] elif self.cfg["pool"] == "mean": x = x.mean(dim=1) elif self.cfg["pool"] == "none": pass else: raise NotImplementedError x = self.to_latent(x) x = self.layer_norm(x) if return_intermediate: output["output"] = x return output else: return x
def forward(self, x: torch.Tensor, mask: torch.LongTensor) -> torch.Tensor: """apply image positional encoding to feature Parameters ---------- x : torch.Tensor [b, h, w, d] mask: torch.LongTensor [b, h, w] Returns ------- torch.Tensor [b, h, w, d] """ not_mask = ~mask embed_y = not_mask.cumsum(1, dtype=torch.float32) embed_x = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 embed_y = embed_y / (embed_y[:, -1:, :] + eps) * self.scale embed_x = embed_x / (embed_x[:, :, -1:] + eps) * self.scale dim_t = torch.arange( 0, self.half_d_model, 2, dtype=torch.float, device=self.device ) inv_feq = 1.0 / (self.temperature ** (dim_t / self.half_d_model)) # [b, h, w, d_model // 4] pos_x = torch.einsum("b h w, d -> b h w d", embed_x, inv_feq) pos_y = torch.einsum("b h w, d -> b h w d", embed_y, inv_feq) # [b, h, w, d_model // 2] sin_x, cos_x, sin_y, cos_y = map( lambda t: repeat(t, "b h w d -> b h w (d n)", n=2), (pos_x.sin(), pos_x.cos(), pos_y.sin(), pos_y.cos()), ) # [b, h, w, d_model] sin = torch.cat((sin_x, sin_y), dim=-1) cos = torch.cat((cos_x, cos_y), dim=-1) x = (x * cos) + (rotate_every_two(x) * sin) return x
def forward(self, img): p = self.patch_size x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) x = self.patch_to_embedding(x) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x) x = self.attn_layers(x) x = self.norm(x) if not exists(self.mlp_head): return x return self.mlp_head(x[:, 0])
def shuffle_aligned_pixels(aligned, R): T, B, C, H, W = aligned.shape aligned = rearrange(aligned, 'n b c h w -> n b c (h w)') shuffled = repeat(aligned[0].clone(), 'b c hw -> r b c hw', r=R) hw_grid = torch.arange(H * W) for b in range(B): for r in range(R): indices = torch.randint(T, (H * W, )).long().to(aligned.device) for c in range(C): shuffled[r, b, c, :] = aligned[indices[r], b, c, hw_grid] shuffled = rearrange(shuffled, 'r b c (h w) -> r b c h w', h=H) # aligned = rearrange(aligned,'n b c (h w) -> n b c h w',h=H) # images = [shuffled,aligned] # cropped = crop_center_patch(images,3,128) # shuffled,aligned = images[0],images[1] # print_tensor_stats("shuffled - aligned",shuffled - aligned) # print_tensor_stats("aligned0 - aligned1",aligned[0] - aligned[1]) # exit() return shuffled
def forward(self, img): x = self.conv(img) x = self.to_patch_embedding(x) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x) x, c = self.transformer(x) if self.with_lca: x = self.LCA(c)[:, 0] else: x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] x = self.to_latent(x) return self.mlp_head(x)
def compute_flownet_of(model, cfg, expanded): R, B, E, T, C, H, W = expanded.shape device = expanded.device model = model.to(device) samples = rearrange(expanded, 'r b e t c h w -> t (r c h w) (b e)') samples = samples.contiguous() # speed up? t_ref = T // 2 # -- shape to pairs -- ref = repeat(samples[T // 2], 'd be -> tile d be', tile=T - 1) neighbors = torch.cat([samples[:T // 2], samples[T // 2 + 1:]], dim=0) pairs = torch.stack([ref, neighbors], dim=0) # 2, T-1, D, BE pairs = rearrange(pairs, 'two tm1 d be -> (tm1 be) two d') # -- compute flow -- flow = model(pairs) flows = model(samples)
def forward(self, img, return_sampled_token_ids=False): x = self.to_patch_embedding(img) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x) x, token_ids = self.transformer(x) logits = self.mlp_head(x[:, 0]) if return_sampled_token_ids: # remove CLS token and decrement by 1 to make -1 the padding token_ids = token_ids[:, 1:] - 1 return logits, token_ids return logits
def forward(self, x, einops_from, einops_to, **einops_dims): h = self.heads q, k, v = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q *= self.scale # splice out classification token at index 1 (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v)) # classification token attend to key/value of all patches across time and space cls_out = attn(cls_q, k, v) # reararange across time or space q_, k_, v_ = map( lambda t: rearrange(t, f'{einops_from}->{einops_to}', **einops_dims ), (q_, k_, v_)) # expand cls token keys and values across time or space r = q_.shape[0] // cls_k.shape[0] cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v)) k_ = torch.cat((cls_k, k_), dim=1) v_ = torch.cat((cls_v, v_), dim=1) # attention out = attn(q_, k_, v_) # merge back time or space out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) # concat back the cls token out = torch.cat((cls_out, out), dim=1) # merge back the heads out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out)
def forward(self, src, tgt=None, mems=None, src_mask=None, tgt_mask=None): b, n, num_mem, device = *src.shape, self.num_mem, src.device mems = default(mems, lambda: self.get_initial_mem(b)) enc = self.encoder(src, context=mems, src_mask=src_mask) if exists(self.decoder) and exists(tgt): dec_out = self.decoder(tgt, context=enc, src_mask=tgt_mask, tgt_mask=src_mask, return_loss=True) else: dec_out = torch.tensor(0., requires_grad=True, device=device) # update memory with attention mem_mask = torch.eye(num_mem, num_mem, device=device).bool() mem_mask = repeat(mem_mask, 'i j -> b i j', b=b) mem_mask = F.pad(mem_mask, (0, n), value=True) if exists(src_mask): src_mask = rearrange(src_mask, 'b j -> b () j') mem_enc_mask = F.pad(src_mask, (num_mem, 0), value=True) mem_mask &= mem_enc_mask for _ in range(self.num_mem_updates): prev_mems = mems updated_mems = self.mem_updater(mems, enc, mask=mem_mask, attend_self=True) next_mems = self.gru(rearrange(updated_mems, 'b n d -> (b n) d'), rearrange(prev_mems, 'b n d -> (b n) d')) mems = rearrange(next_mems, '(b n) d -> b n d', b=b) mems = self.mem_ff(mems) if not exists(self.decoder): return EncOnlyResults(enc, mems) return Results(enc, mems, dec_out)
def __call__(self, x): bs, dim = x.shape[0], x.shape[-1] latents = self.param( "latents", init.normal(), (self.n_latents, dim * self.ff_mult) ) latent = repeat(latents, "n d -> b n d", b=bs) x = fourier_encode(x, self.n_fourier_features) x = rearrange(x, "b n ... -> b n (...)") cross_attn = partial( Attention, heads=self.cross_n_heads, head_features=self.cross_head_features, dropout=self.attn_dropout, ) latent_attn = partial( Attention, heads=self.latent_n_heads, head_features=self.latent_head_features, dropout=self.attn_dropout, ) ff = partial(FeedForward, mult=self.ff_mult, dropout=self.ff_dropout) if self.tie_layer_weights: ca = cross_attn(name="cross_attn") la = latent_attn(name="latent_attn") cf = ff(name="cross_ff") lf = ff(name="latent_ff") for i in range(self.depth): rz = ReZero(name=f"rezero_{i}") latent += rz(ca(latent, x)) latent += rz(cf(latent)) latent += rz(la(latent)) latent += rz(lf(latent)) else: for i in range(self.depth): rz = ReZero(name=f"rezero_{i}") latent += rz(cross_attn(name=f"cross_attn_{i}")(latent, x)) latent += rz(ff(name=f"cross_ff_{i}")(latent)) latent += rz(latent_attn(name=f"latent_attn_{i}")(latent)) latent += rz(ff(name=f"latent_ff_{i}")(latent)) return latent
def seq_flow_to_pix_torch(_flow, centers): # -- copy -- flow = _flow.clone() # -- compute deltas to ref -- nsamples, nframes_minus_1, two = flow.shape nframes = nframes_minus_1 + 1 ref_frame = nframes // 2 # -- init pix -- flip, csum = torch.fliplr, torch.cumsum zeros = torch.zeros((nsamples, 1, 2), device=flow.device) left_idx = slice(None, nframes // 2) right_idx = slice(nframes // 2, None) # -- change from _spatial_ _object_ motion into _image coords_ _object_ motion flow[..., 1] = -flow[..., 1] # -- swap dx and dy -- r""" go from "x -> x -> x*" to get "sum(->,->), sum(->)" 1. (1st flip) the first element is _further_ from ref than the left_idx[-1] element 2. The cumulative sum goes from single arrow to sum of arrows 3. (2nd flip) back to original order 4. (negative) the origin of the starts from _ref_ location. """ left = -flip(csum(flip(flow[:, left_idx]), 1)) right = csum(flow[:, right_idx], 1) pix = torch.cat([left, zeros, right], dim=1) # -- add locations -- centers = repeat(centers, 's two -> s t two', t=nframes) # -- create pix -- pix += centers return pix
def forward(self, feats, coors, adj_mat=None, edges=None, mask=None): b = feats.shape[0] if exists(self.token_emb): feats = self.token_emb(feats) if exists(edges) and exists(self.edge_emb): edges = self.edge_emb(edges) # create N-degrees adjacent matrix from 1st degree connections if exists(self.num_adj_degrees): assert exists( adj_mat ), 'adjacency matrix must be passed in (keyword argument adj_mat)' if len(adj_mat.shape) == 2: adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b=b) adj_indices = adj_mat.clone().long() for ind in range(self.num_adj_degrees - 1): degree = ind + 2 next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0 next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool() adj_indices.masked_fill_(next_degree_mask, degree) adj_mat = next_degree_adj_mat.clone() if exists(self.adj_emb): adj_emb = self.adj_emb(adj_indices) edges = torch.cat( (edges, adj_emb), dim=-1) if exists(edges) else adj_emb for layer in self.layers: feats, coors = layer(feats, coors, adj_mat=adj_mat, edges=edges, mask=mask) return feats, coors
def forward(self, img, mask=None, distill_token=None, return_attention=False): assert distill_token is not None x = self.to_patch_embedding(img) b, n, _ = x.shape if len(distill_token.shape) == 2: distill_token = distill_token.unsqueeze(1) cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) x = torch.cat((cls_tokens, x), dim=1) x = torch.cat((x, distill_token), dim=1) x += self.pos_embedding[:, :(n + 2)] x = self.dropout(x) x = self.transformer(x, mask, return_attention) if return_attention: return x[0], x[1] return x
def forward(self, img): p = self.patch_size x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) x = self.patch_to_embedding(x) batch_size, seq_len, _ = x.shape b, n = batch_size, seq_len class_tokens = repeat(self.class_token, '() n d -> b n d', b = b) out = torch.cat((class_tokens, x), dim=1) out += self.pos_embedding[:, :(n + 1)] out = self.dropout(out) out = self.transformer(out) out = out[:, 0] out = self.mlp(out) return out
def forward(self, video): b, f, _, h, w, *_, device, p = *video.shape, video.device, self.patch_size assert h % p == 0 and w % p == 0, f'height {h} and width {w} of video must be divisible by the patch size {p}' n = (h // p) * (w // p) video = rearrange(video, 'b f c (h p1) (w p2) -> b (f h w) (p1 p2 c)', p1 = p, p2 = p) tokens = self.to_patch_embedding(video) cls_token = repeat(self.cls_token, 'n d -> b n d', b = b) x = torch.cat((cls_token, tokens), dim = 1) x += self.pos_emb(torch.arange(x.shape[1], device = device)) for (time_attn, spatial_attn, ff) in self.layers: x = time_attn(x, 'b (f n) d', '(b n) f d', n = n) + x x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f) + x x = ff(x) + x cls_token = x[:, 0] return self.to_out(cls_token)
def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, \ normalize_data = True, device=None): """Generalized features = kernel_fn(W^T * data) or kernel_fn(data). Args: data: (:obj:`tensor`) Input query or key for which features are computed. projection_matrix: (:obj:`tensor`) Random matrix W used to compute features. kernel_fn: (:obj:`Callable`, 'optional`, defaults to :class:`nn.ReLU()`) Basis function to produce features. kernel_epsilon: (:obj:`float`, `optional`, defaults to 1e-3) Bias added to produced features for numerical stability(?). normalize_data: (:obj:`bool`, `optional`, defaults to :obj:`True`) Whether or not to normalize data. Returns: data_prime: (:obj:`tensor`) Random features. Shape for inputs: - data: (batch, heads, seq_length, hidden_size) - projection_matrix: (nb_features, w_size) Shape for outputs: - data_prime: (batch, heads, seq_length, hidden_size) if `no_projection=True`. (batch, heads, seq_length, nb_features) if `no_projection=False`. """ b, h, *_ = data.shape data_normalizer = (data.shape[-1]**-0.25) if normalize_data else 1. if projection_matrix is None: return kernel_fn(data_normalizer * data) + kernel_epsilon projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h) projection = projection.type_as(data) data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) data_prime = kernel_fn(data_dash) + kernel_epsilon return data_prime.type_as(data)
def homo_warp(src_feat, proj_mat, depth_values): """ src_feat: (B, C, H, W) proj_mat: (B, 3, 4) equal to "src_proj @ ref_proj_inv" depth_values: (B, D, H, W) out: (B, C, D, H, W) """ B, C, H, W = src_feat.shape D = depth_values.shape[1] device = src_feat.device R = proj_mat[:, :, :3] # (B, 3, 3) T = proj_mat[:, :, 3:] # (B, 3, 1) # create grid from the ref frame ref_grid = create_meshgrid(H, W, normalized_coordinates=False, device=device) # (1, H, W, 2) ref_grid = rearrange(ref_grid, '1 h w c -> 1 c (h w)') # (1, 2, H*W) ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W) ref_grid = torch.cat((ref_grid, torch.ones_like(ref_grid[:,:1])), 1) # (B, 3, H*W) ref_grid_d = repeat(ref_grid, 'b c x -> b c (d x)', d=D) # (B, 3, D*H*W) src_grid_d = R @ ref_grid_d + T/rearrange(depth_values, 'b d h w -> b 1 (d h w)') del ref_grid_d, ref_grid, proj_mat, R, T, depth_values # release (GPU) memory # project negative depth pixels to somewhere outside the image negative_depth_mask = src_grid_d[:, 2:] <= 1e-7 src_grid_d[:, 0:1][negative_depth_mask] = W src_grid_d[:, 1:2][negative_depth_mask] = H src_grid_d[:, 2:3][negative_depth_mask] = 1 src_grid = src_grid_d[:, :2] / src_grid_d[:, 2:] # divide by depth (B, 2, D*H*W) del src_grid_d src_grid[:, 0] = src_grid[:, 0]/((W-1)/2) - 1 # scale to -1~1 src_grid[:, 1] = src_grid[:, 1]/((H-1)/2) - 1 # scale to -1~1 src_grid = rearrange(src_grid, 'b c (d h w) -> b d (h w) c', d=D, h=H, w=W) warped_src_feat = F.grid_sample(src_feat, src_grid, mode='bilinear', padding_mode='zeros', align_corners=True) # (B, C, D, H*W) warped_src_feat = rearrange(warped_src_feat, 'b c d (h w) -> b c d h w', h=H, w=W) return warped_src_feat
def forward(self, x, einops_from, einops_to, mask = None, cls_mask = None, rot_emb = None, **einops_dims): h = self.heads q, k, v = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) q = q * self.scale # splice out classification token at index 1 (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k, v)) # let classification token attend to key / values of all patches across time and space cls_out = attn(cls_q, k, v, mask = cls_mask) # rearrange across time or space q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_)) # add rotary embeddings, if applicable if exists(rot_emb): q_, k_ = apply_rot_emb(q_, k_, rot_emb) # expand cls token keys and values across time or space and concat r = q_.shape[0] // cls_k.shape[0] cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v)) k_ = torch.cat((cls_k, k_), dim = 1) v_ = torch.cat((cls_v, v_), dim = 1) # attention out = attn(q_, k_, v_, mask = mask) # merge back time or space out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) # concat back the cls token out = torch.cat((cls_out, out), dim = 1) # merge back the heads out = rearrange(out, '(b h) n d -> b n (h d)', h = h) # combine heads out return self.to_out(out)
def reconstruct_ref_fullpersp(normalized_2d, coords3d_rel, validity_mask): """Reconstructs the reference point location. Args: normalized_2d: normalized image coordinates of the joints (without intrinsics applied), shape [batch_size, n_points, 2] coords3d_rel: 3D camera coordinate offsets relative to the unknown reference point which we want to reconstruct, shape [batch_size, n_points, 3] validity_mask: boolean mask of shape [batch_size, n_points] containing True where the point is reliable and should be used in the reconstruction Returns: The 3D reference point in camera coordinates, shape [batch_size, 3] """ def rms_normalize(x): scale = tf.sqrt(tf.reduce_mean(tf.square(x))) normalized = x / scale return scale, normalized n_batch = tf.shape(normalized_2d)[0] n_points = normalized_2d.shape.as_list()[1] eyes = tf.tile(tf.expand_dims(tf.eye(2, 2), 0), [n_batch, n_points, 1]) scale2d, reshaped2d = rms_normalize( tf.reshape(normalized_2d, [-1, n_points * 2, 1])) A = tf.concat([eyes, -reshaped2d], axis=2) rel_backproj = normalized_2d * coords3d_rel[:, :, 2:] - coords3d_rel[:, :, :2] scale_rel_backproj, b = rms_normalize( tf.reshape(rel_backproj, [-1, n_points * 2, 1])) weights = tf.cast(validity_mask, tf.float32) + np.float32(1e-4) weights = einops.repeat(weights, 'b j -> b (j c) 1', c=2) ref = tf.linalg.lstsq(A * weights, b * weights, l2_regularizer=1e-2, fast=True) ref = tf.concat([ref[:, :2], ref[:, 2:] / scale2d], axis=1) * scale_rel_backproj return tf.squeeze(ref, axis=-1)
def generalized_kernel(data, *, projection_matrix, kernel_fn=nn.ReLU(), kernel_epsilon=0.001, normalize_data=True, device=None): b, h, *_ = data.shape data_normalizer = (data.shape[-1]**-0.25) if normalize_data else 1. if projection_matrix is None: return kernel_fn(data_normalizer * data) + kernel_epsilon projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h) data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) data_prime = kernel_fn(data_dash) + kernel_epsilon return data_prime