Пример #1
0
    def update_data(self, b_edge_ids, edge_features, diff_to_gt, gt_edge_weights, node_labeling, raw, angles, gt):
        self.gt_seg = gt
        self.raw = raw
        self.init_sp_seg = node_labeling.squeeze()

        self.n_nodes = [edge_ids.max() + 1 for edge_ids in b_edge_ids]
        b_subgraphs, sep_subgraphs = find_dense_subgraphs([edge_ids.transpose(0, 1).cpu().numpy() for edge_ids in b_edge_ids], self.args.s_subgraph)
        self.sep_subgraphs = torch.from_numpy(sep_subgraphs.astype(np.int64))
        self.b_edge_ids, (self.n_offs, self.e_offs) = collate_edges(b_edge_ids)
        b_subgraphs = [torch.from_numpy(sg.astype(np.int64)).to(self.device).view(-1, 2).transpose(0, 1) + self.n_offs[i] for i, sg in enumerate(b_subgraphs)]
        self.sg_offs = [0]
        for i in range(len(b_subgraphs)):
            self.sg_offs.append(self.sg_offs[-1] + b_subgraphs[i].shape[-1])
        self.b_subgraphs = torch.cat(b_subgraphs, 1)

        self.b_subgraph_indices = get_edge_indices(self.b_edge_ids, self.b_subgraphs)
        self.b_gt_edge_weights = torch.cat(gt_edge_weights, 0)
        self.sg_gt_edge_weights = self.b_gt_edge_weights[self.b_subgraph_indices].view(-1, self.args.s_subgraph)
        self.b_edge_angles = angles
        self.sg_current_edge_weights = torch.ones_like(self.sg_gt_edge_weights) / 2

        self.b_initial_edge_weights = torch.cat([edge_fe[:, 0] for edge_fe in edge_features], dim=0)
        self.b_current_edge_weights = torch.ones_like(self.b_initial_edge_weights) / 2

        stacked_superpixels = [[node_labeling[i] == n for n in range(n_node)] for i, n_node in enumerate(self.n_nodes)]
        self.sp_indices = [[sp.nonzero().cpu() for sp in stacked_superpixel] for stacked_superpixel in stacked_superpixels]
Пример #2
0
    def update_data(self, raw, gt, edge_ids, gt_edges, sp_seg, rags, edge_features, *args, **kwargs):
        bs = raw.shape[0]
        dev = raw.device
        # edge_img = F.pad(get_contour_from_2d_binary(sp_seg[:, None].float()), (2, 2, 2, 2), mode='constant')
        # edge_img = self.gauss_kernel(edge_img.float())

        self.rags = rags
        self.gt_seg, self.init_sp_seg = gt.squeeze(1), sp_seg.squeeze(1)
        self.raw = raw
        with torch.set_grad_enabled(False):
            self.embeddings = self.embedding_net(raw)
        # get embedding agglomeration over each superpixel
        self.current_node_embeddings = torch.cat([self.embedding_net.get_mean_sp_embedding_chunked(embed, sp, chunks=40)
                                                  for embed, sp in zip(self.embeddings, self.init_sp_seg)], dim=0)

        edge_angles, sp_feat, sp_rads = zip(*[get_angles_smass_in_rag(edge_ids[i], self.init_sp_seg[i]) for i in range(bs)])
        edge_angles, self.sp_feat, self.sp_rads = torch.cat(edge_angles).unsqueeze(-1), torch.cat(sp_feat), torch.cat(sp_rads)

        self.dir_edge_ids = [torch.cat([_edge_ids, torch.stack([_edge_ids[1], _edge_ids[0]], dim=0)], dim=1) for _edge_ids in edge_ids]
        self.n_nodes = [eids.max() + 1 for eids in edge_ids]
        self.edge_ids, (self.n_offs, self.e_offs) = collate_edges(edge_ids)

        self.gt_edge_weights = gt_edges
        if self.gt_edge_weights is not None:
            self.edge_feats = torch.cat(edge_features, 0)
            self.gt_edge_weights = torch.cat(self.gt_edge_weights)
            self.gt_soln, _, _, _ = self.get_current_soln(self.gt_edge_weights)

        self.edge_features = torch.cat([edge_angles, torch.cat(edge_features, 0)[:, :2]], 1)
        self.current_edge_weights = torch.ones(self.edge_ids.shape[1], device=self.edge_ids.device) / 2

        return
