Пример #1
0
    def __init__(self,
                 args,
                 device,
                 writer=None,
                 writer_counter=None,
                 win_event_counter=None):
        super(SpGcnEnv, self).__init__()
        self.stop_quality = 0

        self.reset()
        self.args = args
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.win_event_counter = win_event_counter
        self.discrete_action_space = False

        if self.args.reward_function == 'fully_supervised':
            self.reward_function = FullySupervisedReward(env=self)
        elif self.args.reward_function == 'object_level':
            self.reward_function = ObjectLevelReward(env=self)
        elif self.args.reward_function == 'graph_dice':
            self.reward_function = GraphDiceReward(env=self)
        elif self.args.reward_function == 'focal':
            self.reward_function = FocalReward(env=self)
        elif self.args.reward_function == 'global_sparse':
            self.reward_function = GlobalSparseReward(env=self)
        else:
            self.reward_function = UnSupervisedReward(env=self)
Пример #2
0
    def __init__(self, embedding_net, cfg, device, writer=None, writer_counter=None):
        super(EmbeddingSpaceEnvNodeBased, self).__init__()
        self.embedding_net = embedding_net
        self.reset()
        self.cfg = cfg
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.last_final_reward = torch.tensor([0.0])
        self.max_p = torch.nn.MaxPool2d(3, padding=1, stride=1)
        self.step_encoder = TemporalSineEncoding(max_step=cfg.trn.max_episode_length,
                                                 size=cfg.fe.n_embedding_features)

        if self.cfg.trn.reward_function == 'sub_graph_dice':
            self.reward_function = SubGraphDiceReward()
        else:
            self.reward_function = UnSupervisedReward(env=self)

        self.cluster_policy = nagglo.cosineDistNodeAndEdgeWeightedClusterPolicy
