示例#1
0
    def stoch_mat(self, A, zero_diagonal=False, do_dropout=True, do_sinkhorn=False):
        ''' Affinity -> Stochastic Matrix '''

        if zero_diagonal:
            A = self.zeroout_diag(A)

        if do_dropout and self.edgedrop_rate > 0:
            A[torch.rand_like(A) < self.edgedrop_rate] = -1e20

        if do_sinkhorn:
            return utils.sinkhorn_knopp((A/self.temperature).exp(), 
                tol=0.01, max_iter=100, verbose=False)

        return F.softmax(A/self.temperature, dim=-1)
示例#2
0
    def forward(self, x, just_feats=False,):
        '''
        Input is B x T x N*C x H x W, where either
           N>1 -> list of patches of images
           N=1 -> list of images
        '''
        B, T, C, H, W = x.shape
        _N, C = C//3, 3
    
        #################################################################
        # Pixels to Nodes 
        #################################################################
        x = x.transpose(1, 2).view(B, _N, C, T, H, W)
        q, mm = self.pixels_to_nodes(x)
        B, C, T, N = q.shape

        # # DEBUG
        # print("just_feats = ", just_feats)
        # breakpoint()

        if just_feats:
            h, w = np.ceil(np.array(x.shape[-2:]) / self.map_scale).astype(np.int)
            return (q, mm) if _N > 1 else (q, q.view(*q.shape[:-1], h, w))

        #################################################################
        # Compute walks 
        #################################################################
        walks = dict()
        As = self.affinity(q[:, :, :-1], q[:, :, 1:])
        A12s = [self.stoch_mat(As[:, i], do_dropout=True) for i in range(T-1)]

        #################################################################
        # Palindromes
        #################################################################
        if not self.sk_targets:  
            A21s = [self.stoch_mat(As[:, i].transpose(-1, -2), do_dropout=True) for i in range(T-1)]
            AAs = []
            for i in list(range(1, len(A12s))):
                g = A12s[:i+1] + A21s[:i+1][::-1]
                aar = aal = g[0]
                for _a in g[1:]:
                    aar, aal = aar @ _a, _a @ aal

                AAs.append((f"l{i}", aal) if self.flip else (f"r{i}", aar))
    
            for i, aa in AAs:
                walks[f"cyc {i}"] = [aa, self.xent_targets(aa)]
        
            # # DEBUG
            # print("Walks", end = "\n" + "-"*100 + "\n")
            # print("type(walks) = ", type(walks))
            # print("len(walks) = ", len(walks))
            # for k, v in walks.items():
            #     print(k, " = ", v)
            #     print("len(v) = ", len(v))
            #     print("v[0].shape = ", v[0].shape)
            #     print("v[1].shape = ", v[1].shape)
            # breakpoint()

        #################################################### Sinkhorn-Knopp Target (experimental)
        else:   
            a12, at = A12s[0], self.stoch_mat(A[:, 0], do_dropout=False, do_sinkhorn=True)
            for i in range(1, len(A12s)):
                a12 = a12 @ A12s[i]
                at = self.stoch_mat(As[:, i], do_dropout=False, do_sinkhorn=True) @ at
                with torch.no_grad():
                    targets = utils.sinkhorn_knopp(at, tol=0.001, max_iter=10, verbose=False).argmax(-1).flatten()
                walks[f"sk {i}"] = [a12, targets]

        #################################################################
        # Compute loss 
        #################################################################
        xents = [torch.tensor([0.]).to(self.args.device)]
        diags = dict()

        for name, (A, target) in walks.items():
            logits = torch.log(A+EPS).flatten(0, -2)
            loss = self.xent(logits, target).mean()
            acc = (torch.argmax(logits, dim=-1) == target).float().mean()
            diags.update({f"{H} xent {name}": loss.detach(),
                          f"{H} acc {name}": acc})
            xents += [loss]

        # DEBUG : Contents of xents and diags
        print("A.shape = ", A.shape)
        print("logits = ", logits)
        print("logits.shape = ", logits.shape)
        print("loss = ", loss)
        print("diags = ", diags, end="\n" + "-"*100 + "\n")
        print("len(diags) = ", len(diags))
        print("-"*100)
        print("diags elements:")
        for k, v in diags.items():
            print(k, " = ", v)
        print("-"*100)
        print("xents = ", xents)
        print("-"*100)
        breakpoint()

        #################################################################
        # Visualizations
        #################################################################
        if (np.random.random() < 0.02) and (self.vis is not None): # and False:
            with torch.no_grad():
                self.visualize_frame_pair(x, q, mm)
                if _N > 1: # and False:
                    self.visualize_patches(x, q)

        loss = sum(xents) / max(1, len(xents)-1)
        
        return q, loss, diags