Пример #3
0
    def update_data(self, edge_ids, gt_edges, sp_seg, raw, gt, **kwargs):
        bs = len(edge_ids)
        dev = edge_ids[0].device
        subgraphs, self.sep_subgraphs = [], []
        self.gt_seg = gt.squeeze(1)
        self.raw = raw
        self.init_sp_seg = sp_seg.squeeze(1)
        edge_angles, sup_masses, sup_com = zip(*[get_angles_smass_in_rag(edges, sp) for edges, sp in zip(edge_ids, self.init_sp_seg)])
        self.edge_angles, self.sup_masses, self.sup_com = torch.cat(edge_angles).unsqueeze(-1), torch.cat(sup_masses).unsqueeze(-1), torch.cat(sup_com)
        self.init_sp_seg_edge = torch.cat([(-self.max_p(-sp_seg) != sp_seg).float(), (self.max_p(sp_seg) != sp_seg).float()], 1)

        _subgraphs, _sep_subgraphs = find_dense_subgraphs([eids.transpose(0, 1).cpu().numpy() for eids in edge_ids], self.cfg.trn.s_subgraph)
        _subgraphs = [torch.from_numpy(sg.astype(np.int64)).to(dev).permute(2, 0, 1) for sg in _subgraphs]
        _sep_subgraphs = [torch.from_numpy(sg.astype(np.int64)).to(dev).permute(2, 0, 1) for sg in _sep_subgraphs]

        self.n_nodes = [eids.max() + 1 for eids in edge_ids]
        self.edge_ids, (self.n_offs, self.e_offs) = collate_edges(edge_ids)
        self.dir_edge_ids = torch.cat([self.edge_ids, torch.stack([self.edge_ids[1], self.edge_ids[0]], dim=0)], dim=1)
        for i in range(len(self.cfg.trn.s_subgraph)):
            subgraphs.append(torch.cat([sg + self.n_offs[i] for i, sg in enumerate(_subgraphs[i*bs:(i+1)*bs])], -2).flatten(-2, -1))
            self.sep_subgraphs.append(torch.cat(_sep_subgraphs[i*bs:(i+1)*bs], -2).flatten(-2, -1))

        self.subgraphs = subgraphs
        self.subgraph_indices = get_edge_indices(self.edge_ids, subgraphs)

        self.gt_edge_weights = torch.cat(gt_edges)
        self.gt_soln = self.get_mc_soln(self.gt_edge_weights)
        self.sg_gt_edges = [self.gt_edge_weights[sg].view(-1, sz) for sz, sg in
                            zip(self.cfg.trn.s_subgraph, self.subgraph_indices)]

        stacked_superpixels = [torch.zeros((int(sp.max()+1), ) + sp.shape, device=self.device).scatter_(0, sp[None].long(), 1) for sp in self.init_sp_seg]
        self.sp_indices = [[torch.nonzero(sp, as_tuple=False) for sp in stacked_superpixel] for stacked_superpixel in stacked_superpixels]

        self.embeddings = self.embedding_net(self.raw).detach()
        # get embedding agglomeration over each superpixel
        self.current_node_embeddings = torch.cat([self.embedding_net.get_mean_sp_embedding(embed, sp) for embed, sp
                                                  in zip(self.embeddings, self.init_sp_seg)], dim=0)
        return
Пример #4
0
    def update_data(self, edge_ids, edge_features, diff_to_gt, gt_edge_weights,
                    sp_seg, raw, gt):
        bs = len(edge_ids)
        dev = edge_ids[0].device
        subgraphs, self.sep_subgraphs = [], []
        self.gt_seg = gt.squeeze(1)
        self.raw = raw
        self.init_sp_seg = sp_seg.squeeze(1)
        self.edge_angles = torch.cat([
            get_angles_in_rag(edges, sp)
            for edges, sp in zip(edge_ids, self.init_sp_seg)
        ]).unsqueeze(-1)
        self.init_sp_seg_edge = torch.cat(
            [(-self.max_p(-sp_seg) != sp_seg).float(),
             (self.max_p(sp_seg) != sp_seg).float()], 1)

        _subgraphs, _sep_subgraphs = find_dense_subgraphs(
            [edge_ids.transpose(0, 1).cpu().numpy() for edge_ids in edge_ids],
            self.cfg.sac.s_subgraph)
        _subgraphs = [
            torch.from_numpy(sg.astype(np.int64)).to(dev).transpose(-3, -1)
            for sg in _subgraphs
        ]
        _sep_subgraphs = [
            torch.from_numpy(sg.astype(np.int64)).to(dev).transpose(-3, -1)
            for sg in _sep_subgraphs
        ]

        self.n_nodes = [edge_ids.max() + 1 for edge_ids in edge_ids]
        self.edge_ids, (self.n_offs, self.e_offs) = collate_edges(edge_ids)
        for i in range(len(self.cfg.sac.s_subgraph)):
            subgraphs.append(
                torch.cat([
                    sg + self.n_offs[i]
                    for i, sg in enumerate(_subgraphs[i * bs:(i + 1) * bs])
                ], -1).flatten(-2, -1))
            self.sep_subgraphs.append(
                torch.cat(_sep_subgraphs[i * bs:(i + 1) * bs],
                          -1).flatten(-2, -1))

        self.subgraph_indices = get_edge_indices(self.edge_ids, subgraphs)
        self.gt_edge_weights = torch.cat(gt_edge_weights, 0)
        self.sg_gt_edge_weights = [
            self.gt_edge_weights[sg].view(-1, sz)
            for sz, sg in zip(self.cfg.sac.s_subgraph, self.subgraph_indices)
        ]

        self.sg_current_edge_weights = [
            torch.ones_like(sg) / 2 for sg in self.sg_gt_edge_weights
        ]

        self.initial_edge_weights = torch.cat(
            [edge_fe[:, 0] for edge_fe in edge_features], dim=0)
        self.current_edge_weights = self.initial_edge_weights.clone()
        self.gt_soln = self.get_current_soln(self.gt_edge_weights)

        stacked_superpixels = [[sp_seg[i] == n for n in range(n_node)]
                               for i, n_node in enumerate(self.n_nodes)]
        self.sp_indices = [[
            torch.nonzero(sp, as_tuple=False) for sp in stacked_superpixel
        ] for stacked_superpixel in stacked_superpixels]

        # cs = self.get_current_soln(self.b_gt_edge_weights)
        # fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
        # ax1.imshow(cm.prism(self.gt_seg[0].detach().cpu().numpy() / self.gt_seg[0].detach().cpu().numpy().max()));
        # ax1.set_title('gt')
        # ax2.imshow(cm.prism(self.init_sp_seg[0].detach().cpu().numpy() / self.init_sp_seg[0].detach().cpu().numpy().max()));
        # ax2.set_title('sp')
        # ax3.imshow(cm.prism(cs[0].detach().cpu().numpy() / cs[0].detach().cpu().numpy().max()));
        # ax3.set_title('mc')
        # plt.show()
        # a=1

        # self.b_penalize_diff_thresh = diff_to_gt * 4
        # plt.imshow(self.get_current_soln_pic(1));plt.show()
        return
