def __init__(self, args, device, writer=None, writer_counter=None, win_event_counter=None): super(SpGcnEnv, self).__init__() self.stop_quality = 0 self.reset() self.args = args self.device = device self.writer = writer self.writer_counter = writer_counter self.win_event_counter = win_event_counter self.discrete_action_space = False if self.args.reward_function == 'fully_supervised': self.reward_function = FullySupervisedReward(env=self) elif self.args.reward_function == 'object_level': self.reward_function = ObjectLevelReward(env=self) elif self.args.reward_function == 'graph_dice': self.reward_function = GraphDiceReward(env=self) elif self.args.reward_function == 'focal': self.reward_function = FocalReward(env=self) elif self.args.reward_function == 'global_sparse': self.reward_function = GlobalSparseReward(env=self) else: self.reward_function = UnSupervisedReward(env=self)
def __init__(self, embedding_net, cfg, device, writer=None, writer_counter=None): super(EmbeddingSpaceEnvNodeBased, self).__init__() self.embedding_net = embedding_net self.reset() self.cfg = cfg self.device = device self.writer = writer self.writer_counter = writer_counter self.last_final_reward = torch.tensor([0.0]) self.max_p = torch.nn.MaxPool2d(3, padding=1, stride=1) self.step_encoder = TemporalSineEncoding(max_step=cfg.trn.max_episode_length, size=cfg.fe.n_embedding_features) if self.cfg.trn.reward_function == 'sub_graph_dice': self.reward_function = SubGraphDiceReward() else: self.reward_function = UnSupervisedReward(env=self) self.cluster_policy = nagglo.cosineDistNodeAndEdgeWeightedClusterPolicy
def __init__(self, args, device, writer=None, writer_counter=None, win_event_counter=None): super(SpGcnEnv, self).__init__() self.reset() self.args = args self.device = device self.writer = writer self.writer_counter = writer_counter self.discrete_action_space = False if self.args.reward_function == 'fully_supervised': self.reward_function = FullySupervisedReward(env=self) elif self.args.reward_function == 'sub_graph_dice': self.reward_function = SubGraphDiceReward(env=self) elif self.args.reward_function == 'defining_rules': self.reward_function = HoughCircles(env=self, range_num=[8, 10], range_rad=[max(self.args.data_shape) // 18, max(self.args.data_shape) // 15], min_hough_confidence=0.7) elif self.args.reward_function == 'defining_rules_lg': self.reward_function = HoughCircles_lg(env=self, range_num=[8, 10], range_rad=[max(self.args.data_shape) // 18, max(self.args.data_shape) // 15], min_hough_confidence=0.7) # elif self.args.reward_function == 'focal': # self.reward_function = FocalReward(env=self) # elif self.args.reward_function == 'global_sparse': # self.reward_function = GlobalSparseReward(env=self) else: self.reward_function = UnSupervisedReward(env=self)
def __init__(self, cfg, device, writer=None, writer_counter=None): super(SpGcnEnv, self).__init__() self.reset() self.cfg = cfg self.device = device self.writer = writer self.writer_counter = writer_counter self.discrete_action_space = False self.max_p = torch.nn.MaxPool2d(3, padding=1, stride=1) if self.cfg.sac.reward_function == 'fully_supervised': self.reward_function = FullySupervisedReward(env=self) elif self.cfg.sac.reward_function == 'sub_graph_dice': self.reward_function = SubGraphDiceReward(env=self) elif self.cfg.sac.reward_function == 'defining_rules_edge_based': self.reward_function = HoughCircles( env=self, range_num=[8, 10], range_rad=[ max(self.cfg.sac.data_shape) // 18, max(self.cfg.sac.data_shape) // 15 ], min_hough_confidence=0.7) elif self.cfg.sac.reward_function == 'defining_rules_sp_based': self.reward_function = HoughCirclesOnSp( env=self, range_num=[8, 10], range_rad=[ max(self.cfg.sac.data_shape) // 18, max(self.cfg.sac.data_shape) // 15 ], min_hough_confidence=0.7) elif self.cfg.sac.reward_function == 'defining_rules_lg': assert False else: self.reward_function = UnSupervisedReward(env=self)
class EmbeddingSpaceEnvNodeBased(): State = collections.namedtuple("State", ["node_embeddings", "edge_ids", "edge_angles", "sup_masses", "subgraph_indices", "sep_subgraphs", "subgraphs", "gt_edge_weights"]) def __init__(self, embedding_net, cfg, device, writer=None, writer_counter=None): super(EmbeddingSpaceEnvNodeBased, self).__init__() self.embedding_net = embedding_net self.reset() self.cfg = cfg self.device = device self.writer = writer self.writer_counter = writer_counter self.last_final_reward = torch.tensor([0.0]) self.max_p = torch.nn.MaxPool2d(3, padding=1, stride=1) self.step_encoder = TemporalSineEncoding(max_step=cfg.trn.max_episode_length, size=cfg.fe.n_embedding_features) if self.cfg.trn.reward_function == 'sub_graph_dice': self.reward_function = SubGraphDiceReward() else: self.reward_function = UnSupervisedReward(env=self) self.cluster_policy = nagglo.cosineDistNodeAndEdgeWeightedClusterPolicy # self.cluster_policy = nagglo.nodeAndEdgeWeightedClusterPolicy def execute_action(self, actions, logg_vals=None, post_stats=False): self.current_node_embeddings += actions # normalize self.current_node_embeddings /= (torch.norm(self.current_node_embeddings, dim=-1, keepdim=True) + 1e-10) self.current_soln, node_labeling = self.get_soln_graph_clustering(self.current_node_embeddings) sg_edge_weights = [] for i, sz in enumerate(self.cfg.trn.s_subgraph): sg_ne = node_labeling[self.subgraphs[i].view(2, -1, sz)] sg_edge_weights.append((sg_ne[0] == sg_ne[1]).float()) reward = self.reward_function.get(sg_edge_weights, self.sg_gt_edges) #self.current_soln) reward.append(self.last_final_reward) self.counter += 1 if self.counter >= self.cfg.trn.max_episode_length: self.done = True ne = node_labeling[self.edge_ids] edge_weights = ((ne[0] == ne[1]).float()) self.last_final_reward = self.reward_function.get_global(edge_weights, self.gt_edge_weights) total_reward = 0 for _rew in reward: total_reward += _rew.mean().item() total_reward /= len(self.cfg.trn.s_subgraph) if self.writer is not None and post_stats: self.writer.add_scalar("step/avg_return", total_reward, self.writer_counter.value()) if self.writer_counter.value() % 20 == 0: fig, (a0, a1, a2, a3, a4) = plt.subplots(1, 5, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0}) a0.imshow(self.gt_seg[0].cpu().squeeze()) a0.set_title('gt') a1.imshow(self.raw[0].cpu().permute(1,2,0).squeeze()) a1.set_title('raw image') a2.imshow(cm.prism(self.init_sp_seg[0].cpu() / self.init_sp_seg[0].max().item())) a2.set_title('superpixels') a3.imshow(cm.prism(self.gt_soln[0].cpu()/self.gt_soln[0].max().item())) a3.set_title('gt') a4.imshow(cm.prism(self.current_soln[0].cpu()/self.current_soln[0].max().item())) a4.set_title('prediction') self.writer.add_figure("image/state", fig, self.writer_counter.value() // 10) # self.writer.add_figure("image/shift_proj", self.vis_node_actions(actions.cpu(), 0), self.writer_counter.value() // 10) self.embedding_net.post_pca(get_angles(self.embeddings)[0].cpu(), tag="image/pix_embedding_proj") self.embedding_net.post_pca(get_angles(self.current_node_embeddings[:self.n_offs[1]][self.init_sp_seg[0].long()].permute(2, 0, 1)[None])[0].cpu(), tag="image/node_embedding_proj") if logg_vals is not None: for key, val in logg_vals.items(): self.writer.add_scalar("step/" + key, val, self.writer_counter.value()) self.writer_counter.increment() self.acc_reward.append(total_reward) return self.get_state(), reward def get_state(self): temp_code = self.step_encoder(self.counter, self.current_node_embeddings.device)[None] state = self.current_node_embeddings + temp_code return self.State(state, self.edge_ids, self.edge_angles, self.sup_masses, self.subgraph_indices, self.sep_subgraphs, self.subgraphs, self.gt_edge_weights) def update_data(self, edge_ids, gt_edges, sp_seg, raw, gt, **kwargs): bs = len(edge_ids) dev = edge_ids[0].device subgraphs, self.sep_subgraphs = [], [] self.gt_seg = gt.squeeze(1) self.raw = raw self.init_sp_seg = sp_seg.squeeze(1) edge_angles, sup_masses, sup_com = zip(*[get_angles_smass_in_rag(edges, sp) for edges, sp in zip(edge_ids, self.init_sp_seg)]) self.edge_angles, self.sup_masses, self.sup_com = torch.cat(edge_angles).unsqueeze(-1), torch.cat(sup_masses).unsqueeze(-1), torch.cat(sup_com) self.init_sp_seg_edge = torch.cat([(-self.max_p(-sp_seg) != sp_seg).float(), (self.max_p(sp_seg) != sp_seg).float()], 1) _subgraphs, _sep_subgraphs = find_dense_subgraphs([eids.transpose(0, 1).cpu().numpy() for eids in edge_ids], self.cfg.trn.s_subgraph) _subgraphs = [torch.from_numpy(sg.astype(np.int64)).to(dev).permute(2, 0, 1) for sg in _subgraphs] _sep_subgraphs = [torch.from_numpy(sg.astype(np.int64)).to(dev).permute(2, 0, 1) for sg in _sep_subgraphs] self.n_nodes = [eids.max() + 1 for eids in edge_ids] self.edge_ids, (self.n_offs, self.e_offs) = collate_edges(edge_ids) self.dir_edge_ids = torch.cat([self.edge_ids, torch.stack([self.edge_ids[1], self.edge_ids[0]], dim=0)], dim=1) for i in range(len(self.cfg.trn.s_subgraph)): subgraphs.append(torch.cat([sg + self.n_offs[i] for i, sg in enumerate(_subgraphs[i*bs:(i+1)*bs])], -2).flatten(-2, -1)) self.sep_subgraphs.append(torch.cat(_sep_subgraphs[i*bs:(i+1)*bs], -2).flatten(-2, -1)) self.subgraphs = subgraphs self.subgraph_indices = get_edge_indices(self.edge_ids, subgraphs) self.gt_edge_weights = torch.cat(gt_edges) self.gt_soln = self.get_mc_soln(self.gt_edge_weights) self.sg_gt_edges = [self.gt_edge_weights[sg].view(-1, sz) for sz, sg in zip(self.cfg.trn.s_subgraph, self.subgraph_indices)] self.embeddings = self.embedding_net(self.raw).detach() # get embedding agglomeration over each superpixel self.current_node_embeddings = torch.cat([self.embedding_net.get_mean_sp_embedding(embed, sp) for embed, sp in zip(self.embeddings, self.init_sp_seg)], dim=0) return def get_batched_actions_from_global_graph(self, actions): b_actions = torch.zeros(size=(self.edge_ids.shape[1],)) other = torch.zeros_like(self.subgraph_indices) for i in range(self.edge_ids.shape[1]): mask = (self.subgraph_indices == i) num = mask.float().sum() b_actions[i] = torch.where(mask, actions.float(), other.float()).sum() / num return b_actions def get_soln_free_clustering(self, node_features): labels = [] node_labels = [] for i, sp_seg in enumerate(self.init_sp_seg): single_node_features = node_features[self.n_offs[i]:self.n_offs[i+1]] z_linkage = linkage(single_node_features.cpu(), 'ward') node_labels.append(fcluster(z_linkage, self.cfg.gen.n_max_object, criterion='maxclust')) rag = elf.segmentation.features.compute_rag(np.expand_dims(sp_seg.cpu(), axis=0)) labels.append(elf.segmentation.features.project_node_labels_to_pixels(rag, node_labels[-1]).squeeze()) return torch.from_numpy(np.stack(labels).astype(np.float)).to(node_features.device), \ torch.from_numpy(np.concatenate(node_labels).astype(np.float)).to(node_features.device) def get_soln_graph_clustering(self, node_features): labels = [] node_labels = [] for i, sp_seg in enumerate(self.init_sp_seg): single_node_features = node_features[self.n_offs[i]:self.n_offs[i+1]].detach().cpu().numpy() rag = nifty.graph.undirectedGraph(single_node_features.shape[0]) rag.insertEdges((self.edge_ids[:, self.e_offs[i]:self.e_offs[i+1]] - self.n_offs[i]).T.detach().cpu().numpy()) edge_weights = np.ones(rag.numberOfEdges, dtype=np.int) edge_sizes = np.ones(rag.numberOfEdges, dtype=np.int) node_sizes = np.ones(rag.numberOfNodes, dtype=np.int) policy = self.cluster_policy( graph=rag, edgeIndicators=edge_weights, edgeSizes=edge_sizes, nodeFeatures=single_node_features, nodeSizes=node_sizes, numberOfNodesStop=self.cfg.gen.n_max_object, beta = 1, sizeRegularizer = 0 ) clustering = nagglo.agglomerativeClustering(policy) clustering.run() node_labels.append(clustering.result()) rag = elf.segmentation.features.compute_rag(np.expand_dims(sp_seg.cpu(), axis=0)) labels.append(elf.segmentation.features.project_node_labels_to_pixels(rag, node_labels[-1]).squeeze()) return torch.from_numpy(np.stack(labels).astype(np.float)).to(node_features.device), \ torch.from_numpy(np.concatenate(node_labels).astype(np.float)).to(node_features.device) def get_mc_soln(self, edge_weights): p_min = 0.001 p_max = 1. segmentations = [] for i in range(1, len(self.e_offs)): probs = edge_weights[self.e_offs[i-1]:self.e_offs[i]] edges = self.edge_ids[:, self.e_offs[i-1]:self.e_offs[i]] - self.n_offs[i-1] costs = (p_max - p_min) * probs + p_min # probabilities to costs costs = (torch.log((1. - costs) / costs)).detach().cpu().numpy() graph = nifty.graph.undirectedGraph(self.n_nodes[i-1]) graph.insertEdges(edges.T.cpu().numpy()) node_labels = elf.segmentation.multicut.multicut_kernighan_lin(graph, costs) mc_seg = torch.zeros_like(self.init_sp_seg[i-1]) for j, lbl in enumerate(node_labels): mc_seg += (self.init_sp_seg[i-1] == j).float() * lbl segmentations.append(mc_seg) return torch.stack(segmentations, dim=0) def get_node_gt(self): b_node_seg = torch.zeros(self.n_offs[-1], device=self.gt_seg.device) for i, (sp_seg, gt) in enumerate(zip(self.init_sp_seg, self.gt_seg)): for node_it in range(self.n_nodes[i]): nums = torch.bincount(((sp_seg == node_it).long() * (gt.long() + 1)).view(-1)) b_node_seg[node_it + self.n_offs[i]] = nums[1:].argmax() - 1 return b_node_seg def vis_node_actions(self, shifts, sb=0): plt.clf() fig = plt.figure() shifts = shifts[self.n_offs[sb]:self.n_offs[sb+1]] n = shifts.shape[0] proj = pca_project_1d(shifts, 8) proj = np.concatenate((proj[:2], proj[2:4], proj[4:6], proj[6:8]), 1) colors = n*["b"] + n*["g"] + n*["r"] + n*["c"] com = np.round(self.sup_com[self.n_offs[sb]:self.n_offs[sb+1]].cpu()) com = np.concatenate((com, )*4, 0) plt.imshow((self.gt_seg[sb]*(self.init_sp_seg_edge[sb, 0] == 0) + self.init_sp_seg_edge[sb, 0] * 10).cpu()) plt.quiver(com[:, 1], com[:, 0], proj[0], proj[1], color=colors, width=0.005) return fig def reset(self): self.done = False self.acc_reward = [] self.counter = 0
class SpGcnEnv(Environment): def __init__(self, args, device, writer=None, writer_counter=None, win_event_counter=None): super(SpGcnEnv, self).__init__() self.reset() self.args = args self.device = device self.writer = writer self.writer_counter = writer_counter self.discrete_action_space = False if self.args.reward_function == 'fully_supervised': self.reward_function = FullySupervisedReward(env=self) elif self.args.reward_function == 'sub_graph_dice': self.reward_function = SubGraphDiceReward(env=self) elif self.args.reward_function == 'defining_rules': self.reward_function = HoughCircles(env=self, range_num=[8, 10], range_rad=[max(self.args.data_shape) // 18, max(self.args.data_shape) // 15], min_hough_confidence=0.7) elif self.args.reward_function == 'defining_rules_lg': self.reward_function = HoughCircles_lg(env=self, range_num=[8, 10], range_rad=[max(self.args.data_shape) // 18, max(self.args.data_shape) // 15], min_hough_confidence=0.7) # elif self.args.reward_function == 'focal': # self.reward_function = FocalReward(env=self) # elif self.args.reward_function == 'global_sparse': # self.reward_function = GlobalSparseReward(env=self) else: self.reward_function = UnSupervisedReward(env=self) def execute_action(self, actions, logg_vals=None, post_stats=False): last_diff = (self.sg_current_edge_weights - self.sg_gt_edge_weights).squeeze().abs() self.b_current_edge_weights = actions.clone() self.sg_current_edge_weights = actions[self.b_subgraph_indices].view(-1, self.args.s_subgraph) reward = self.reward_function.get(actions, self.get_current_soln()) quality = (self.sg_current_edge_weights - self.sg_gt_edge_weights).squeeze().abs().sum().item() self.counter += 1 if self.counter >= self.args.max_episode_length: self.done = True total_reward = torch.sum(reward[0]).item() if self.writer is not None and post_stats: self.writer.add_scalar("step/quality", quality, self.writer_counter.value()) self.writer.add_scalar("step/avg_return_1", reward[0].mean(), self.writer_counter.value()) self.writer.add_scalar("step/avg_return_2", reward[1].mean(), self.writer_counter.value()) if self.writer_counter.value() % 80 == 0: self.writer.add_histogram("step/pred_mean", self.sg_current_edge_weights.view(-1).cpu().numpy(), self.writer_counter.value() // 80) self.writer.add_scalar("step/gt_mean", self.sg_gt_edge_weights.mean(), self.writer_counter.value()) self.writer.add_scalar("step/gt_std", self.sg_gt_edge_weights.std(), self.writer_counter.value()) if logg_vals is not None: for key, val in logg_vals.items(): self.writer.add_scalar("step/" + key, val, self.writer_counter.value()) self.writer_counter.increment() self.acc_reward = total_reward return self.get_state(), reward, quality # def get_state(self): # return self.raw, self.b_edge_ids, self.sp_indices, self.b_edge_angles, self.b_subgraph_indices, self.sep_subgraphs, self.counter, self.b_gt_edge_weights def get_state(self): return self.raw, self.b_edge_ids, self.sp_indices, self.b_edge_angles, self.b_subgraph_indices, self.sep_subgraphs, self.e_offs, self.counter, self.b_gt_edge_weights def update_data(self, b_edge_ids, edge_features, diff_to_gt, gt_edge_weights, node_labeling, raw, angles, gt): self.gt_seg = gt self.raw = raw self.init_sp_seg = node_labeling.squeeze() self.n_nodes = [edge_ids.max() + 1 for edge_ids in b_edge_ids] b_subgraphs, sep_subgraphs = find_dense_subgraphs([edge_ids.transpose(0, 1).cpu().numpy() for edge_ids in b_edge_ids], self.args.s_subgraph) self.sep_subgraphs = torch.from_numpy(sep_subgraphs.astype(np.int64)) self.b_edge_ids, (self.n_offs, self.e_offs) = collate_edges(b_edge_ids) b_subgraphs = [torch.from_numpy(sg.astype(np.int64)).to(self.device).view(-1, 2).transpose(0, 1) + self.n_offs[i] for i, sg in enumerate(b_subgraphs)] self.sg_offs = [0] for i in range(len(b_subgraphs)): self.sg_offs.append(self.sg_offs[-1] + b_subgraphs[i].shape[-1]) self.b_subgraphs = torch.cat(b_subgraphs, 1) self.b_subgraph_indices = get_edge_indices(self.b_edge_ids, self.b_subgraphs) self.b_gt_edge_weights = torch.cat(gt_edge_weights, 0) self.sg_gt_edge_weights = self.b_gt_edge_weights[self.b_subgraph_indices].view(-1, self.args.s_subgraph) self.b_edge_angles = angles self.sg_current_edge_weights = torch.ones_like(self.sg_gt_edge_weights) / 2 self.b_initial_edge_weights = torch.cat([edge_fe[:, 0] for edge_fe in edge_features], dim=0) self.b_current_edge_weights = torch.ones_like(self.b_initial_edge_weights) / 2 stacked_superpixels = [[node_labeling[i] == n for n in range(n_node)] for i, n_node in enumerate(self.n_nodes)] self.sp_indices = [[sp.nonzero().cpu() for sp in stacked_superpixel] for stacked_superpixel in stacked_superpixels] # self.b_penalize_diff_thresh = diff_to_gt * 4 # plt.imshow(self.get_current_soln_pic(1));plt.show() # return def get_batched_actions_from_global_graph(self, actions): b_actions = torch.zeros(size=(self.b_edge_ids.shape[1],)) other = torch.zeros_like(self.b_subgraph_indices) for i in range(self.b_edge_ids.shape[1]): mask = (self.b_subgraph_indices == i) num = mask.float().sum() b_actions[i] = torch.where(mask, actions.float(), other.float()).sum() / num return b_actions def get_current_soln(self): p_min = 0.001 p_max = 1. segmentations = [] for i in range(1, len(self.e_offs)): probs = self.b_current_edge_weights[self.e_offs[i-1]:self.e_offs[i]] # probs = self.b_gt_edge_weights[self.e_offs[i-1]:self.e_offs[i]] edges = self.b_edge_ids[:, self.e_offs[i-1]:self.e_offs[i]] - self.n_offs[i-1] costs = (p_max - p_min) * probs + p_min # probabilities to costs costs = (torch.log((1. - costs) / costs)).detach().cpu().numpy() graph = nifty.graph.undirectedGraph(self.n_nodes[i-1]) graph.insertEdges(edges.T.cpu().numpy()) node_labels = elf.segmentation.multicut.multicut_kernighan_lin(graph, costs) mc_seg = torch.zeros_like(self.init_sp_seg[i-1]) for j, lbl in enumerate(node_labels): mc_seg += (self.init_sp_seg[i-1] == j).float() * lbl segmentations.append(mc_seg) return torch.stack(segmentations, dim=0) # return torch.rand_like(self.init_sp_seg) def get_current_soln_pic(self, b): b_actions = self.get_batched_actions_from_global_graph(self.sg_current_edge_weights.view(-1)) b_gt = self.get_batched_actions_from_global_graph(self.sg_gt_edge_weights.view(-1)) edge_ids = self.b_edge_ids[:, self.e_offs[b]: self.e_offs[b+1]] - self.n_offs[b] edge_ids = edge_ids.cpu().t().contiguous().numpy() boundary_input = self.b_initial_edge_weights[self.e_offs[b]: self.e_offs[b+1]].cpu().numpy() mc_seg1 = general.multicut_from_probas(self.init_sp_seg[b].squeeze().cpu(), edge_ids, self.b_initial_edge_weights[self.e_offs[b]: self.e_offs[b+1]].cpu().numpy(), boundary_input) mc_seg = general.multicut_from_probas(self.init_sp_seg[b].squeeze().cpu(), edge_ids, b_actions[self.e_offs[b]: self.e_offs[b+1]].cpu().numpy(), boundary_input) gt_mc_seg = general.multicut_from_probas(self.init_sp_seg[b].squeeze().cpu(), edge_ids, b_gt[self.e_offs[b]: self.e_offs[b+1]].cpu().numpy(), boundary_input) mc_seg = cm.prism(mc_seg / mc_seg.max()) mc_seg1 = cm.prism(mc_seg1 / mc_seg1.max()) seg = cm.prism(self.init_sp_seg[b].squeeze().cpu() / self.init_sp_seg[b].cpu().max()) gt_mc_seg = cm.prism(gt_mc_seg / gt_mc_seg.max()) return np.concatenate((np.concatenate((mc_seg1, mc_seg), 0), np.concatenate((gt_mc_seg, seg), 0)), 1) #################### # init mc # gt seg # #################### # curr mc # sp seg # #################### def reset(self): self.done = False self.win = False self.acc_reward = 0 self.last_reward = -inf self.counter = 0
class SpGcnEnv(Environment): def __init__(self, cfg, device, writer=None, writer_counter=None): super(SpGcnEnv, self).__init__() self.reset() self.cfg = cfg self.device = device self.writer = writer self.writer_counter = writer_counter self.discrete_action_space = False self.max_p = torch.nn.MaxPool2d(3, padding=1, stride=1) if self.cfg.sac.reward_function == 'fully_supervised': self.reward_function = FullySupervisedReward(env=self) elif self.cfg.sac.reward_function == 'sub_graph_dice': self.reward_function = SubGraphDiceReward(env=self) elif self.cfg.sac.reward_function == 'defining_rules_edge_based': self.reward_function = HoughCircles( env=self, range_num=[8, 10], range_rad=[ max(self.cfg.sac.data_shape) // 18, max(self.cfg.sac.data_shape) // 15 ], min_hough_confidence=0.7) elif self.cfg.sac.reward_function == 'defining_rules_sp_based': self.reward_function = HoughCirclesOnSp( env=self, range_num=[8, 10], range_rad=[ max(self.cfg.sac.data_shape) // 18, max(self.cfg.sac.data_shape) // 15 ], min_hough_confidence=0.7) elif self.cfg.sac.reward_function == 'defining_rules_lg': assert False else: self.reward_function = UnSupervisedReward(env=self) def execute_action(self, actions, logg_vals=None, post_stats=False): # last_diff = (self.sg_current_edge_weights - self.sg_gt_edge_weights).squeeze().abs() self.current_edge_weights = actions self.sg_current_edge_weights = [] for i, sz in enumerate(self.cfg.sac.s_subgraph): self.sg_current_edge_weights.append( self.current_edge_weights[self.subgraph_indices[i].view( -1, sz)]) self.current_soln = self.get_current_soln(self.current_edge_weights) reward = self.reward_function.get( self.sg_current_edge_weights, self.sg_gt_edge_weights) #self.current_soln) # reward = self.reward_function.get(actions, self.get_current_soln(self.gt_edge_weights)) # reward = self.reward_function.get(actions=self.sg_current_edge_weights) self.counter += 1 if self.counter >= self.cfg.trainer.max_episode_length: self.done = True total_reward = 0 for _rew in reward: total_reward += _rew.mean().item() total_reward /= len(self.cfg.sac.s_subgraph) if self.writer is not None and post_stats: self.writer.add_scalar("step/avg_return", total_reward, self.writer_counter.value()) if self.writer_counter.value() % 10 == 0: self.writer.add_histogram( "step/pred_mean", self.current_edge_weights.view(-1).cpu().numpy(), self.writer_counter.value() // 10) fig, (a1, a2, a3, a4) = plt.subplots(1, 4, sharex='col', sharey='row', gridspec_kw={ 'hspace': 0, 'wspace': 0 }) a1.imshow(self.raw[0].cpu().permute(1, 2, 0).squeeze(), cmap='hot') a1.set_title('raw image') a2.imshow( cm.prism(self.init_sp_seg[0].cpu() / self.init_sp_seg[0].max().item())) a2.set_title('superpixels') a3.imshow( cm.prism(self.gt_soln[0].cpu() / self.gt_soln[0].max().item())) a3.set_title('gt') a4.imshow( cm.prism(self.current_soln[0].cpu() / self.current_soln[0].max().item())) a4.set_title('prediction') self.writer.add_figure("image/state", fig, self.writer_counter.value() // 10) self.writer.add_scalar("step/gt_mean", self.gt_edge_weights.mean().item(), self.writer_counter.value()) self.writer.add_scalar("step/gt_std", self.gt_edge_weights.std().item(), self.writer_counter.value()) if logg_vals is not None: for key, val in logg_vals.items(): self.writer.add_scalar("step/" + key, val, self.writer_counter.value()) self.writer_counter.increment() self.acc_reward.append(total_reward) return self.get_state(), reward def get_state(self): return torch.cat([self.raw, self.init_sp_seg_edge], 1), self.init_sp_seg, self.edge_ids, self.sp_indices, \ self.edge_angles, self.subgraph_indices, self.sep_subgraphs, self.counter, self.gt_edge_weights, self.e_offs def update_data(self, edge_ids, edge_features, diff_to_gt, gt_edge_weights, sp_seg, raw, gt): bs = len(edge_ids) dev = edge_ids[0].device subgraphs, self.sep_subgraphs = [], [] self.gt_seg = gt.squeeze(1) self.raw = raw self.init_sp_seg = sp_seg.squeeze(1) self.edge_angles = torch.cat([ get_angles_in_rag(edges, sp) for edges, sp in zip(edge_ids, self.init_sp_seg) ]).unsqueeze(-1) self.init_sp_seg_edge = torch.cat( [(-self.max_p(-sp_seg) != sp_seg).float(), (self.max_p(sp_seg) != sp_seg).float()], 1) _subgraphs, _sep_subgraphs = find_dense_subgraphs( [edge_ids.transpose(0, 1).cpu().numpy() for edge_ids in edge_ids], self.cfg.sac.s_subgraph) _subgraphs = [ torch.from_numpy(sg.astype(np.int64)).to(dev).transpose(-3, -1) for sg in _subgraphs ] _sep_subgraphs = [ torch.from_numpy(sg.astype(np.int64)).to(dev).transpose(-3, -1) for sg in _sep_subgraphs ] self.n_nodes = [edge_ids.max() + 1 for edge_ids in edge_ids] self.edge_ids, (self.n_offs, self.e_offs) = collate_edges(edge_ids) for i in range(len(self.cfg.sac.s_subgraph)): subgraphs.append( torch.cat([ sg + self.n_offs[i] for i, sg in enumerate(_subgraphs[i * bs:(i + 1) * bs]) ], -1).flatten(-2, -1)) self.sep_subgraphs.append( torch.cat(_sep_subgraphs[i * bs:(i + 1) * bs], -1).flatten(-2, -1)) self.subgraph_indices = get_edge_indices(self.edge_ids, subgraphs) self.gt_edge_weights = torch.cat(gt_edge_weights, 0) self.sg_gt_edge_weights = [ self.gt_edge_weights[sg].view(-1, sz) for sz, sg in zip(self.cfg.sac.s_subgraph, self.subgraph_indices) ] self.sg_current_edge_weights = [ torch.ones_like(sg) / 2 for sg in self.sg_gt_edge_weights ] self.initial_edge_weights = torch.cat( [edge_fe[:, 0] for edge_fe in edge_features], dim=0) self.current_edge_weights = self.initial_edge_weights.clone() self.gt_soln = self.get_current_soln(self.gt_edge_weights) stacked_superpixels = [[sp_seg[i] == n for n in range(n_node)] for i, n_node in enumerate(self.n_nodes)] self.sp_indices = [[ torch.nonzero(sp, as_tuple=False) for sp in stacked_superpixel ] for stacked_superpixel in stacked_superpixels] # cs = self.get_current_soln(self.b_gt_edge_weights) # fig, (ax1, ax2, ax3) = plt.subplots(1, 3) # ax1.imshow(cm.prism(self.gt_seg[0].detach().cpu().numpy() / self.gt_seg[0].detach().cpu().numpy().max())); # ax1.set_title('gt') # ax2.imshow(cm.prism(self.init_sp_seg[0].detach().cpu().numpy() / self.init_sp_seg[0].detach().cpu().numpy().max())); # ax2.set_title('sp') # ax3.imshow(cm.prism(cs[0].detach().cpu().numpy() / cs[0].detach().cpu().numpy().max())); # ax3.set_title('mc') # plt.show() # a=1 # self.b_penalize_diff_thresh = diff_to_gt * 4 # plt.imshow(self.get_current_soln_pic(1));plt.show() return def get_batched_actions_from_global_graph(self, actions): b_actions = torch.zeros(size=(self.edge_ids.shape[1], )) other = torch.zeros_like(self.subgraph_indices) for i in range(self.edge_ids.shape[1]): mask = (self.subgraph_indices == i) num = mask.float().sum() b_actions[i] = torch.where(mask, actions.float(), other.float()).sum() / num return b_actions def get_current_soln(self, edge_weights): p_min = 0.001 p_max = 1. segmentations = [] for i in range(1, len(self.e_offs)): probs = edge_weights[self.e_offs[i - 1]:self.e_offs[i]] edges = self.edge_ids[:, self.e_offs[i - 1]:self. e_offs[i]] - self.n_offs[i - 1] costs = (p_max - p_min) * probs + p_min # probabilities to costs costs = (torch.log((1. - costs) / costs)).detach().cpu().numpy() graph = nifty.graph.undirectedGraph(self.n_nodes[i - 1]) graph.insertEdges(edges.T.cpu().numpy()) node_labels = elf.segmentation.multicut.multicut_kernighan_lin( graph, costs) mc_seg = torch.zeros_like(self.init_sp_seg[i - 1]) for j, lbl in enumerate(node_labels): mc_seg += (self.init_sp_seg[i - 1] == j).float() * lbl segmentations.append(mc_seg) return torch.stack(segmentations, dim=0) # return torch.rand_like(self.init_sp_seg) def get_current_soln_pic(self, b): actions = self.get_batched_actions_from_global_graph( self.sg_current_edge_weights.view(-1)) gt = self.get_batched_actions_from_global_graph( self.sg_gt_edge_weights.view(-1)) edge_ids = self.edge_ids[:, self.e_offs[b]:self. e_offs[b + 1]] - self.n_offs[b] edge_ids = edge_ids.cpu().t().contiguous().numpy() boundary_input = self.initial_edge_weights[self.e_offs[b]:self. e_offs[b + 1]].cpu().numpy() mc_seg1 = general.multicut_from_probas( self.init_sp_seg[b].squeeze().cpu(), edge_ids, self.initial_edge_weights[self.e_offs[b]:self.e_offs[b + 1]].cpu( ).numpy(), boundary_input) mc_seg = general.multicut_from_probas( self.init_sp_seg[b].squeeze().cpu(), edge_ids, actions[self.e_offs[b]:self.e_offs[b + 1]].cpu().numpy(), boundary_input) gt_mc_seg = general.multicut_from_probas( self.init_sp_seg[b].squeeze().cpu(), edge_ids, gt[self.e_offs[b]:self.e_offs[b + 1]].cpu().numpy(), boundary_input) mc_seg = cm.prism(mc_seg / mc_seg.max()) mc_seg1 = cm.prism(mc_seg1 / mc_seg1.max()) seg = cm.prism(self.init_sp_seg[b].squeeze().cpu() / self.init_sp_seg[b].cpu().max()) gt_mc_seg = cm.prism(gt_mc_seg / gt_mc_seg.max()) return np.concatenate((np.concatenate( (mc_seg1, mc_seg), 0), np.concatenate((gt_mc_seg, seg), 0)), 1) #################### # init mc # gt seg # #################### # curr mc # sp seg # #################### def reset(self): self.done = False self.acc_reward = [] self.counter = 0
class SpGcnEnv(Environment): def __init__(self, args, device, writer=None, writer_counter=None, win_event_counter=None): super(SpGcnEnv, self).__init__() self.stop_quality = 0 self.reset() self.args = args self.device = device self.writer = writer self.writer_counter = writer_counter self.win_event_counter = win_event_counter self.discrete_action_space = False if self.args.reward_function == 'fully_supervised': self.reward_function = FullySupervisedReward(env=self) elif self.args.reward_function == 'object_level': self.reward_function = ObjectLevelReward(env=self) elif self.args.reward_function == 'graph_dice': self.reward_function = GraphDiceReward(env=self) elif self.args.reward_function == 'focal': self.reward_function = FocalReward(env=self) elif self.args.reward_function == 'global_sparse': self.reward_function = GlobalSparseReward(env=self) else: self.reward_function = UnSupervisedReward(env=self) def execute_action(self, actions, logg_vals=None): last_diff = (self.current_edge_weights - self.gt_edge_weights).squeeze().abs() self.current_edge_weights = actions.clone() reward = self.reward_function.get(last_diff, actions, self.get_current_soln()).to( self.device) quality = (self.current_edge_weights - self.gt_edge_weights).squeeze().abs().sum().item() self.counter += 1 if self.counter >= self.args.max_episode_length: if quality < self.stop_quality: # reward += 2 self.win = True else: a = 1 # reward -= 1 self.done = True self.win_event_counter.increment() total_reward = torch.sum(reward).item() if self.writer is not None and self.done: self.writer.add_scalar("step/quality", quality, self.writer_counter.value()) self.writer.add_scalar("step/stop_quality", self.stop_quality, self.writer_counter.value()) self.writer.add_scalar("step/n_wins", self.win_event_counter.value(), self.writer_counter.value()) self.writer.add_scalar("step/steps_needed", self.counter, self.writer_counter.value()) self.writer.add_scalar("step/win_loose_ratio", (self.win_event_counter.value() + 1) / (self.writer_counter.value() + 1), self.writer_counter.value()) self.writer.add_scalar("step/pred_mean", self.current_edge_weights.mean(), self.writer_counter.value()) self.writer.add_scalar("step/pred_std", self.current_edge_weights.std(), self.writer_counter.value()) self.writer.add_scalar("step/gt_mean", self.gt_edge_weights.mean(), self.writer_counter.value()) self.writer.add_scalar("step/gt_std", self.gt_edge_weights.std(), self.writer_counter.value()) if logg_vals is not None: for key, val in logg_vals.items(): self.writer.add_scalar("step/" + key, val, self.writer_counter.value()) self.writer_counter.increment() self.acc_reward = total_reward state_pixels = torch.stack( [self.raw, self.init_sp_seg, self.get_current_soln()], dim=0) return (state_pixels, self.edge_ids, self.sp_indices, self.edge_angles, self.counter), reward, quality def get_state(self): state_pixels = torch.stack( [self.raw, self.init_sp_seg, self.get_current_soln()], dim=0) return state_pixels, self.edge_ids, self.sp_indices, self.edge_angles, self.counter def update_data(self, edge_ids, edge_features, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt): self.gt_seg = gt self.affinities = affinities self.initial_edge_weights = edge_features[:, 0] self.edge_features = edge_features self.stacked_superpixels = [node_labeling == n for n in nodes] self.sp_indices = [sp.nonzero() for sp in self.stacked_superpixels] self.raw = raw self.penalize_diff_thresh = diff_to_gt * 4 self.init_sp_seg = node_labeling.squeeze() self.edge_ids = edge_ids self.gt_edge_weights = gt_edge_weights self.edge_angles = angles self.current_edge_weights = torch.ones_like(gt_edge_weights) / 2 def show_current_soln(self): affs = np.expand_dims(self.affinities, axis=1) boundary_input = np.mean(affs, axis=0) mc_seg1 = general.multicut_from_probas( self.init_sp_seg.cpu(), self.edge_ids.cpu().t().contiguous().numpy(), self.initial_edge_weights.squeeze().cpu().numpy(), boundary_input) mc_seg = general.multicut_from_probas( self.init_sp_seg.cpu(), self.edge_ids.cpu().t().contiguous().numpy(), self.current_edge_weights.squeeze().cpu().numpy(), boundary_input) gt_mc_seg = general.multicut_from_probas( self.init_sp_seg.cpu(), self.edge_ids.cpu().t().contiguous().numpy(), self.gt_edge_weights.squeeze().cpu().numpy(), boundary_input) mc_seg = cm.prism(mc_seg / mc_seg.max()) mc_seg1 = cm.prism(mc_seg1 / mc_seg1.max()) seg = cm.prism(self.init_sp_seg.cpu() / self.init_sp_seg.cpu().max()) gt_mc_seg = cm.prism(gt_mc_seg / gt_mc_seg.max()) plt.imshow( np.concatenate((np.concatenate( (mc_seg1, mc_seg), 0), np.concatenate((gt_mc_seg, seg), 0)), 1)) plt.show() a = 1 #################### # init mc # gt seg # #################### # curr mc # sp seg # #################### def get_current_soln(self): # affs = np.expand_dims(self.affinities, axis=1) # boundary_input = np.mean(affs, axis=0) # mc_seg = general.multicut_from_probas(self.init_sp_seg.squeeze().cpu(), self.edge_ids.cpu().t().contiguous().numpy(), # self.current_edge_weights.squeeze().cpu().numpy(), boundary_input) # return torch.from_numpy(mc_seg.astype(np.float32)) return torch.rand_like(self.init_sp_seg.squeeze()) def get_rag_and_edge_feats(self, reward, edges): edge_indices = [] seg = self.init_sp_seg.clone() for edge in self.edge_ids.t(): n1, n2 = self.sp_indices[edge[0]], self.sp_indices[edge[1]] dis = torch.cdist(n1.float(), n2.float()) dis = (dis <= 1).nonzero() inds_n1 = n1[dis[:, 0].unique()] inds_n2 = n2[dis[:, 1].unique()] edge_indices.append(torch.cat((inds_n1, inds_n2), 0)) for indices in edge_indices: seg[indices[:, 0], indices[:, 1]] = 600 seg = cm.prism(seg.cpu().numpy() / seg.cpu().numpy().max()) plt.imshow(seg) plt.show() def reset(self): self.done = False self.win = False self.acc_reward = 0 self.last_reward = -inf self.counter = 0
class SpGcnEnv(Environment): def __init__(self, args, device, writer=None, writer_counter=None, win_event_counter=None): super(SpGcnEnv, self).__init__() self.stop_quality = 0 self.reset() self.args = args self.device = device self.writer = writer self.writer_counter = writer_counter self.win_event_counter = win_event_counter self.discrete_action_space = False if self.args.reward_function == 'fully_supervised': self.reward_function = FullySupervisedReward(env=self) elif self.args.reward_function == 'object_level': self.reward_function = ObjectLevelReward(env=self) elif self.args.reward_function == 'graph_dice': self.reward_function = GraphDiceReward(env=self) elif self.args.reward_function == 'focal': self.reward_function = FocalReward(env=self) elif self.args.reward_function == 'global_sparse': self.reward_function = GlobalSparseReward(env=self) else: self.reward_function = UnSupervisedReward(env=self) def execute_action(self, actions, logg_vals=None): last_diff = (self.b_current_edge_weights - self.b_gt_edge_weights).squeeze().abs() self.b_current_edge_weights = actions.clone() reward = self.reward_function.get(last_diff, actions, self.get_current_soln()).to( self.device) quality = (self.b_current_edge_weights - self.b_gt_edge_weights).squeeze().abs().sum().item() self.counter += 1 if self.counter >= self.args.max_episode_length: if quality < self.stop_quality: # reward += 2 self.win = True else: a = 1 # reward -= 1 self.done = True self.win_event_counter.increment() total_reward = torch.sum(reward).item() if self.writer is not None and self.done: self.writer.add_scalar("step/quality", quality, self.writer_counter.value()) self.writer.add_scalar("step/stop_quality", self.stop_quality, self.writer_counter.value()) self.writer.add_scalar("step/n_wins", self.win_event_counter.value(), self.writer_counter.value()) self.writer.add_scalar("step/steps_needed", self.counter, self.writer_counter.value()) self.writer.add_scalar("step/win_loose_ratio", (self.win_event_counter.value() + 1) / (self.writer_counter.value() + 1), self.writer_counter.value()) self.writer.add_scalar("step/pred_mean", self.b_current_edge_weights.mean(), self.writer_counter.value()) self.writer.add_scalar("step/pred_std", self.b_current_edge_weights.std(), self.writer_counter.value()) self.writer.add_scalar("step/gt_mean", self.b_gt_edge_weights.mean(), self.writer_counter.value()) self.writer.add_scalar("step/gt_std", self.b_gt_edge_weights.std(), self.writer_counter.value()) if logg_vals is not None: for key, val in logg_vals.items(): self.writer.add_scalar("step/" + key, val, self.writer_counter.value()) self.writer_counter.increment() self.acc_reward = total_reward return self.get_state(), reward, quality def get_state(self): state_pixels = torch.cat( [self.raw, self.init_sp_seg, self.get_current_soln()], dim=1) return state_pixels, self.b_edge_ids, self.sp_indices, self.b_edge_angles, self.counter, self.b_gt_edge_weights def update_data(self, b_edge_ids, edge_features, diff_to_gt, gt_edge_weights, node_labeling, raw, angles, gt): self.gt_seg = gt self.raw = raw self.init_sp_seg = node_labeling self.n_nodes = [edge_ids.max() + 1 for edge_ids in b_edge_ids] # b_subgraphs = find_dense_subgraphs([edge_ids.transpose(0, 1).cpu().numpy() for edge_ids in b_edge_ids], self.args.s_subgraph) self.b_edge_ids, (self.n_offs, self.e_offs) = collate_edges(b_edge_ids) # b_subgraphs = [torch.from_numpy(sg.astype(np.int64)).to(self.device).view(-1, 2).transpose(0, 1) + self.n_offs[i] for i, sg in enumerate(b_subgraphs)] # self.sg_offs = [0] # for i in range(len(b_subgraphs)): # self.sg_offs.append(self.sg_offs[-1] + b_subgraphs[i].shape[0]) # self.b_subgraphs = torch.cat(b_subgraphs, 1) # # self.b_subgraph_indices = get_edge_indices(self.b_edge_ids, self.b_subgraphs) self.b_gt_edge_weights = torch.cat(gt_edge_weights, 0) # self.sg_gt_edge_weights = self.b_gt_edge_weights[self.b_subgraph_indices].view(-1, self.args.s_subgraph) self.b_edge_angles = angles self.b_current_edge_weights = torch.ones_like( self.b_gt_edge_weights) / 2 # self.sg_current_edge_weights = torch.ones_like(self.sg_gt_edge_weights) / 2 self.b_initial_edge_weights = torch.cat( [edge_fe[:, 0] for edge_fe in edge_features], dim=0) self.b_edge_features = torch.cat(edge_features, dim=0) stacked_superpixels = [[node_labeling[i] == n for n in range(n_node)] for i, n_node in enumerate(self.n_nodes)] self.sp_indices = [[sp.nonzero().cpu() for sp in stacked_superpixel] for stacked_superpixel in stacked_superpixels] # self.b_penalize_diff_thresh = diff_to_gt * 4 # plt.imshow(self.get_current_soln_pic(1));plt.show() # return def get_batched_actions_from_global_graph(self, actions): b_actions = torch.zeros(size=(self.b_edge_ids.shape[1], )) other = torch.zeros_like(self.b_subgraph_indices) for i in range(self.b_edge_ids.shape[1]): mask = (self.b_subgraph_indices == i) num = mask.float().sum() b_actions[i] = torch.where(mask, actions.float(), other.float()).sum() / num return b_actions def get_current_soln(self): # affs = np.expand_dims(self.affinities, axis=1) # boundary_input = np.mean(affs, axis=0) # mc_seg = general.multicut_from_probas(self.init_sp_seg.squeeze().cpu(), self.edge_ids.cpu().t().contiguous().numpy(), # self.current_edge_weights.squeeze().cpu().numpy(), boundary_input) # return torch.from_numpy(mc_seg.astype(np.float32)) return torch.rand_like(self.init_sp_seg) def get_current_soln_pic(self, b): b_actions = self.get_batched_actions_from_global_graph( self.sg_current_edge_weights.view(-1)) b_gt = self.get_batched_actions_from_global_graph( self.sg_gt_edge_weights.view(-1)) edge_ids = self.b_edge_ids[:, self.e_offs[b]:self. e_offs[b + 1]] - self.n_offs[b] edge_ids = edge_ids.cpu().t().contiguous().numpy() boundary_input = self.b_initial_edge_weights[self.e_offs[b]:self. e_offs[b + 1]].cpu().numpy() mc_seg1 = general.multicut_from_probas( self.init_sp_seg[b].squeeze().cpu(), edge_ids, self.b_initial_edge_weights[self.e_offs[b]:self.e_offs[b + 1]].cpu( ).numpy(), boundary_input) mc_seg = general.multicut_from_probas( self.init_sp_seg[b].squeeze().cpu(), edge_ids, b_actions[self.e_offs[b]:self.e_offs[b + 1]].cpu().numpy(), boundary_input) gt_mc_seg = general.multicut_from_probas( self.init_sp_seg[b].squeeze().cpu(), edge_ids, b_gt[self.e_offs[b]:self.e_offs[b + 1]].cpu().numpy(), boundary_input) mc_seg = cm.prism(mc_seg / mc_seg.max()) mc_seg1 = cm.prism(mc_seg1 / mc_seg1.max()) seg = cm.prism(self.init_sp_seg[b].squeeze().cpu() / self.init_sp_seg[b].cpu().max()) gt_mc_seg = cm.prism(gt_mc_seg / gt_mc_seg.max()) return np.concatenate((np.concatenate( (mc_seg1, mc_seg), 0), np.concatenate((gt_mc_seg, seg), 0)), 1) #################### # init mc # gt seg # #################### # curr mc # sp seg # #################### def reset(self): self.done = False self.win = False self.acc_reward = 0 self.last_reward = -inf self.counter = 0
class SpGcnEnv(Environment): def __init__(self, args, device, writer=None, writer_counter=None, win_event_counter=None, discrete_action_space=True): super(SpGcnEnv, self).__init__() self.stop_quality = 0 self.reset() self.args = args self.device = device self.writer = writer self.writer_counter = writer_counter self.win_event_counter = win_event_counter self.discrete_action_space = discrete_action_space if self.args.reward_function == 'fully_supervised': self.reward_function = FullySupervisedReward(env=self) elif self.args.reward_function == 'object_level': self.reward_function = ObjectLevelReward(env=self) elif self.args.reward_function == 'graph_dice': self.reward_function = GraphDiceReward(env=self) elif self.args.reward_function == 'focal': self.reward_function = FocalReward(env=self) else: self.reward_function = UnSupervisedReward(env=self) def execute_action(self, actions): last_diff = (self.state[0] - self.gt_edge_weights).squeeze().abs() if self.discrete_action_space: mask = (actions == 2).float() * (self.state[0] + self.args.action_agression) mask += (actions == 1).float() * (self.state[0] - self.args.action_agression) mask += (actions == 0).float() * self.state[0] self.state[0] = mask + 1e-10 # prevent the reinforcement loss from becoming too large self.state[0] = self.state[0].clamp(min=0, max=1) else: self.state[0] = actions.clone() reward = self.reward_function.get(last_diff, actions, self.get_current_soln()).to(self.device) # self.get_rag_and_edge_feats(reward, self.state[0]) self.data_changed = torch.sum(torch.abs(self.state[0] - self.edge_features[:, 0])).cpu().item() penalize_change = 0 quality = (self.state[0] - self.gt_edge_weights).squeeze().abs().sum().item() if self.counter > self.args.max_episode_length: # penalize_change = (self.penalize_diff_thresh - self.data_changed) / np.prod(self.state.size()) * 10 if quality < self.stop_quality: reward += 2 self.win = True else: reward -= 1 self.done = True self.iteration += 1 self.iteration += 1 self.win_event_counter.increment() reward += (penalize_change * (actions != 0).float()) total_reward = torch.sum(reward).item() self.counter += 1 if self.writer is not None and self.done: self.writer.add_scalar("step/quality", quality, self.writer_counter.value()) self.writer.add_scalar("step/stop_quality", self.stop_quality, self.writer_counter.value()) self.writer.add_scalar("step/n_wins", self.win_event_counter.value(), self.writer_counter.value()) self.writer.add_scalar("step/steps_needed", self.counter, self.writer_counter.value()) self.writer.add_scalar("step/win_loose_ratio", (self.win_event_counter.value()+1) / (self.writer_counter.value()+1), self.writer_counter.value()) self.writer.add_scalar("step/pred_mean", self.state[0].mean(), self.writer_counter.value()) self.writer.add_scalar("step/pred_std", self.state[0].std(), self.writer_counter.value()) self.writer.add_scalar("step/gt_mean", self.gt_edge_weights.mean(), self.writer_counter.value()) self.writer.add_scalar("step/gt_std", self.gt_edge_weights.std(), self.writer_counter.value()) self.writer_counter.increment() self.acc_reward = total_reward self.state[1] = self.get_current_soln() return [self.state[0].clone(), self.state[1].clone()], reward, quality def update_data(self, edge_ids, edge_features, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt): self.gt_seg = gt self.affinities = affinities self.initial_edge_weights = edge_features[:, 0] self.edge_features = edge_features self.stacked_superpixels = [node_labeling == n for n in nodes] self.sp_indices = [sp.nonzero() for sp in self.stacked_superpixels] self.raw = raw self.penalize_diff_thresh = diff_to_gt * 4 self.init_sp_seg = node_labeling.squeeze() self.edge_ids = edge_ids self.gt_edge_weights = gt_edge_weights self.edge_angles = angles self.state = [torch.ones_like(gt_edge_weights) / 2, None] self.state = [torch.ones_like(gt_edge_weights) / 2, self.get_current_soln()] def show_current_soln(self): affs = np.expand_dims(self.affinities, axis=1) boundary_input = np.mean(affs, axis=0) mc_seg1 = general.multicut_from_probas(self.init_sp_seg.cpu(), self.edge_ids.cpu().t().contiguous().numpy(), self.initial_edge_weights.squeeze().cpu().numpy(), boundary_input) mc_seg = general.multicut_from_probas(self.init_sp_seg.cpu(), self.edge_ids.cpu().t().contiguous().numpy(), self.state[0].squeeze().cpu().numpy(), boundary_input) gt_mc_seg = general.multicut_from_probas(self.init_sp_seg.cpu(), self.edge_ids.cpu().t().contiguous().numpy(), self.gt_edge_weights.squeeze().cpu().numpy(), boundary_input) mc_seg = cm.prism(mc_seg / mc_seg.max()) mc_seg1 = cm.prism(mc_seg1 / mc_seg1.max()) seg = cm.prism(self.init_sp_seg.cpu() / self.init_sp_seg.cpu().max()) gt_mc_seg = cm.prism(gt_mc_seg / gt_mc_seg.max()) plt.imshow(np.concatenate((np.concatenate((mc_seg1, mc_seg), 0), np.concatenate((gt_mc_seg, seg), 0)), 1)); plt.show() a=1 #################### # init mc # gt seg # #################### # curr mc # sp seg # #################### def get_current_soln(self): affs = np.expand_dims(self.affinities, axis=1) boundary_input = np.mean(affs, axis=0) mc_seg = general.multicut_from_probas(self.init_sp_seg.squeeze().cpu(), self.edge_ids.cpu().t().contiguous().numpy(), self.state[0].squeeze().cpu().numpy(), boundary_input) return torch.from_numpy(mc_seg.astype(np.float)) def get_rag_and_edge_feats(self, reward, edges): edge_indices = [] seg = self.init_sp_seg.clone() for edge in self.edge_ids.t(): n1, n2 = self.sp_indices[edge[0]], self.sp_indices[edge[1]] dis = torch.cdist(n1.float(), n2.float()) dis = (dis <= 1).nonzero() inds_n1 = n1[dis[:, 0].unique()] inds_n2 = n2[dis[:, 1].unique()] edge_indices.append(torch.cat((inds_n1, inds_n2), 0)) for indices in edge_indices: seg[indices[:, 0], indices[:, 1]] = 600 seg = cm.prism(seg.cpu().numpy() / seg.cpu().numpy().max()) plt.imshow(seg) plt.show() a=1 def reset(self): self.done = False self.win = False self.iteration = 0 self.acc_reward = 0 self.last_reward = -inf self.counter = 0