Пример #3
0
    def __init__(self, args, device, writer=None, writer_counter=None, win_event_counter=None):
        super(SpGcnEnv, self).__init__()

        self.reset()
        self.args = args
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.discrete_action_space = False

        if self.args.reward_function == 'fully_supervised':
            self.reward_function = FullySupervisedReward(env=self)
        elif self.args.reward_function == 'sub_graph_dice':
            self.reward_function = SubGraphDiceReward(env=self)
        elif self.args.reward_function == 'defining_rules':
            self.reward_function = HoughCircles(env=self, range_num=[8, 10],
                                                range_rad=[max(self.args.data_shape) // 18,
                                                           max(self.args.data_shape) // 15], min_hough_confidence=0.7)
        elif self.args.reward_function == 'defining_rules_lg':
            self.reward_function = HoughCircles_lg(env=self, range_num=[8, 10],
                                                range_rad=[max(self.args.data_shape) // 18,
                                                           max(self.args.data_shape) // 15], min_hough_confidence=0.7)
        # elif self.args.reward_function == 'focal':
        #     self.reward_function = FocalReward(env=self)
        # elif self.args.reward_function == 'global_sparse':
        #     self.reward_function = GlobalSparseReward(env=self)
        else:
            self.reward_function = UnSupervisedReward(env=self)
Пример #4
0
    def __init__(self, cfg, device, writer=None, writer_counter=None):
        super(SpGcnEnv, self).__init__()

        self.reset()
        self.cfg = cfg
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.discrete_action_space = False
        self.max_p = torch.nn.MaxPool2d(3, padding=1, stride=1)

        if self.cfg.sac.reward_function == 'fully_supervised':
            self.reward_function = FullySupervisedReward(env=self)
        elif self.cfg.sac.reward_function == 'sub_graph_dice':
            self.reward_function = SubGraphDiceReward(env=self)
        elif self.cfg.sac.reward_function == 'defining_rules_edge_based':
            self.reward_function = HoughCircles(
                env=self,
                range_num=[8, 10],
                range_rad=[
                    max(self.cfg.sac.data_shape) // 18,
                    max(self.cfg.sac.data_shape) // 15
                ],
                min_hough_confidence=0.7)
        elif self.cfg.sac.reward_function == 'defining_rules_sp_based':
            self.reward_function = HoughCirclesOnSp(
                env=self,
                range_num=[8, 10],
                range_rad=[
                    max(self.cfg.sac.data_shape) // 18,
                    max(self.cfg.sac.data_shape) // 15
                ],
                min_hough_confidence=0.7)
        elif self.cfg.sac.reward_function == 'defining_rules_lg':
            assert False
        else:
            self.reward_function = UnSupervisedReward(env=self)
Пример #5
0
class EmbeddingSpaceEnvNodeBased():

    State = collections.namedtuple("State", ["node_embeddings", "edge_ids", "edge_angles", "sup_masses", "subgraph_indices", "sep_subgraphs", "subgraphs", "gt_edge_weights"])

    def __init__(self, embedding_net, cfg, device, writer=None, writer_counter=None):
        super(EmbeddingSpaceEnvNodeBased, self).__init__()
        self.embedding_net = embedding_net
        self.reset()
        self.cfg = cfg
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.last_final_reward = torch.tensor([0.0])
        self.max_p = torch.nn.MaxPool2d(3, padding=1, stride=1)
        self.step_encoder = TemporalSineEncoding(max_step=cfg.trn.max_episode_length,
                                                 size=cfg.fe.n_embedding_features)

        if self.cfg.trn.reward_function == 'sub_graph_dice':
            self.reward_function = SubGraphDiceReward()
        else:
            self.reward_function = UnSupervisedReward(env=self)

        self.cluster_policy = nagglo.cosineDistNodeAndEdgeWeightedClusterPolicy
        # self.cluster_policy = nagglo.nodeAndEdgeWeightedClusterPolicy

    def execute_action(self, actions, logg_vals=None, post_stats=False):

        self.current_node_embeddings += actions

        # normalize
        self.current_node_embeddings /= (torch.norm(self.current_node_embeddings, dim=-1, keepdim=True) + 1e-10)

        self.current_soln, node_labeling = self.get_soln_graph_clustering(self.current_node_embeddings)

        sg_edge_weights = []
        for i, sz in enumerate(self.cfg.trn.s_subgraph):
            sg_ne = node_labeling[self.subgraphs[i].view(2, -1, sz)]
            sg_edge_weights.append((sg_ne[0] == sg_ne[1]).float())

        reward = self.reward_function.get(sg_edge_weights, self.sg_gt_edges) #self.current_soln)
        reward.append(self.last_final_reward)

        self.counter += 1
        if self.counter >= self.cfg.trn.max_episode_length:
            self.done = True
            ne = node_labeling[self.edge_ids]
            edge_weights = ((ne[0] == ne[1]).float())
            self.last_final_reward = self.reward_function.get_global(edge_weights, self.gt_edge_weights)

        total_reward = 0
        for _rew in reward:
            total_reward += _rew.mean().item()
        total_reward /= len(self.cfg.trn.s_subgraph)

        if self.writer is not None and post_stats:
            self.writer.add_scalar("step/avg_return", total_reward, self.writer_counter.value())
            if self.writer_counter.value() % 20 == 0:
                fig, (a0, a1, a2, a3, a4) = plt.subplots(1, 5, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0})
                a0.imshow(self.gt_seg[0].cpu().squeeze())
                a0.set_title('gt')
                a1.imshow(self.raw[0].cpu().permute(1,2,0).squeeze())
                a1.set_title('raw image')
                a2.imshow(cm.prism(self.init_sp_seg[0].cpu() / self.init_sp_seg[0].max().item()))
                a2.set_title('superpixels')
                a3.imshow(cm.prism(self.gt_soln[0].cpu()/self.gt_soln[0].max().item()))
                a3.set_title('gt')
                a4.imshow(cm.prism(self.current_soln[0].cpu()/self.current_soln[0].max().item()))
                a4.set_title('prediction')
                self.writer.add_figure("image/state", fig, self.writer_counter.value() // 10)
                # self.writer.add_figure("image/shift_proj", self.vis_node_actions(actions.cpu(), 0), self.writer_counter.value() // 10)
                self.embedding_net.post_pca(get_angles(self.embeddings)[0].cpu(), tag="image/pix_embedding_proj")
                self.embedding_net.post_pca(get_angles(self.current_node_embeddings[:self.n_offs[1]][self.init_sp_seg[0].long()].permute(2, 0, 1)[None])[0].cpu(),
                                            tag="image/node_embedding_proj")

            if logg_vals is not None:
                for key, val in logg_vals.items():
                    self.writer.add_scalar("step/" + key, val, self.writer_counter.value())
            self.writer_counter.increment()

        self.acc_reward.append(total_reward)
        return self.get_state(), reward

    def get_state(self):
        temp_code = self.step_encoder(self.counter, self.current_node_embeddings.device)[None]
        state = self.current_node_embeddings + temp_code
        return self.State(state, self.edge_ids, self.edge_angles, self.sup_masses, self.subgraph_indices,
                          self.sep_subgraphs, self.subgraphs, self.gt_edge_weights)

    def update_data(self, edge_ids, gt_edges, sp_seg, raw, gt, **kwargs):
        bs = len(edge_ids)
        dev = edge_ids[0].device
        subgraphs, self.sep_subgraphs = [], []
        self.gt_seg = gt.squeeze(1)
        self.raw = raw
        self.init_sp_seg = sp_seg.squeeze(1)
        edge_angles, sup_masses, sup_com = zip(*[get_angles_smass_in_rag(edges, sp) for edges, sp in zip(edge_ids, self.init_sp_seg)])
        self.edge_angles, self.sup_masses, self.sup_com = torch.cat(edge_angles).unsqueeze(-1), torch.cat(sup_masses).unsqueeze(-1), torch.cat(sup_com)
        self.init_sp_seg_edge = torch.cat([(-self.max_p(-sp_seg) != sp_seg).float(), (self.max_p(sp_seg) != sp_seg).float()], 1)

        _subgraphs, _sep_subgraphs = find_dense_subgraphs([eids.transpose(0, 1).cpu().numpy() for eids in edge_ids], self.cfg.trn.s_subgraph)
        _subgraphs = [torch.from_numpy(sg.astype(np.int64)).to(dev).permute(2, 0, 1) for sg in _subgraphs]
        _sep_subgraphs = [torch.from_numpy(sg.astype(np.int64)).to(dev).permute(2, 0, 1) for sg in _sep_subgraphs]

        self.n_nodes = [eids.max() + 1 for eids in edge_ids]
        self.edge_ids, (self.n_offs, self.e_offs) = collate_edges(edge_ids)
        self.dir_edge_ids = torch.cat([self.edge_ids, torch.stack([self.edge_ids[1], self.edge_ids[0]], dim=0)], dim=1)
        for i in range(len(self.cfg.trn.s_subgraph)):
            subgraphs.append(torch.cat([sg + self.n_offs[i] for i, sg in enumerate(_subgraphs[i*bs:(i+1)*bs])], -2).flatten(-2, -1))
            self.sep_subgraphs.append(torch.cat(_sep_subgraphs[i*bs:(i+1)*bs], -2).flatten(-2, -1))

        self.subgraphs = subgraphs
        self.subgraph_indices = get_edge_indices(self.edge_ids, subgraphs)

        self.gt_edge_weights = torch.cat(gt_edges)
        self.gt_soln = self.get_mc_soln(self.gt_edge_weights)
        self.sg_gt_edges = [self.gt_edge_weights[sg].view(-1, sz) for sz, sg in
                            zip(self.cfg.trn.s_subgraph, self.subgraph_indices)]

        self.embeddings = self.embedding_net(self.raw).detach()
        # get embedding agglomeration over each superpixel
        self.current_node_embeddings = torch.cat([self.embedding_net.get_mean_sp_embedding(embed, sp) for embed, sp
                                                  in zip(self.embeddings, self.init_sp_seg)], dim=0)

        return

    def get_batched_actions_from_global_graph(self, actions):
        b_actions = torch.zeros(size=(self.edge_ids.shape[1],))
        other = torch.zeros_like(self.subgraph_indices)
        for i in range(self.edge_ids.shape[1]):
            mask = (self.subgraph_indices == i)
            num = mask.float().sum()
            b_actions[i] = torch.where(mask, actions.float(), other.float()).sum() / num
        return b_actions

    def get_soln_free_clustering(self, node_features):
        labels = []
        node_labels = []
        for i, sp_seg in enumerate(self.init_sp_seg):
            single_node_features = node_features[self.n_offs[i]:self.n_offs[i+1]]
            z_linkage = linkage(single_node_features.cpu(), 'ward')
            node_labels.append(fcluster(z_linkage, self.cfg.gen.n_max_object, criterion='maxclust'))
            rag = elf.segmentation.features.compute_rag(np.expand_dims(sp_seg.cpu(), axis=0))
            labels.append(elf.segmentation.features.project_node_labels_to_pixels(rag, node_labels[-1]).squeeze())

        return torch.from_numpy(np.stack(labels).astype(np.float)).to(node_features.device), \
               torch.from_numpy(np.concatenate(node_labels).astype(np.float)).to(node_features.device)

    def get_soln_graph_clustering(self, node_features):
        labels = []
        node_labels = []
        for i, sp_seg in enumerate(self.init_sp_seg):
            single_node_features = node_features[self.n_offs[i]:self.n_offs[i+1]].detach().cpu().numpy()
            rag = nifty.graph.undirectedGraph(single_node_features.shape[0])
            rag.insertEdges((self.edge_ids[:, self.e_offs[i]:self.e_offs[i+1]] - self.n_offs[i]).T.detach().cpu().numpy())

            edge_weights = np.ones(rag.numberOfEdges, dtype=np.int)
            edge_sizes = np.ones(rag.numberOfEdges, dtype=np.int)
            node_sizes = np.ones(rag.numberOfNodes, dtype=np.int)

            policy = self.cluster_policy(
                graph=rag,
                edgeIndicators=edge_weights,
                edgeSizes=edge_sizes,
                nodeFeatures=single_node_features,
                nodeSizes=node_sizes,
                numberOfNodesStop=self.cfg.gen.n_max_object,
                beta = 1,
                sizeRegularizer = 0
            )
            clustering = nagglo.agglomerativeClustering(policy)
            clustering.run()

            node_labels.append(clustering.result())
            rag = elf.segmentation.features.compute_rag(np.expand_dims(sp_seg.cpu(), axis=0))
            labels.append(elf.segmentation.features.project_node_labels_to_pixels(rag, node_labels[-1]).squeeze())
        return torch.from_numpy(np.stack(labels).astype(np.float)).to(node_features.device), \
               torch.from_numpy(np.concatenate(node_labels).astype(np.float)).to(node_features.device)

    def get_mc_soln(self, edge_weights):
        p_min = 0.001
        p_max = 1.
        segmentations = []
        for i in range(1, len(self.e_offs)):
            probs = edge_weights[self.e_offs[i-1]:self.e_offs[i]]
            edges = self.edge_ids[:, self.e_offs[i-1]:self.e_offs[i]] - self.n_offs[i-1]
            costs = (p_max - p_min) * probs + p_min
            # probabilities to costs
            costs = (torch.log((1. - costs) / costs)).detach().cpu().numpy()
            graph = nifty.graph.undirectedGraph(self.n_nodes[i-1])
            graph.insertEdges(edges.T.cpu().numpy())

            node_labels = elf.segmentation.multicut.multicut_kernighan_lin(graph, costs)

            mc_seg = torch.zeros_like(self.init_sp_seg[i-1])
            for j, lbl in enumerate(node_labels):
                mc_seg += (self.init_sp_seg[i-1] == j).float() * lbl

            segmentations.append(mc_seg)
        return torch.stack(segmentations, dim=0)

    def get_node_gt(self):
        b_node_seg = torch.zeros(self.n_offs[-1], device=self.gt_seg.device)
        for i, (sp_seg, gt) in enumerate(zip(self.init_sp_seg, self.gt_seg)):
            for node_it in range(self.n_nodes[i]):
                nums = torch.bincount(((sp_seg == node_it).long() * (gt.long() + 1)).view(-1))
                b_node_seg[node_it + self.n_offs[i]] = nums[1:].argmax() - 1
        return b_node_seg

    def vis_node_actions(self, shifts, sb=0):
        plt.clf()
        fig = plt.figure()
        shifts = shifts[self.n_offs[sb]:self.n_offs[sb+1]]
        n = shifts.shape[0]
        proj = pca_project_1d(shifts, 8)
        proj = np.concatenate((proj[:2], proj[2:4], proj[4:6], proj[6:8]), 1)
        colors = n*["b"] + n*["g"] + n*["r"] + n*["c"]
        com = np.round(self.sup_com[self.n_offs[sb]:self.n_offs[sb+1]].cpu())
        com = np.concatenate((com, )*4, 0)
        plt.imshow((self.gt_seg[sb]*(self.init_sp_seg_edge[sb, 0] == 0) + self.init_sp_seg_edge[sb, 0] * 10).cpu())
        plt.quiver(com[:, 1], com[:, 0], proj[0], proj[1], color=colors, width=0.005)
        return fig

    def reset(self):
        self.done = False
        self.acc_reward = []
        self.counter = 0
Пример #6
0
class SpGcnEnv(Environment):
    def __init__(self, args, device, writer=None, writer_counter=None, win_event_counter=None):
        super(SpGcnEnv, self).__init__()

        self.reset()
        self.args = args
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.discrete_action_space = False

        if self.args.reward_function == 'fully_supervised':
            self.reward_function = FullySupervisedReward(env=self)
        elif self.args.reward_function == 'sub_graph_dice':
            self.reward_function = SubGraphDiceReward(env=self)
        elif self.args.reward_function == 'defining_rules':
            self.reward_function = HoughCircles(env=self, range_num=[8, 10],
                                                range_rad=[max(self.args.data_shape) // 18,
                                                           max(self.args.data_shape) // 15], min_hough_confidence=0.7)
        elif self.args.reward_function == 'defining_rules_lg':
            self.reward_function = HoughCircles_lg(env=self, range_num=[8, 10],
                                                range_rad=[max(self.args.data_shape) // 18,
                                                           max(self.args.data_shape) // 15], min_hough_confidence=0.7)
        # elif self.args.reward_function == 'focal':
        #     self.reward_function = FocalReward(env=self)
        # elif self.args.reward_function == 'global_sparse':
        #     self.reward_function = GlobalSparseReward(env=self)
        else:
            self.reward_function = UnSupervisedReward(env=self)

    def execute_action(self, actions, logg_vals=None, post_stats=False):
        last_diff = (self.sg_current_edge_weights - self.sg_gt_edge_weights).squeeze().abs()

        self.b_current_edge_weights = actions.clone()
        self.sg_current_edge_weights = actions[self.b_subgraph_indices].view(-1, self.args.s_subgraph)

        reward = self.reward_function.get(actions, self.get_current_soln())

        quality = (self.sg_current_edge_weights - self.sg_gt_edge_weights).squeeze().abs().sum().item()
        self.counter += 1
        if self.counter >= self.args.max_episode_length:
            self.done = True

        total_reward = torch.sum(reward[0]).item()

        if self.writer is not None and post_stats:
            self.writer.add_scalar("step/quality", quality, self.writer_counter.value())
            self.writer.add_scalar("step/avg_return_1", reward[0].mean(), self.writer_counter.value())
            self.writer.add_scalar("step/avg_return_2", reward[1].mean(), self.writer_counter.value())
            if self.writer_counter.value() % 80 == 0:
                self.writer.add_histogram("step/pred_mean", self.sg_current_edge_weights.view(-1).cpu().numpy(), self.writer_counter.value() // 80)
            self.writer.add_scalar("step/gt_mean", self.sg_gt_edge_weights.mean(), self.writer_counter.value())
            self.writer.add_scalar("step/gt_std", self.sg_gt_edge_weights.std(), self.writer_counter.value())
            if logg_vals is not None:
                for key, val in logg_vals.items():
                    self.writer.add_scalar("step/" + key, val, self.writer_counter.value())
            self.writer_counter.increment()

        self.acc_reward = total_reward
        return self.get_state(), reward, quality

    # def get_state(self):
    #     return self.raw, self.b_edge_ids, self.sp_indices, self.b_edge_angles, self.b_subgraph_indices, self.sep_subgraphs, self.counter, self.b_gt_edge_weights

    def get_state(self):
        return self.raw, self.b_edge_ids, self.sp_indices, self.b_edge_angles, self.b_subgraph_indices, self.sep_subgraphs, self.e_offs, self.counter, self.b_gt_edge_weights

    def update_data(self, b_edge_ids, edge_features, diff_to_gt, gt_edge_weights, node_labeling, raw, angles, gt):
        self.gt_seg = gt
        self.raw = raw
        self.init_sp_seg = node_labeling.squeeze()

        self.n_nodes = [edge_ids.max() + 1 for edge_ids in b_edge_ids]
        b_subgraphs, sep_subgraphs = find_dense_subgraphs([edge_ids.transpose(0, 1).cpu().numpy() for edge_ids in b_edge_ids], self.args.s_subgraph)
        self.sep_subgraphs = torch.from_numpy(sep_subgraphs.astype(np.int64))
        self.b_edge_ids, (self.n_offs, self.e_offs) = collate_edges(b_edge_ids)
        b_subgraphs = [torch.from_numpy(sg.astype(np.int64)).to(self.device).view(-1, 2).transpose(0, 1) + self.n_offs[i] for i, sg in enumerate(b_subgraphs)]
        self.sg_offs = [0]
        for i in range(len(b_subgraphs)):
            self.sg_offs.append(self.sg_offs[-1] + b_subgraphs[i].shape[-1])
        self.b_subgraphs = torch.cat(b_subgraphs, 1)

        self.b_subgraph_indices = get_edge_indices(self.b_edge_ids, self.b_subgraphs)
        self.b_gt_edge_weights = torch.cat(gt_edge_weights, 0)
        self.sg_gt_edge_weights = self.b_gt_edge_weights[self.b_subgraph_indices].view(-1, self.args.s_subgraph)
        self.b_edge_angles = angles
        self.sg_current_edge_weights = torch.ones_like(self.sg_gt_edge_weights) / 2

        self.b_initial_edge_weights = torch.cat([edge_fe[:, 0] for edge_fe in edge_features], dim=0)
        self.b_current_edge_weights = torch.ones_like(self.b_initial_edge_weights) / 2

        stacked_superpixels = [[node_labeling[i] == n for n in range(n_node)] for i, n_node in enumerate(self.n_nodes)]
        self.sp_indices = [[sp.nonzero().cpu() for sp in stacked_superpixel] for stacked_superpixel in stacked_superpixels]

        # self.b_penalize_diff_thresh = diff_to_gt * 4
        # plt.imshow(self.get_current_soln_pic(1));plt.show()
        # return

    def get_batched_actions_from_global_graph(self, actions):
        b_actions = torch.zeros(size=(self.b_edge_ids.shape[1],))
        other = torch.zeros_like(self.b_subgraph_indices)
        for i in range(self.b_edge_ids.shape[1]):
            mask = (self.b_subgraph_indices == i)
            num = mask.float().sum()
            b_actions[i] = torch.where(mask, actions.float(), other.float()).sum() / num
        return b_actions

    def get_current_soln(self):
        p_min = 0.001
        p_max = 1.
        segmentations = []
        for i in range(1, len(self.e_offs)):
            probs = self.b_current_edge_weights[self.e_offs[i-1]:self.e_offs[i]]
            # probs = self.b_gt_edge_weights[self.e_offs[i-1]:self.e_offs[i]]
            edges = self.b_edge_ids[:, self.e_offs[i-1]:self.e_offs[i]] - self.n_offs[i-1]
            costs = (p_max - p_min) * probs + p_min
            # probabilities to costs
            costs = (torch.log((1. - costs) / costs)).detach().cpu().numpy()
            graph = nifty.graph.undirectedGraph(self.n_nodes[i-1])
            graph.insertEdges(edges.T.cpu().numpy())

            node_labels = elf.segmentation.multicut.multicut_kernighan_lin(graph, costs)

            mc_seg = torch.zeros_like(self.init_sp_seg[i-1])
            for j, lbl in enumerate(node_labels):
                mc_seg += (self.init_sp_seg[i-1] == j).float() * lbl

            segmentations.append(mc_seg)
        return torch.stack(segmentations, dim=0)
        # return torch.rand_like(self.init_sp_seg)

    def get_current_soln_pic(self, b):
        b_actions = self.get_batched_actions_from_global_graph(self.sg_current_edge_weights.view(-1))
        b_gt = self.get_batched_actions_from_global_graph(self.sg_gt_edge_weights.view(-1))

        edge_ids = self.b_edge_ids[:, self.e_offs[b]: self.e_offs[b+1]] - self.n_offs[b]
        edge_ids = edge_ids.cpu().t().contiguous().numpy()
        boundary_input = self.b_initial_edge_weights[self.e_offs[b]: self.e_offs[b+1]].cpu().numpy()
        mc_seg1 = general.multicut_from_probas(self.init_sp_seg[b].squeeze().cpu(), edge_ids,
                                               self.b_initial_edge_weights[self.e_offs[b]: self.e_offs[b+1]].cpu().numpy(), boundary_input)
        mc_seg = general.multicut_from_probas(self.init_sp_seg[b].squeeze().cpu(), edge_ids,
                                              b_actions[self.e_offs[b]: self.e_offs[b+1]].cpu().numpy(), boundary_input)
        gt_mc_seg = general.multicut_from_probas(self.init_sp_seg[b].squeeze().cpu(), edge_ids,
                                                 b_gt[self.e_offs[b]: self.e_offs[b+1]].cpu().numpy(), boundary_input)
        mc_seg = cm.prism(mc_seg / mc_seg.max())
        mc_seg1 = cm.prism(mc_seg1 / mc_seg1.max())
        seg = cm.prism(self.init_sp_seg[b].squeeze().cpu() / self.init_sp_seg[b].cpu().max())
        gt_mc_seg = cm.prism(gt_mc_seg / gt_mc_seg.max())
        return np.concatenate((np.concatenate((mc_seg1, mc_seg), 0), np.concatenate((gt_mc_seg, seg), 0)), 1)
        ####################
        # init mc # gt seg #
        ####################
        # curr mc # sp seg #
        ####################

    def reset(self):
        self.done = False
        self.win = False
        self.acc_reward = 0
        self.last_reward = -inf
        self.counter = 0
Пример #7
0
class SpGcnEnv(Environment):
    def __init__(self, cfg, device, writer=None, writer_counter=None):
        super(SpGcnEnv, self).__init__()

        self.reset()
        self.cfg = cfg
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.discrete_action_space = False
        self.max_p = torch.nn.MaxPool2d(3, padding=1, stride=1)

        if self.cfg.sac.reward_function == 'fully_supervised':
            self.reward_function = FullySupervisedReward(env=self)
        elif self.cfg.sac.reward_function == 'sub_graph_dice':
            self.reward_function = SubGraphDiceReward(env=self)
        elif self.cfg.sac.reward_function == 'defining_rules_edge_based':
            self.reward_function = HoughCircles(
                env=self,
                range_num=[8, 10],
                range_rad=[
                    max(self.cfg.sac.data_shape) // 18,
                    max(self.cfg.sac.data_shape) // 15
                ],
                min_hough_confidence=0.7)
        elif self.cfg.sac.reward_function == 'defining_rules_sp_based':
            self.reward_function = HoughCirclesOnSp(
                env=self,
                range_num=[8, 10],
                range_rad=[
                    max(self.cfg.sac.data_shape) // 18,
                    max(self.cfg.sac.data_shape) // 15
                ],
                min_hough_confidence=0.7)
        elif self.cfg.sac.reward_function == 'defining_rules_lg':
            assert False
        else:
            self.reward_function = UnSupervisedReward(env=self)

    def execute_action(self, actions, logg_vals=None, post_stats=False):
        # last_diff = (self.sg_current_edge_weights - self.sg_gt_edge_weights).squeeze().abs()
        self.current_edge_weights = actions

        self.sg_current_edge_weights = []
        for i, sz in enumerate(self.cfg.sac.s_subgraph):
            self.sg_current_edge_weights.append(
                self.current_edge_weights[self.subgraph_indices[i].view(
                    -1, sz)])

        self.current_soln = self.get_current_soln(self.current_edge_weights)
        reward = self.reward_function.get(
            self.sg_current_edge_weights,
            self.sg_gt_edge_weights)  #self.current_soln)
        # reward = self.reward_function.get(actions, self.get_current_soln(self.gt_edge_weights))
        # reward = self.reward_function.get(actions=self.sg_current_edge_weights)

        self.counter += 1
        if self.counter >= self.cfg.trainer.max_episode_length:
            self.done = True

        total_reward = 0
        for _rew in reward:
            total_reward += _rew.mean().item()
        total_reward /= len(self.cfg.sac.s_subgraph)

        if self.writer is not None and post_stats:
            self.writer.add_scalar("step/avg_return", total_reward,
                                   self.writer_counter.value())
            if self.writer_counter.value() % 10 == 0:
                self.writer.add_histogram(
                    "step/pred_mean",
                    self.current_edge_weights.view(-1).cpu().numpy(),
                    self.writer_counter.value() // 10)
                fig, (a1, a2, a3, a4) = plt.subplots(1,
                                                     4,
                                                     sharex='col',
                                                     sharey='row',
                                                     gridspec_kw={
                                                         'hspace': 0,
                                                         'wspace': 0
                                                     })
                a1.imshow(self.raw[0].cpu().permute(1, 2, 0).squeeze(),
                          cmap='hot')
                a1.set_title('raw image')
                a2.imshow(
                    cm.prism(self.init_sp_seg[0].cpu() /
                             self.init_sp_seg[0].max().item()))
                a2.set_title('superpixels')
                a3.imshow(
                    cm.prism(self.gt_soln[0].cpu() /
                             self.gt_soln[0].max().item()))
                a3.set_title('gt')
                a4.imshow(
                    cm.prism(self.current_soln[0].cpu() /
                             self.current_soln[0].max().item()))
                a4.set_title('prediction')
                self.writer.add_figure("image/state", fig,
                                       self.writer_counter.value() // 10)
            self.writer.add_scalar("step/gt_mean",
                                   self.gt_edge_weights.mean().item(),
                                   self.writer_counter.value())
            self.writer.add_scalar("step/gt_std",
                                   self.gt_edge_weights.std().item(),
                                   self.writer_counter.value())
            if logg_vals is not None:
                for key, val in logg_vals.items():
                    self.writer.add_scalar("step/" + key, val,
                                           self.writer_counter.value())
            self.writer_counter.increment()

        self.acc_reward.append(total_reward)
        return self.get_state(), reward

    def get_state(self):
        return torch.cat([self.raw, self.init_sp_seg_edge], 1), self.init_sp_seg, self.edge_ids, self.sp_indices, \
               self.edge_angles, self.subgraph_indices, self.sep_subgraphs, self.counter, self.gt_edge_weights, self.e_offs

    def update_data(self, edge_ids, edge_features, diff_to_gt, gt_edge_weights,
                    sp_seg, raw, gt):
        bs = len(edge_ids)
        dev = edge_ids[0].device
        subgraphs, self.sep_subgraphs = [], []
        self.gt_seg = gt.squeeze(1)
        self.raw = raw
        self.init_sp_seg = sp_seg.squeeze(1)
        self.edge_angles = torch.cat([
            get_angles_in_rag(edges, sp)
            for edges, sp in zip(edge_ids, self.init_sp_seg)
        ]).unsqueeze(-1)
        self.init_sp_seg_edge = torch.cat(
            [(-self.max_p(-sp_seg) != sp_seg).float(),
             (self.max_p(sp_seg) != sp_seg).float()], 1)

        _subgraphs, _sep_subgraphs = find_dense_subgraphs(
            [edge_ids.transpose(0, 1).cpu().numpy() for edge_ids in edge_ids],
            self.cfg.sac.s_subgraph)
        _subgraphs = [
            torch.from_numpy(sg.astype(np.int64)).to(dev).transpose(-3, -1)
            for sg in _subgraphs
        ]
        _sep_subgraphs = [
            torch.from_numpy(sg.astype(np.int64)).to(dev).transpose(-3, -1)
            for sg in _sep_subgraphs
        ]

        self.n_nodes = [edge_ids.max() + 1 for edge_ids in edge_ids]
        self.edge_ids, (self.n_offs, self.e_offs) = collate_edges(edge_ids)
        for i in range(len(self.cfg.sac.s_subgraph)):
            subgraphs.append(
                torch.cat([
                    sg + self.n_offs[i]
                    for i, sg in enumerate(_subgraphs[i * bs:(i + 1) * bs])
                ], -1).flatten(-2, -1))
            self.sep_subgraphs.append(
                torch.cat(_sep_subgraphs[i * bs:(i + 1) * bs],
                          -1).flatten(-2, -1))

        self.subgraph_indices = get_edge_indices(self.edge_ids, subgraphs)
        self.gt_edge_weights = torch.cat(gt_edge_weights, 0)
        self.sg_gt_edge_weights = [
            self.gt_edge_weights[sg].view(-1, sz)
            for sz, sg in zip(self.cfg.sac.s_subgraph, self.subgraph_indices)
        ]

        self.sg_current_edge_weights = [
            torch.ones_like(sg) / 2 for sg in self.sg_gt_edge_weights
        ]

        self.initial_edge_weights = torch.cat(
            [edge_fe[:, 0] for edge_fe in edge_features], dim=0)
        self.current_edge_weights = self.initial_edge_weights.clone()
        self.gt_soln = self.get_current_soln(self.gt_edge_weights)

        stacked_superpixels = [[sp_seg[i] == n for n in range(n_node)]
                               for i, n_node in enumerate(self.n_nodes)]
        self.sp_indices = [[
            torch.nonzero(sp, as_tuple=False) for sp in stacked_superpixel
        ] for stacked_superpixel in stacked_superpixels]

        # cs = self.get_current_soln(self.b_gt_edge_weights)
        # fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
        # ax1.imshow(cm.prism(self.gt_seg[0].detach().cpu().numpy() / self.gt_seg[0].detach().cpu().numpy().max()));
        # ax1.set_title('gt')
        # ax2.imshow(cm.prism(self.init_sp_seg[0].detach().cpu().numpy() / self.init_sp_seg[0].detach().cpu().numpy().max()));
        # ax2.set_title('sp')
        # ax3.imshow(cm.prism(cs[0].detach().cpu().numpy() / cs[0].detach().cpu().numpy().max()));
        # ax3.set_title('mc')
        # plt.show()
        # a=1

        # self.b_penalize_diff_thresh = diff_to_gt * 4
        # plt.imshow(self.get_current_soln_pic(1));plt.show()
        return

    def get_batched_actions_from_global_graph(self, actions):
        b_actions = torch.zeros(size=(self.edge_ids.shape[1], ))
        other = torch.zeros_like(self.subgraph_indices)
        for i in range(self.edge_ids.shape[1]):
            mask = (self.subgraph_indices == i)
            num = mask.float().sum()
            b_actions[i] = torch.where(mask, actions.float(),
                                       other.float()).sum() / num
        return b_actions

    def get_current_soln(self, edge_weights):
        p_min = 0.001
        p_max = 1.
        segmentations = []
        for i in range(1, len(self.e_offs)):
            probs = edge_weights[self.e_offs[i - 1]:self.e_offs[i]]
            edges = self.edge_ids[:, self.e_offs[i - 1]:self.
                                  e_offs[i]] - self.n_offs[i - 1]
            costs = (p_max - p_min) * probs + p_min
            # probabilities to costs
            costs = (torch.log((1. - costs) / costs)).detach().cpu().numpy()
            graph = nifty.graph.undirectedGraph(self.n_nodes[i - 1])
            graph.insertEdges(edges.T.cpu().numpy())

            node_labels = elf.segmentation.multicut.multicut_kernighan_lin(
                graph, costs)

            mc_seg = torch.zeros_like(self.init_sp_seg[i - 1])
            for j, lbl in enumerate(node_labels):
                mc_seg += (self.init_sp_seg[i - 1] == j).float() * lbl

            segmentations.append(mc_seg)
        return torch.stack(segmentations, dim=0)
        # return torch.rand_like(self.init_sp_seg)

    def get_current_soln_pic(self, b):
        actions = self.get_batched_actions_from_global_graph(
            self.sg_current_edge_weights.view(-1))
        gt = self.get_batched_actions_from_global_graph(
            self.sg_gt_edge_weights.view(-1))

        edge_ids = self.edge_ids[:, self.e_offs[b]:self.
                                 e_offs[b + 1]] - self.n_offs[b]
        edge_ids = edge_ids.cpu().t().contiguous().numpy()
        boundary_input = self.initial_edge_weights[self.e_offs[b]:self.
                                                   e_offs[b +
                                                          1]].cpu().numpy()
        mc_seg1 = general.multicut_from_probas(
            self.init_sp_seg[b].squeeze().cpu(), edge_ids,
            self.initial_edge_weights[self.e_offs[b]:self.e_offs[b + 1]].cpu(
            ).numpy(), boundary_input)
        mc_seg = general.multicut_from_probas(
            self.init_sp_seg[b].squeeze().cpu(), edge_ids,
            actions[self.e_offs[b]:self.e_offs[b + 1]].cpu().numpy(),
            boundary_input)
        gt_mc_seg = general.multicut_from_probas(
            self.init_sp_seg[b].squeeze().cpu(), edge_ids,
            gt[self.e_offs[b]:self.e_offs[b + 1]].cpu().numpy(),
            boundary_input)
        mc_seg = cm.prism(mc_seg / mc_seg.max())
        mc_seg1 = cm.prism(mc_seg1 / mc_seg1.max())
        seg = cm.prism(self.init_sp_seg[b].squeeze().cpu() /
                       self.init_sp_seg[b].cpu().max())
        gt_mc_seg = cm.prism(gt_mc_seg / gt_mc_seg.max())
        return np.concatenate((np.concatenate(
            (mc_seg1, mc_seg), 0), np.concatenate((gt_mc_seg, seg), 0)), 1)
        ####################
        # init mc # gt seg #
        ####################
        # curr mc # sp seg #
        ####################

    def reset(self):
        self.done = False
        self.acc_reward = []
        self.counter = 0
Пример #8
0
class SpGcnEnv(Environment):
    def __init__(self,
                 args,
                 device,
                 writer=None,
                 writer_counter=None,
                 win_event_counter=None):
        super(SpGcnEnv, self).__init__()
        self.stop_quality = 0

        self.reset()
        self.args = args
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.win_event_counter = win_event_counter
        self.discrete_action_space = False

        if self.args.reward_function == 'fully_supervised':
            self.reward_function = FullySupervisedReward(env=self)
        elif self.args.reward_function == 'object_level':
            self.reward_function = ObjectLevelReward(env=self)
        elif self.args.reward_function == 'graph_dice':
            self.reward_function = GraphDiceReward(env=self)
        elif self.args.reward_function == 'focal':
            self.reward_function = FocalReward(env=self)
        elif self.args.reward_function == 'global_sparse':
            self.reward_function = GlobalSparseReward(env=self)
        else:
            self.reward_function = UnSupervisedReward(env=self)

    def execute_action(self, actions, logg_vals=None):
        last_diff = (self.current_edge_weights -
                     self.gt_edge_weights).squeeze().abs()

        self.current_edge_weights = actions.clone()

        reward = self.reward_function.get(last_diff, actions,
                                          self.get_current_soln()).to(
                                              self.device)

        quality = (self.current_edge_weights -
                   self.gt_edge_weights).squeeze().abs().sum().item()
        self.counter += 1
        if self.counter >= self.args.max_episode_length:
            if quality < self.stop_quality:
                # reward += 2
                self.win = True
            else:
                a = 1
                # reward -= 1

            self.done = True
            self.win_event_counter.increment()

        total_reward = torch.sum(reward).item()

        if self.writer is not None and self.done:
            self.writer.add_scalar("step/quality", quality,
                                   self.writer_counter.value())
            self.writer.add_scalar("step/stop_quality", self.stop_quality,
                                   self.writer_counter.value())
            self.writer.add_scalar("step/n_wins",
                                   self.win_event_counter.value(),
                                   self.writer_counter.value())
            self.writer.add_scalar("step/steps_needed", self.counter,
                                   self.writer_counter.value())
            self.writer.add_scalar("step/win_loose_ratio",
                                   (self.win_event_counter.value() + 1) /
                                   (self.writer_counter.value() + 1),
                                   self.writer_counter.value())
            self.writer.add_scalar("step/pred_mean",
                                   self.current_edge_weights.mean(),
                                   self.writer_counter.value())
            self.writer.add_scalar("step/pred_std",
                                   self.current_edge_weights.std(),
                                   self.writer_counter.value())
            self.writer.add_scalar("step/gt_mean", self.gt_edge_weights.mean(),
                                   self.writer_counter.value())
            self.writer.add_scalar("step/gt_std", self.gt_edge_weights.std(),
                                   self.writer_counter.value())
            if logg_vals is not None:
                for key, val in logg_vals.items():
                    self.writer.add_scalar("step/" + key, val,
                                           self.writer_counter.value())
            self.writer_counter.increment()

        self.acc_reward = total_reward
        state_pixels = torch.stack(
            [self.raw, self.init_sp_seg,
             self.get_current_soln()], dim=0)
        return (state_pixels, self.edge_ids, self.sp_indices, self.edge_angles,
                self.counter), reward, quality

    def get_state(self):
        state_pixels = torch.stack(
            [self.raw, self.init_sp_seg,
             self.get_current_soln()], dim=0)
        return state_pixels, self.edge_ids, self.sp_indices, self.edge_angles, self.counter

    def update_data(self, edge_ids, edge_features, diff_to_gt, gt_edge_weights,
                    node_labeling, raw, nodes, angles, affinities, gt):
        self.gt_seg = gt
        self.affinities = affinities
        self.initial_edge_weights = edge_features[:, 0]
        self.edge_features = edge_features
        self.stacked_superpixels = [node_labeling == n for n in nodes]
        self.sp_indices = [sp.nonzero() for sp in self.stacked_superpixels]
        self.raw = raw
        self.penalize_diff_thresh = diff_to_gt * 4
        self.init_sp_seg = node_labeling.squeeze()
        self.edge_ids = edge_ids
        self.gt_edge_weights = gt_edge_weights
        self.edge_angles = angles
        self.current_edge_weights = torch.ones_like(gt_edge_weights) / 2

    def show_current_soln(self):
        affs = np.expand_dims(self.affinities, axis=1)
        boundary_input = np.mean(affs, axis=0)
        mc_seg1 = general.multicut_from_probas(
            self.init_sp_seg.cpu(),
            self.edge_ids.cpu().t().contiguous().numpy(),
            self.initial_edge_weights.squeeze().cpu().numpy(), boundary_input)
        mc_seg = general.multicut_from_probas(
            self.init_sp_seg.cpu(),
            self.edge_ids.cpu().t().contiguous().numpy(),
            self.current_edge_weights.squeeze().cpu().numpy(), boundary_input)
        gt_mc_seg = general.multicut_from_probas(
            self.init_sp_seg.cpu(),
            self.edge_ids.cpu().t().contiguous().numpy(),
            self.gt_edge_weights.squeeze().cpu().numpy(), boundary_input)
        mc_seg = cm.prism(mc_seg / mc_seg.max())
        mc_seg1 = cm.prism(mc_seg1 / mc_seg1.max())
        seg = cm.prism(self.init_sp_seg.cpu() / self.init_sp_seg.cpu().max())
        gt_mc_seg = cm.prism(gt_mc_seg / gt_mc_seg.max())
        plt.imshow(
            np.concatenate((np.concatenate(
                (mc_seg1, mc_seg), 0), np.concatenate((gt_mc_seg, seg), 0)),
                           1))
        plt.show()
        a = 1
        ####################
        # init mc # gt seg #
        ####################
        # curr mc # sp seg #
        ####################

    def get_current_soln(self):
        # affs = np.expand_dims(self.affinities, axis=1)
        # boundary_input = np.mean(affs, axis=0)
        # mc_seg = general.multicut_from_probas(self.init_sp_seg.squeeze().cpu(), self.edge_ids.cpu().t().contiguous().numpy(),
        #                                       self.current_edge_weights.squeeze().cpu().numpy(), boundary_input)
        # return torch.from_numpy(mc_seg.astype(np.float32))
        return torch.rand_like(self.init_sp_seg.squeeze())

    def get_rag_and_edge_feats(self, reward, edges):
        edge_indices = []
        seg = self.init_sp_seg.clone()
        for edge in self.edge_ids.t():
            n1, n2 = self.sp_indices[edge[0]], self.sp_indices[edge[1]]
            dis = torch.cdist(n1.float(), n2.float())
            dis = (dis <= 1).nonzero()
            inds_n1 = n1[dis[:, 0].unique()]
            inds_n2 = n2[dis[:, 1].unique()]
            edge_indices.append(torch.cat((inds_n1, inds_n2), 0))
        for indices in edge_indices:
            seg[indices[:, 0], indices[:, 1]] = 600
        seg = cm.prism(seg.cpu().numpy() / seg.cpu().numpy().max())
        plt.imshow(seg)
        plt.show()

    def reset(self):
        self.done = False
        self.win = False
        self.acc_reward = 0
        self.last_reward = -inf
        self.counter = 0
Пример #9
0
class SpGcnEnv(Environment):
    def __init__(self,
                 args,
                 device,
                 writer=None,
                 writer_counter=None,
                 win_event_counter=None):
        super(SpGcnEnv, self).__init__()
        self.stop_quality = 0

        self.reset()
        self.args = args
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.win_event_counter = win_event_counter
        self.discrete_action_space = False

        if self.args.reward_function == 'fully_supervised':
            self.reward_function = FullySupervisedReward(env=self)
        elif self.args.reward_function == 'object_level':
            self.reward_function = ObjectLevelReward(env=self)
        elif self.args.reward_function == 'graph_dice':
            self.reward_function = GraphDiceReward(env=self)
        elif self.args.reward_function == 'focal':
            self.reward_function = FocalReward(env=self)
        elif self.args.reward_function == 'global_sparse':
            self.reward_function = GlobalSparseReward(env=self)
        else:
            self.reward_function = UnSupervisedReward(env=self)

    def execute_action(self, actions, logg_vals=None):
        last_diff = (self.b_current_edge_weights -
                     self.b_gt_edge_weights).squeeze().abs()

        self.b_current_edge_weights = actions.clone()

        reward = self.reward_function.get(last_diff, actions,
                                          self.get_current_soln()).to(
                                              self.device)

        quality = (self.b_current_edge_weights -
                   self.b_gt_edge_weights).squeeze().abs().sum().item()
        self.counter += 1
        if self.counter >= self.args.max_episode_length:
            if quality < self.stop_quality:
                # reward += 2
                self.win = True
            else:
                a = 1
                # reward -= 1

            self.done = True
            self.win_event_counter.increment()

        total_reward = torch.sum(reward).item()

        if self.writer is not None and self.done:
            self.writer.add_scalar("step/quality", quality,
                                   self.writer_counter.value())
            self.writer.add_scalar("step/stop_quality", self.stop_quality,
                                   self.writer_counter.value())
            self.writer.add_scalar("step/n_wins",
                                   self.win_event_counter.value(),
                                   self.writer_counter.value())
            self.writer.add_scalar("step/steps_needed", self.counter,
                                   self.writer_counter.value())
            self.writer.add_scalar("step/win_loose_ratio",
                                   (self.win_event_counter.value() + 1) /
                                   (self.writer_counter.value() + 1),
                                   self.writer_counter.value())
            self.writer.add_scalar("step/pred_mean",
                                   self.b_current_edge_weights.mean(),
                                   self.writer_counter.value())
            self.writer.add_scalar("step/pred_std",
                                   self.b_current_edge_weights.std(),
                                   self.writer_counter.value())
            self.writer.add_scalar("step/gt_mean",
                                   self.b_gt_edge_weights.mean(),
                                   self.writer_counter.value())
            self.writer.add_scalar("step/gt_std", self.b_gt_edge_weights.std(),
                                   self.writer_counter.value())
            if logg_vals is not None:
                for key, val in logg_vals.items():
                    self.writer.add_scalar("step/" + key, val,
                                           self.writer_counter.value())
            self.writer_counter.increment()

        self.acc_reward = total_reward
        return self.get_state(), reward, quality

    def get_state(self):
        state_pixels = torch.cat(
            [self.raw, self.init_sp_seg,
             self.get_current_soln()], dim=1)
        return state_pixels, self.b_edge_ids, self.sp_indices, self.b_edge_angles, self.counter, self.b_gt_edge_weights

    def update_data(self, b_edge_ids, edge_features, diff_to_gt,
                    gt_edge_weights, node_labeling, raw, angles, gt):
        self.gt_seg = gt
        self.raw = raw
        self.init_sp_seg = node_labeling

        self.n_nodes = [edge_ids.max() + 1 for edge_ids in b_edge_ids]
        # b_subgraphs = find_dense_subgraphs([edge_ids.transpose(0, 1).cpu().numpy() for edge_ids in b_edge_ids], self.args.s_subgraph)
        self.b_edge_ids, (self.n_offs, self.e_offs) = collate_edges(b_edge_ids)
        # b_subgraphs = [torch.from_numpy(sg.astype(np.int64)).to(self.device).view(-1, 2).transpose(0, 1) + self.n_offs[i] for i, sg in enumerate(b_subgraphs)]
        # self.sg_offs = [0]
        # for i in range(len(b_subgraphs)):
        #     self.sg_offs.append(self.sg_offs[-1] + b_subgraphs[i].shape[0])
        # self.b_subgraphs = torch.cat(b_subgraphs, 1)
        #
        # self.b_subgraph_indices = get_edge_indices(self.b_edge_ids, self.b_subgraphs)
        self.b_gt_edge_weights = torch.cat(gt_edge_weights, 0)
        # self.sg_gt_edge_weights = self.b_gt_edge_weights[self.b_subgraph_indices].view(-1, self.args.s_subgraph)
        self.b_edge_angles = angles
        self.b_current_edge_weights = torch.ones_like(
            self.b_gt_edge_weights) / 2
        # self.sg_current_edge_weights = torch.ones_like(self.sg_gt_edge_weights) / 2

        self.b_initial_edge_weights = torch.cat(
            [edge_fe[:, 0] for edge_fe in edge_features], dim=0)
        self.b_edge_features = torch.cat(edge_features, dim=0)

        stacked_superpixels = [[node_labeling[i] == n for n in range(n_node)]
                               for i, n_node in enumerate(self.n_nodes)]
        self.sp_indices = [[sp.nonzero().cpu() for sp in stacked_superpixel]
                           for stacked_superpixel in stacked_superpixels]

        # self.b_penalize_diff_thresh = diff_to_gt * 4
        # plt.imshow(self.get_current_soln_pic(1));plt.show()
        # return

    def get_batched_actions_from_global_graph(self, actions):
        b_actions = torch.zeros(size=(self.b_edge_ids.shape[1], ))
        other = torch.zeros_like(self.b_subgraph_indices)
        for i in range(self.b_edge_ids.shape[1]):
            mask = (self.b_subgraph_indices == i)
            num = mask.float().sum()
            b_actions[i] = torch.where(mask, actions.float(),
                                       other.float()).sum() / num
        return b_actions

    def get_current_soln(self):
        # affs = np.expand_dims(self.affinities, axis=1)
        # boundary_input = np.mean(affs, axis=0)
        # mc_seg = general.multicut_from_probas(self.init_sp_seg.squeeze().cpu(), self.edge_ids.cpu().t().contiguous().numpy(),
        #                                       self.current_edge_weights.squeeze().cpu().numpy(), boundary_input)
        # return torch.from_numpy(mc_seg.astype(np.float32))
        return torch.rand_like(self.init_sp_seg)

    def get_current_soln_pic(self, b):
        b_actions = self.get_batched_actions_from_global_graph(
            self.sg_current_edge_weights.view(-1))
        b_gt = self.get_batched_actions_from_global_graph(
            self.sg_gt_edge_weights.view(-1))

        edge_ids = self.b_edge_ids[:, self.e_offs[b]:self.
                                   e_offs[b + 1]] - self.n_offs[b]
        edge_ids = edge_ids.cpu().t().contiguous().numpy()
        boundary_input = self.b_initial_edge_weights[self.e_offs[b]:self.
                                                     e_offs[b +
                                                            1]].cpu().numpy()
        mc_seg1 = general.multicut_from_probas(
            self.init_sp_seg[b].squeeze().cpu(), edge_ids,
            self.b_initial_edge_weights[self.e_offs[b]:self.e_offs[b + 1]].cpu(
            ).numpy(), boundary_input)
        mc_seg = general.multicut_from_probas(
            self.init_sp_seg[b].squeeze().cpu(), edge_ids,
            b_actions[self.e_offs[b]:self.e_offs[b + 1]].cpu().numpy(),
            boundary_input)
        gt_mc_seg = general.multicut_from_probas(
            self.init_sp_seg[b].squeeze().cpu(), edge_ids,
            b_gt[self.e_offs[b]:self.e_offs[b + 1]].cpu().numpy(),
            boundary_input)
        mc_seg = cm.prism(mc_seg / mc_seg.max())
        mc_seg1 = cm.prism(mc_seg1 / mc_seg1.max())
        seg = cm.prism(self.init_sp_seg[b].squeeze().cpu() /
                       self.init_sp_seg[b].cpu().max())
        gt_mc_seg = cm.prism(gt_mc_seg / gt_mc_seg.max())
        return np.concatenate((np.concatenate(
            (mc_seg1, mc_seg), 0), np.concatenate((gt_mc_seg, seg), 0)), 1)
        ####################
        # init mc # gt seg #
        ####################
        # curr mc # sp seg #
        ####################

    def reset(self):
        self.done = False
        self.win = False
        self.acc_reward = 0
        self.last_reward = -inf
        self.counter = 0
Пример #10
0
class SpGcnEnv(Environment):
    def __init__(self, args, device, writer=None, writer_counter=None, win_event_counter=None, discrete_action_space=True):
        super(SpGcnEnv, self).__init__()
        self.stop_quality = 0

        self.reset()
        self.args = args
        self.device = device
        self.writer = writer
        self.writer_counter = writer_counter
        self.win_event_counter = win_event_counter
        self.discrete_action_space = discrete_action_space

        if self.args.reward_function == 'fully_supervised':
            self.reward_function = FullySupervisedReward(env=self)
        elif self.args.reward_function == 'object_level':
            self.reward_function = ObjectLevelReward(env=self)
        elif self.args.reward_function == 'graph_dice':
            self.reward_function = GraphDiceReward(env=self)
        elif self.args.reward_function == 'focal':
            self.reward_function = FocalReward(env=self)
        else:
            self.reward_function = UnSupervisedReward(env=self)

    def execute_action(self, actions):
        last_diff = (self.state[0] - self.gt_edge_weights).squeeze().abs()
        if self.discrete_action_space:
            mask = (actions == 2).float() * (self.state[0] + self.args.action_agression)
            mask += (actions == 1).float() * (self.state[0] - self.args.action_agression)
            mask += (actions == 0).float() * self.state[0]
            self.state[0] = mask + 1e-10  # prevent the reinforcement loss from becoming too large
            self.state[0] = self.state[0].clamp(min=0, max=1)
        else:
            self.state[0] = actions.clone()

        reward = self.reward_function.get(last_diff, actions, self.get_current_soln()).to(self.device)

        # self.get_rag_and_edge_feats(reward, self.state[0])

        self.data_changed = torch.sum(torch.abs(self.state[0] - self.edge_features[:, 0])).cpu().item()
        penalize_change = 0
        quality = (self.state[0] - self.gt_edge_weights).squeeze().abs().sum().item()
        if self.counter > self.args.max_episode_length:
            # penalize_change = (self.penalize_diff_thresh - self.data_changed) / np.prod(self.state.size()) * 10

            if quality < self.stop_quality:
                reward += 2
                self.win = True
            else:
                reward -= 1

            self.done = True
            self.iteration += 1
            self.iteration += 1
            self.win_event_counter.increment()

        reward += (penalize_change * (actions != 0).float())

        total_reward = torch.sum(reward).item()
        self.counter += 1

        if self.writer is not None and self.done:
            self.writer.add_scalar("step/quality", quality, self.writer_counter.value())
            self.writer.add_scalar("step/stop_quality", self.stop_quality, self.writer_counter.value())
            self.writer.add_scalar("step/n_wins", self.win_event_counter.value(), self.writer_counter.value())
            self.writer.add_scalar("step/steps_needed", self.counter, self.writer_counter.value())
            self.writer.add_scalar("step/win_loose_ratio", (self.win_event_counter.value()+1) /
                                   (self.writer_counter.value()+1), self.writer_counter.value())
            self.writer.add_scalar("step/pred_mean", self.state[0].mean(), self.writer_counter.value())
            self.writer.add_scalar("step/pred_std", self.state[0].std(), self.writer_counter.value())
            self.writer.add_scalar("step/gt_mean", self.gt_edge_weights.mean(), self.writer_counter.value())
            self.writer.add_scalar("step/gt_std", self.gt_edge_weights.std(), self.writer_counter.value())
            self.writer_counter.increment()

        self.acc_reward = total_reward
        self.state[1] = self.get_current_soln()
        return [self.state[0].clone(), self.state[1].clone()], reward, quality

    def update_data(self, edge_ids, edge_features, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles,
                    affinities, gt):
        self.gt_seg = gt
        self.affinities = affinities
        self.initial_edge_weights = edge_features[:, 0]
        self.edge_features = edge_features
        self.stacked_superpixels = [node_labeling == n for n in nodes]
        self.sp_indices = [sp.nonzero() for sp in self.stacked_superpixels]
        self.raw = raw
        self.penalize_diff_thresh = diff_to_gt * 4
        self.init_sp_seg = node_labeling.squeeze()
        self.edge_ids = edge_ids
        self.gt_edge_weights = gt_edge_weights
        self.edge_angles = angles
        self.state = [torch.ones_like(gt_edge_weights) / 2, None]
        self.state = [torch.ones_like(gt_edge_weights) / 2, self.get_current_soln()]

    def show_current_soln(self):
        affs = np.expand_dims(self.affinities, axis=1)
        boundary_input = np.mean(affs, axis=0)
        mc_seg1 = general.multicut_from_probas(self.init_sp_seg.cpu(), self.edge_ids.cpu().t().contiguous().numpy(),
                                               self.initial_edge_weights.squeeze().cpu().numpy(), boundary_input)
        mc_seg = general.multicut_from_probas(self.init_sp_seg.cpu(), self.edge_ids.cpu().t().contiguous().numpy(),
                                              self.state[0].squeeze().cpu().numpy(), boundary_input)
        gt_mc_seg = general.multicut_from_probas(self.init_sp_seg.cpu(), self.edge_ids.cpu().t().contiguous().numpy(),
                                                 self.gt_edge_weights.squeeze().cpu().numpy(), boundary_input)
        mc_seg = cm.prism(mc_seg / mc_seg.max())
        mc_seg1 = cm.prism(mc_seg1 / mc_seg1.max())
        seg = cm.prism(self.init_sp_seg.cpu() / self.init_sp_seg.cpu().max())
        gt_mc_seg = cm.prism(gt_mc_seg / gt_mc_seg.max())
        plt.imshow(np.concatenate((np.concatenate((mc_seg1, mc_seg), 0), np.concatenate((gt_mc_seg, seg), 0)), 1));
        plt.show()
        a=1
        ####################
        # init mc # gt seg #
        ####################
        # curr mc # sp seg #
        ####################

    def get_current_soln(self):
        affs = np.expand_dims(self.affinities, axis=1)
        boundary_input = np.mean(affs, axis=0)
        mc_seg = general.multicut_from_probas(self.init_sp_seg.squeeze().cpu(), self.edge_ids.cpu().t().contiguous().numpy(),
                                              self.state[0].squeeze().cpu().numpy(), boundary_input)
        return torch.from_numpy(mc_seg.astype(np.float))

    def get_rag_and_edge_feats(self, reward, edges):
        edge_indices = []
        seg = self.init_sp_seg.clone()
        for edge in self.edge_ids.t():
            n1, n2 = self.sp_indices[edge[0]], self.sp_indices[edge[1]]
            dis = torch.cdist(n1.float(), n2.float())
            dis = (dis <= 1).nonzero()
            inds_n1 = n1[dis[:, 0].unique()]
            inds_n2 = n2[dis[:, 1].unique()]
            edge_indices.append(torch.cat((inds_n1, inds_n2), 0))
        for indices in edge_indices:
            seg[indices[:, 0], indices[:, 1]] = 600
        seg = cm.prism(seg.cpu().numpy() / seg.cpu().numpy().max())
        plt.imshow(seg)
        plt.show()
        a=1

    def reset(self):
        self.done = False
        self.win = False
        self.iteration = 0
        self.acc_reward = 0
        self.last_reward = -inf
        self.counter = 0