Пример #5
0
    def update_data(self, raw, gt, edge_ids, gt_edges, sp_seg, fe_grad, rags,
                    edge_features, *args, **kwargs):
        bs = raw.shape[0]
        dev = raw.device
        for _sp_seg in sp_seg:
            assert all(
                _sp_seg.unique() == torch.arange(_sp_seg.max() +
                                                 1, device=dev))
            assert _sp_seg.max() > 60
        self.rags = rags
        self.gt_seg, self.init_sp_seg = gt.squeeze(1), sp_seg.squeeze(1)
        self.raw = raw

        edge_angles, sp_feat, self.sp_rads, self.sp_cms, self.sp_masses = \
            zip(*[get_angles_smass_in_rag(edge_ids[i], self.init_sp_seg[i]) for i in range(bs)])
        edge_angles, self.sp_feat = torch.cat(edge_angles).unsqueeze(
            -1), torch.cat(sp_feat)

        subgraphs, self.sep_subgraphs = [], []
        _subgraphs, _sep_subgraphs = find_dense_subgraphs(
            [eids.transpose(0, 1).cpu().numpy() for eids in edge_ids],
            self.cfg.s_subgraph)
        _subgraphs = [
            torch.from_numpy(sg.astype(np.int64)).to(dev).permute(2, 0, 1)
            for sg in _subgraphs
        ]
        _sep_subgraphs = [
            torch.from_numpy(sg.astype(np.int64)).to(dev).permute(2, 0, 1)
            for sg in _sep_subgraphs
        ]

        self.dir_edge_ids = [
            torch.cat(
                [_edge_ids,
                 torch.stack([_edge_ids[1], _edge_ids[0]], dim=0)],
                dim=1) for _edge_ids in edge_ids
        ]
        self.n_nodes = [eids.max() + 1 for eids in edge_ids]
        self.edge_ids, (self.n_offs, self.e_offs) = collate_edges(edge_ids)
        for i in range(len(self.cfg.s_subgraph)):
            subgraphs.append(
                torch.cat([
                    sg + self.n_offs[i]
                    for i, sg in enumerate(_subgraphs[i * bs:(i + 1) * bs])
                ], -2).flatten(-2, -1))
            self.sep_subgraphs.append(
                torch.cat(_sep_subgraphs[i * bs:(i + 1) * bs],
                          -2).flatten(-2, -1))

        self.subgraphs = subgraphs
        self.subgraph_indices = get_edge_indices(self.edge_ids, subgraphs)

        # for i, sz in enumerate(self.cfg.s_subgraph):
        #     if not self.subgraph_indices[i].max() == self.edge_ids.shape[1] - 1:
        #         pass
        batched_sp = []
        for sp, off in zip(self.init_sp_seg, self.n_offs):
            batched_sp.append(sp + off)
        self.batched_sp_seg = torch.stack(batched_sp, 0)

        self.gt_edge_weights = gt_edges
        if self.gt_edge_weights is not None:
            self.gt_edge_weights = torch.cat(self.gt_edge_weights)
            self.gt_soln = self.get_current_soln(self.gt_edge_weights)
            self.sg_gt_edges = [
                self.gt_edge_weights[sg].view(-1, sz)
                for sz, sg in zip(self.cfg.s_subgraph, self.subgraph_indices)
            ]

        self.edge_features = torch.cat(
            [edge_angles, torch.cat(edge_features, 0)[:, :2]], 1).float()
        self.current_edge_weights = torch.ones(self.edge_ids.shape[1],
                                               device=self.edge_ids.device) / 2

        return