示例#3
0
    def forward(
        self,
        x,
        d=None,
        RT=None,
        K=None,
        just_feats=False,
    ):
        '''
        Input is B x V x T x N*C x H x W, where either
           N>1 -> list of patches of images
           N=1 -> list of images
        '''
        if d != None:
            B, V, T, C, H, W = x.shape
            _N, C = C // 3, 3
            Bd, Vd, Td, Hd, Wd = d.shape
            d, RT, K = d.transpose(0, 1), RT.transpose(0, 1), K.transpose(0, 1)
            x = x.transpose(2, 3).view(B, V, _N, C, T, H, W).transpose(0, 1)
            q, mm = self.pixels_to_nodes(
                x[0])  # q node embedding, mm node feature embedding
            q_back, mm_back = self.pixels_to_nodes(x[1])
            q_top, mm_top = self.pixels_to_nodes(x[2])
        else:
            B, T, C, H, W = x.shape
            _N, C = C // 3, 3
            x = x.transpose(1, 2).view(B, _N, C, T, H, W)
            q, mm = self.pixels_to_nodes(x)

        #################################################################
        # Pixels to Nodes
        #################################################################

        B, C, T, N = q.shape
        # q shape [B, 128, 4, 49]
        # mm shape [B, 49, 512, 4, 8, 8]

        view_loss = 0
        # untested attempt to minimized distance between one node from another view
        #for i in range(49):
        #    features = q[..., i].transpose(1,2)
        # B x T
        #    diff_back = torch.sum(torch.pow((features - q_back[..., 0].transpose(1,2)), 2), dim=-1)
        #    diff_top = torch.sum(torch.pow((features - q_up[..., 0].transpose(1,2)), 2), dim=-1)
        #    for j in range(49):
        #        features_back = q_back[..., j].transpose(1,2)
        #        features_top = q_up[..., j].transpose(1,2)
        #        dist_back = torch.sum(torch.pow((features - features_back), 2),dim=-1)
        #        dist_top = torch.sum(torch.pow((features - features_top), 2), dim=-1)
        #        diff_back = torch.min(diff_back, dist_back)
        #        diff_top = torch.min(diff_top, dist_top)
        #    view_loss += torch.sum(diff_back) + torch.sum(diff_top)
        #print(view_loss)
        #view_loss = self.sigmoid(view_loss / B / 2)

        if d != None:
            for i in range(49):
                coor_x, coor_y = self.patch_index_to_pixel(i, Hd, Wd, H, W, N)
                coor_z = d[0][..., coor_x][..., coor_y]
                features = q[..., i].transpose(1, 2)
                for b in range(B):
                    for t in range(T):
                        coor_back = utils.view_swap(coor_x, coor_y,
                                                    coor_z[b][t], RT[0][b],
                                                    RT[1][b], K[0][b])
                        #coor_top = utils.view_swap(coor_x, coor_y, coor_z[b][t], RT[0][b], RT[2][b], K[0][b])
                        index_back = self.coord_to_index(
                            coor_back, Hd, Wd, H, W, N)
                        #index_top = self.coord_to_index(coor_top, Hd, Wd, H, W, N)
                        if index_back < N and index_back >= 0:
                            view_loss += self.sigmoid(
                                torch.dist(
                                    features[b][t],
                                    q_back[...,
                                           index_back].transpose(1,
                                                                 2)[b][t], 2))
                        #if index_top < N or index_top >= 0:
                        #    view_loss += self.sigmoid(torch.dist(features[b][t], q_top[..., index_top].transpose(1, 2)[b][t], 2))
            view_loss /= int(B) * int(T)
            view_loss /= 10

        if just_feats:
            h, w = np.ceil(np.array(x[0].shape[-2:]) / self.map_scale).astype(
                np.int)
            return (q, mm) if _N > 1 else (q, q.view(*q.shape[:-1], h, w))

        #################################################################
        # Compute walks
        #################################################################
        walks = dict()
        As = self.affinity(q[:, :, :-1], q[:, :, 1:])
        A12s = [
            self.stoch_mat(As[:, i], do_dropout=True) for i in range(T - 1)
        ]
        # As [4,3,49,49] B, T, N, M
        # A12s[0] [4,49,49]

        #################################################### Palindromes
        if not self.sk_targets:
            A21s = [
                self.stoch_mat(As[:, i].transpose(-1, -2), do_dropout=True)
                for i in range(T - 1)
            ]
            AAs = []
            for i in list(range(1, len(A12s))):
                g = A12s[:i + 1] + A21s[:i + 1][::-1]
                aar = aal = g[0]
                for _a in g[1:]:
                    aar, aal = aar @ _a, _a @ aal

                AAs.append((f"l{i}", aal) if self.flip else (f"r{i}", aar))

            for i, aa in AAs:
                walks[f"cyc {i}"] = [aa, self.xent_targets(aa)]

        #################################################### Sinkhorn-Knopp Target (experimental)
        else:
            a12, at = A12s[0], self.stoch_mat(A[:, 0],
                                              do_dropout=False,
                                              do_sinkhorn=True)
            for i in range(1, len(A12s)):
                a12 = a12 @ A12s[i]
                at = self.stoch_mat(
                    As[:, i], do_dropout=False, do_sinkhorn=True) @ at
                with torch.no_grad():
                    targets = utils.sinkhorn_knopp(
                        at, tol=0.001, max_iter=10,
                        verbose=False).argmax(-1).flatten()
                walks[f"sk {i}"] = [a12, targets]

        #################################################################
        # Compute loss
        #################################################################
        xents = [torch.tensor([0.]).to(self.args.device)]
        diags = dict()

        for name, (A, target) in walks.items():
            logits = torch.log(A + EPS).flatten(0, -2)
            loss = self.xent(logits, target).mean()
            acc = (torch.argmax(logits, dim=-1) == target).float().mean()
            diags.update({
                f"{H} xent {name}": loss.detach(),
                f"{H} acc {name}": acc
            })
            xents += [loss]

        #################################################################
        # Visualizations
        #################################################################
        if (np.random.random() < 1) and (self.vis is not None):  # and False:
            with torch.no_grad():
                self.visualize_frame_pair(x[0], q, mm)
                if _N > 1:  # and False:
                    self.visualize_patches(x[0], q)

        loss = sum(xents) / max(1, len(xents) - 1)
        loss += view_loss
        print(loss, view_loss)
        return q, loss, diags