def stoch_mat(self, A, zero_diagonal=False, do_dropout=True, do_sinkhorn=False): ''' Affinity -> Stochastic Matrix ''' if zero_diagonal: A = self.zeroout_diag(A) if do_dropout and self.edgedrop_rate > 0: A[torch.rand_like(A) < self.edgedrop_rate] = -1e20 if do_sinkhorn: return utils.sinkhorn_knopp((A/self.temperature).exp(), tol=0.01, max_iter=100, verbose=False) return F.softmax(A/self.temperature, dim=-1)
def forward(self, x, just_feats=False,): ''' Input is B x T x N*C x H x W, where either N>1 -> list of patches of images N=1 -> list of images ''' B, T, C, H, W = x.shape _N, C = C//3, 3 ################################################################# # Pixels to Nodes ################################################################# x = x.transpose(1, 2).view(B, _N, C, T, H, W) q, mm = self.pixels_to_nodes(x) B, C, T, N = q.shape # # DEBUG # print("just_feats = ", just_feats) # breakpoint() if just_feats: h, w = np.ceil(np.array(x.shape[-2:]) / self.map_scale).astype(np.int) return (q, mm) if _N > 1 else (q, q.view(*q.shape[:-1], h, w)) ################################################################# # Compute walks ################################################################# walks = dict() As = self.affinity(q[:, :, :-1], q[:, :, 1:]) A12s = [self.stoch_mat(As[:, i], do_dropout=True) for i in range(T-1)] ################################################################# # Palindromes ################################################################# if not self.sk_targets: A21s = [self.stoch_mat(As[:, i].transpose(-1, -2), do_dropout=True) for i in range(T-1)] AAs = [] for i in list(range(1, len(A12s))): g = A12s[:i+1] + A21s[:i+1][::-1] aar = aal = g[0] for _a in g[1:]: aar, aal = aar @ _a, _a @ aal AAs.append((f"l{i}", aal) if self.flip else (f"r{i}", aar)) for i, aa in AAs: walks[f"cyc {i}"] = [aa, self.xent_targets(aa)] # # DEBUG # print("Walks", end = "\n" + "-"*100 + "\n") # print("type(walks) = ", type(walks)) # print("len(walks) = ", len(walks)) # for k, v in walks.items(): # print(k, " = ", v) # print("len(v) = ", len(v)) # print("v[0].shape = ", v[0].shape) # print("v[1].shape = ", v[1].shape) # breakpoint() #################################################### Sinkhorn-Knopp Target (experimental) else: a12, at = A12s[0], self.stoch_mat(A[:, 0], do_dropout=False, do_sinkhorn=True) for i in range(1, len(A12s)): a12 = a12 @ A12s[i] at = self.stoch_mat(As[:, i], do_dropout=False, do_sinkhorn=True) @ at with torch.no_grad(): targets = utils.sinkhorn_knopp(at, tol=0.001, max_iter=10, verbose=False).argmax(-1).flatten() walks[f"sk {i}"] = [a12, targets] ################################################################# # Compute loss ################################################################# xents = [torch.tensor([0.]).to(self.args.device)] diags = dict() for name, (A, target) in walks.items(): logits = torch.log(A+EPS).flatten(0, -2) loss = self.xent(logits, target).mean() acc = (torch.argmax(logits, dim=-1) == target).float().mean() diags.update({f"{H} xent {name}": loss.detach(), f"{H} acc {name}": acc}) xents += [loss] # DEBUG : Contents of xents and diags print("A.shape = ", A.shape) print("logits = ", logits) print("logits.shape = ", logits.shape) print("loss = ", loss) print("diags = ", diags, end="\n" + "-"*100 + "\n") print("len(diags) = ", len(diags)) print("-"*100) print("diags elements:") for k, v in diags.items(): print(k, " = ", v) print("-"*100) print("xents = ", xents) print("-"*100) breakpoint() ################################################################# # Visualizations ################################################################# if (np.random.random() < 0.02) and (self.vis is not None): # and False: with torch.no_grad(): self.visualize_frame_pair(x, q, mm) if _N > 1: # and False: self.visualize_patches(x, q) loss = sum(xents) / max(1, len(xents)-1) return q, loss, diags
def forward( self, x, d=None, RT=None, K=None, just_feats=False, ): ''' Input is B x V x T x N*C x H x W, where either N>1 -> list of patches of images N=1 -> list of images ''' if d != None: B, V, T, C, H, W = x.shape _N, C = C // 3, 3 Bd, Vd, Td, Hd, Wd = d.shape d, RT, K = d.transpose(0, 1), RT.transpose(0, 1), K.transpose(0, 1) x = x.transpose(2, 3).view(B, V, _N, C, T, H, W).transpose(0, 1) q, mm = self.pixels_to_nodes( x[0]) # q node embedding, mm node feature embedding q_back, mm_back = self.pixels_to_nodes(x[1]) q_top, mm_top = self.pixels_to_nodes(x[2]) else: B, T, C, H, W = x.shape _N, C = C // 3, 3 x = x.transpose(1, 2).view(B, _N, C, T, H, W) q, mm = self.pixels_to_nodes(x) ################################################################# # Pixels to Nodes ################################################################# B, C, T, N = q.shape # q shape [B, 128, 4, 49] # mm shape [B, 49, 512, 4, 8, 8] view_loss = 0 # untested attempt to minimized distance between one node from another view #for i in range(49): # features = q[..., i].transpose(1,2) # B x T # diff_back = torch.sum(torch.pow((features - q_back[..., 0].transpose(1,2)), 2), dim=-1) # diff_top = torch.sum(torch.pow((features - q_up[..., 0].transpose(1,2)), 2), dim=-1) # for j in range(49): # features_back = q_back[..., j].transpose(1,2) # features_top = q_up[..., j].transpose(1,2) # dist_back = torch.sum(torch.pow((features - features_back), 2),dim=-1) # dist_top = torch.sum(torch.pow((features - features_top), 2), dim=-1) # diff_back = torch.min(diff_back, dist_back) # diff_top = torch.min(diff_top, dist_top) # view_loss += torch.sum(diff_back) + torch.sum(diff_top) #print(view_loss) #view_loss = self.sigmoid(view_loss / B / 2) if d != None: for i in range(49): coor_x, coor_y = self.patch_index_to_pixel(i, Hd, Wd, H, W, N) coor_z = d[0][..., coor_x][..., coor_y] features = q[..., i].transpose(1, 2) for b in range(B): for t in range(T): coor_back = utils.view_swap(coor_x, coor_y, coor_z[b][t], RT[0][b], RT[1][b], K[0][b]) #coor_top = utils.view_swap(coor_x, coor_y, coor_z[b][t], RT[0][b], RT[2][b], K[0][b]) index_back = self.coord_to_index( coor_back, Hd, Wd, H, W, N) #index_top = self.coord_to_index(coor_top, Hd, Wd, H, W, N) if index_back < N and index_back >= 0: view_loss += self.sigmoid( torch.dist( features[b][t], q_back[..., index_back].transpose(1, 2)[b][t], 2)) #if index_top < N or index_top >= 0: # view_loss += self.sigmoid(torch.dist(features[b][t], q_top[..., index_top].transpose(1, 2)[b][t], 2)) view_loss /= int(B) * int(T) view_loss /= 10 if just_feats: h, w = np.ceil(np.array(x[0].shape[-2:]) / self.map_scale).astype( np.int) return (q, mm) if _N > 1 else (q, q.view(*q.shape[:-1], h, w)) ################################################################# # Compute walks ################################################################# walks = dict() As = self.affinity(q[:, :, :-1], q[:, :, 1:]) A12s = [ self.stoch_mat(As[:, i], do_dropout=True) for i in range(T - 1) ] # As [4,3,49,49] B, T, N, M # A12s[0] [4,49,49] #################################################### Palindromes if not self.sk_targets: A21s = [ self.stoch_mat(As[:, i].transpose(-1, -2), do_dropout=True) for i in range(T - 1) ] AAs = [] for i in list(range(1, len(A12s))): g = A12s[:i + 1] + A21s[:i + 1][::-1] aar = aal = g[0] for _a in g[1:]: aar, aal = aar @ _a, _a @ aal AAs.append((f"l{i}", aal) if self.flip else (f"r{i}", aar)) for i, aa in AAs: walks[f"cyc {i}"] = [aa, self.xent_targets(aa)] #################################################### Sinkhorn-Knopp Target (experimental) else: a12, at = A12s[0], self.stoch_mat(A[:, 0], do_dropout=False, do_sinkhorn=True) for i in range(1, len(A12s)): a12 = a12 @ A12s[i] at = self.stoch_mat( As[:, i], do_dropout=False, do_sinkhorn=True) @ at with torch.no_grad(): targets = utils.sinkhorn_knopp( at, tol=0.001, max_iter=10, verbose=False).argmax(-1).flatten() walks[f"sk {i}"] = [a12, targets] ################################################################# # Compute loss ################################################################# xents = [torch.tensor([0.]).to(self.args.device)] diags = dict() for name, (A, target) in walks.items(): logits = torch.log(A + EPS).flatten(0, -2) loss = self.xent(logits, target).mean() acc = (torch.argmax(logits, dim=-1) == target).float().mean() diags.update({ f"{H} xent {name}": loss.detach(), f"{H} acc {name}": acc }) xents += [loss] ################################################################# # Visualizations ################################################################# if (np.random.random() < 1) and (self.vis is not None): # and False: with torch.no_grad(): self.visualize_frame_pair(x[0], q, mm) if _N > 1: # and False: self.visualize_patches(x[0], q) loss = sum(xents) / max(1, len(xents) - 1) loss += view_loss print(loss, view_loss) return q, loss, diags