def execute_action(self, actions, logg_vals=None, post_stats=False, post_images=False, tau=None, train=True): self.current_edge_weights = actions.squeeze() self.current_soln, obj_edge_ind_critic, obj_node_mask_critic, obj_edge_mask_actor = self.get_current_soln(self.current_edge_weights) if 'artificial_cells' in self.cfg.reward_function or 'leptin_data' in self.cfg.reward_function: split_actions = [actions[self.e_offs[i-1]:self.e_offs[i]].squeeze(-1) for i in range(1, len(self.e_offs))] sp_reward = self.reward_function(self.current_soln.long(), self.init_sp_seg.long(), dir_edges=self.dir_edge_ids, edge_score=False, res=50, sp_cmrads=self.sp_rads, actions=split_actions) object_weights = obj_node_mask_critic.sum(1) reward = [(sp_reward[None] * obj_node_mask_critic).sum(1) / object_weights] reward.append(self.last_final_reward) self.counter += 1 self.last_final_reward = reward[0].mean() else: assert False total_reward = 0 for _rew in reward: total_reward += _rew.mean().item() if post_stats: tag = "train/" if train else "validation/" wandb.log({tag + "avg_return": total_reward}) if post_images: mc_soln = self.gt_soln[-1].cpu() if self.gt_edge_weights is not None else torch.zeros(self.raw.shape[-2:]) wandb.log({tag + "pred_mean": wandb.Histogram(self.current_edge_weights.view(-1).cpu().numpy())}) fig, axes = plt.subplots(2, 3, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0}) axes[0, 0].imshow(self.gt_seg[-1].cpu().squeeze(), cmap=random_label_cmap(), interpolation="none") axes[0, 0].set_title('gt') if self.raw.ndim == 3: axes[0, 1].imshow(self.raw[-1, 0]) else: axes[0, 1].imshow(self.raw[-1]) axes[0, 1].imshow(self.raw[-1, 0].cpu().squeeze()) axes[0, 1].set_title('raw image') axes[0, 2].imshow(self.raw[-1, 1].cpu().squeeze()) axes[0, 2].set_title('edge sp') axes[1, 0].imshow(self.init_sp_seg[-1].cpu(), cmap=random_label_cmap(), interpolation="none") axes[1, 0].set_title('superpixels', y=-0.15) # axes[1, 1].imshow(pca_project(get_angles(self.embeddings)[0].detach().cpu().numpy())) axes[1, 1].imshow(pca_project(self.embeddings[-1].detach().cpu())) axes[1, 1].set_title('embed', y=-0.15) axes[1, 2].imshow(self.current_soln[-1].cpu(), cmap=random_label_cmap(), interpolation="none") axes[1, 2].set_title('prediction', y=-0.15) wandb.log({tag: [wandb.Image(fig, caption="state")]}) plt.close('all') if logg_vals is not None: for key, val in logg_vals.items(): wandb.log({tag + key: val}) self.acc_reward.append(total_reward) return reward, State(self.current_node_embeddings, self.edge_ids, self.edge_features, self.sp_feat, obj_edge_ind_critic, obj_node_mask_critic, obj_edge_mask_actor, self.gt_edge_weights)
def validate(self): """validates the prediction against the method of clustering the embedding space""" env = MulticutEmbeddingsEnv(self.cfg, self.device) if self.cfg.verbose: print("\n\n###### start validate ######", end='') self.model.eval() n_examples = len(self.val_dset) # taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] # rl_scores, keys = [], None self.clst_metric.reset() map_scores = [] ex_raws, ex_sps, ex_gts, ex_mc_gts, ex_feats, ex_emb, ex_n_emb, ex_rl, edge_ids, rewards, actions = [ [] for _ in range(11) ] dloader = iter( DataLoader(self.val_dset, batch_size=1, shuffle=False, pin_memory=True, num_workers=0)) acc_reward = 0 for it in range(n_examples): update_env_data(env, dloader, self.val_dset, self.device, with_gt_edges="sub_graph_dice" in self.cfg.reward_function) env.reset() state = env.get_state() self.model_mtx.acquire() try: distr, _, _, _, _, node_features, embeddings = self.forwarder.forward( self.model, state, State, self.device, grad=False, post_data=False, get_node_feats=True, get_embeddings=True) finally: self.model_mtx.release() action = torch.sigmoid(distr.loc) reward = env.execute_action(action, tau=0.0, train=False) rew = reward[-1].item( ) if self.cfg.reward_function == "sub_graph_dice" else reward[ -2].item() acc_reward += rew rl_labels = env.current_soln.cpu().numpy()[0] gt_seg = env.gt_seg[0].cpu().numpy() if self.cfg.verbose: print( f"\nstep: {it}; mean_loc: {round(distr.loc.mean().item(), 5)}; mean reward: {round(rew, 5)}", end='') if it in self.cfg.store_indices: node_features = node_features[:env.n_offs[1]][ env.init_sp_seg[0].long()].permute(2, 0, 1).cpu() gt_mc = cm.prism( env.gt_soln[0].cpu() / env.gt_soln[0].max().item() ) if env.gt_edge_weights is not None else torch.zeros( env.raw.shape[-2:]) ex_feats.append(pca_project(node_features, n_comps=3)) ex_emb.append(pca_project(embeddings[0].cpu(), n_comps=3)) ex_n_emb.append( pca_project(node_features[:self.cfg.dim_embeddings], n_comps=3)) ex_raws.append(env.raw[0].cpu().permute(1, 2, 0).squeeze()) ex_sps.append(env.init_sp_seg[0].cpu()) ex_mc_gts.append(gt_mc) ex_gts.append(gt_seg) ex_rl.append(rl_labels) edge_ids.append(env.edge_ids) rewards.append(reward[-1]) actions.append(action) map_scores.append(self.segm_metric(rl_labels, gt_seg)) self.clst_metric(rl_labels, gt_seg) ''' _rl_scores = matching(gt_seg, rl_labels, thresh=taus, criterion='iou', report_matches=False) if it == 0: for tau_it in range(len(_rl_scores)): rl_scores.append(np.array(list(map(float, list(_rl_scores[tau_it]._asdict().values())[1:])))) keys = list(_rl_scores[0]._asdict().keys())[1:] else: for tau_it in range(len(_rl_scores)): rl_scores[tau_it] += np.array(list(map(float, list(_rl_scores[tau_it]._asdict().values())[1:])) ''' ''' div = np.ones_like(rl_scores[0]) for i, key in enumerate(keys): if key not in ('fp', 'tp', 'fn'): div[i] = 10 for tau_it in range(len(rl_scores)): rl_scores[tau_it] = dict(zip(keys, rl_scores[tau_it] / div)) fig, axs = plt.subplots(1, 2, figsize=(10, 10)) plt.subplots_adjust(hspace=.5) for m in ('precision', 'recall', 'accuracy', 'f1'): y = [s[m] for s in rl_scores] data = [[x, y] for (x, y) in zip(taus, y)] table = wandb.Table(data=data, columns=["IoU_threshold", m]) wandb.log({"validation/" + m: wandb.plot.line(table, "IoU_threshold", m, stroke=None, title=m)}) axs[0].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m) axs[0].set_ylabel('Metric value') axs[0].grid() axs[0].legend(bbox_to_anchor=(.8, 1.65), loc='upper left', fontsize='xx-small') axs[0].set_title('RL method') axs[0].set_xlabel(r'IoU threshold $\tau$') for m in ('fp', 'tp', 'fn'): y = [s[m] for s in rl_scores] data = [[x, y] for (x, y) in zip(taus, y)] table = wandb.Table(data=data, columns=["IoU_threshold", m]) wandb.log({"validation/" + m: wandb.plot.line(table, "IoU_threshold", m, stroke=None, title=m)}) axs[1].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m) axs[1].set_ylabel('Number #') axs[1].grid() axs[1].legend(bbox_to_anchor=(.87, 1.6), loc='upper left', fontsize='xx-small'); axs[1].set_title('RL method') axs[1].set_xlabel(r'IoU threshold $\tau$') #wandb.log({"validation/metrics": [wandb.Image(fig, caption="metrics")]}) plt.close('all') ''' splits, merges, are, arp, arr = self.clst_metric.dump() wandb.log({"validation/acc_reward": acc_reward}) wandb.log({"validation/mAP": np.mean(map_scores)}, step=self.global_counter) wandb.log({"validation/UnderSegmVI": splits}, step=self.global_counter) wandb.log({"validation/OverSegmVI": merges}, step=self.global_counter) wandb.log({"validation/ARE": are}, step=self.global_counter) wandb.log({"validation/ARP": arp}, step=self.global_counter) wandb.log({"validation/ARR": arr}, step=self.global_counter) # do the lr sheduling self.optimizers.critic_shed.step(acc_reward) self.optimizers.actor_shed.step(acc_reward) if acc_reward > self.best_val_reward: self.best_val_reward = acc_reward wandb.run.summary["validation/acc_reward"] = acc_reward torch.save( self.model.state_dict(), os.path.join(wandb.run.dir, "best_checkpoint_agent.pth")) if self.cfg.verbose: print("\n###### finish validate ######\n", end='') label_cm = random_label_cmap(zeroth=1.0) label_cm.set_bad(alpha=0) for it, i in enumerate(self.cfg.store_indices): fig, axs = plt.subplots( 2, 4 if self.cfg.reward_function == "sub_graph_dice" else 5, sharex='col', figsize=(9, 5), sharey='row', gridspec_kw={ 'hspace': 0, 'wspace': 0 }) axs[0, 0].imshow(ex_gts[it], cmap=random_label_cmap(), interpolation="none") axs[0, 0].set_title('gt', y=1.05, size=10) axs[0, 0].axis('off') if ex_raws[it].ndim == 3: if ex_raws[it].shape[-1] > 2: axs[0, 1].imshow(ex_raws[it][..., :3], cmap="gray") else: axs[0, 1].imshow(ex_raws[it][..., 0], cmap="gray") else: axs[1, 1].imshow(ex_raws[it], cmap="gray") axs[0, 1].set_title('raw image', y=1.05, size=10) axs[0, 1].axis('off') if ex_raws[it].ndim == 3: if ex_raws[it].shape[-1] > 1: axs[0, 2].imshow(ex_raws[it][..., -1], cmap="gray") else: axs[0, 2].imshow(ex_raws[it][..., 0], cmap="gray") else: axs[0, 2].imshow(ex_raws[it], cmap="gray") axs[0, 2].set_title('plantseg', y=1.05, size=10) axs[0, 2].axis('off') axs[0, 3].imshow(ex_sps[it], cmap=random_label_cmap(), interpolation="none") axs[0, 3].set_title('superpixels', y=1.05, size=10) axs[0, 3].axis('off') axs[1, 0].imshow(ex_feats[it]) axs[1, 0].set_title('features', y=-0.15, size=10) axs[1, 0].axis('off') axs[1, 1].imshow(ex_n_emb[it]) axs[1, 1].set_title('node embeddings', y=-0.15, size=10) axs[1, 1].axis('off') axs[1, 2].imshow(ex_emb[it]) axs[1, 2].set_title('embeddings', y=-0.15, size=10) axs[1, 2].axis('off') axs[1, 3].imshow(ex_rl[it], cmap=random_label_cmap(), interpolation="none") axs[1, 3].set_title('prediction', y=-0.15, size=10) axs[1, 3].axis('off') if self.cfg.reward_function != "sub_graph_dice": frame_rew, scores_rew, bnd_mask = get_colored_edges_in_sseg( ex_sps[it][None].float(), edge_ids[it].cpu(), rewards[it].cpu()) frame_act, scores_act, _ = get_colored_edges_in_sseg( ex_sps[it][None].float(), edge_ids[it].cpu(), 1 - actions[it].cpu().squeeze()) bnd_mask = torch.from_numpy(dilation(bnd_mask.cpu().numpy())) frame_rew = np.stack([ dilation(frame_rew.cpu().numpy()[..., i]) for i in range(3) ], -1) frame_act = np.stack([ dilation(frame_act.cpu().numpy()[..., i]) for i in range(3) ], -1) ex_rl[it] = ex_rl[it].squeeze().astype(np.float) ex_rl[it][bnd_mask] = np.nan axs[1, 4].imshow(frame_rew, interpolation="none") axs[1, 4].imshow(ex_rl[it], cmap=label_cm, alpha=0.8, interpolation="none") axs[1, 4].set_title("rewards", y=-0.2) axs[1, 4].axis('off') axs[0, 4].imshow(frame_act, interpolation="none") axs[0, 4].imshow(ex_rl[it], cmap=label_cm, alpha=0.8, interpolation="none") axs[0, 4].set_title("actions", y=1.05) axs[0, 4].axis('off') wandb.log( { "validation/sample_" + str(i): [wandb.Image(fig, caption="sample images")] }, step=self.global_counter) plt.close('all')
def validate(self): """validates the prediction against the method of clustering the embedding space""" env = MulticutEmbeddingsEnv(self.fe_ext, self.cfg, self.device) if self.cfg.verbose: print("\n\n###### start validate ######", end='') self.model.eval() n_examples = len(self.val_dset) taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] rl_scores, keys = [], None ex_raws, ex_sps, ex_gts, ex_mc_gts, ex_embeds, ex_rl = [], [], [], [], [], [] dloader = iter( DataLoader(self.val_dset, batch_size=1, shuffle=True, pin_memory=True, num_workers=0)) acc_reward = 0 for it in range(len(self.val_dset)): update_env_data(env, dloader, self.val_dset, self.device, with_gt_edges="sub_graph_dice" in self.cfg.reward_function) env.reset() state = env.get_state() self.model_mtx.acquire() try: distr, _ = self.forwarder.forward(self.model, state, State, self.device, grad=False, post_data=False) finally: self.model_mtx.release() action = torch.sigmoid(distr.loc) reward, state = env.execute_action(action, None, post_images=True, tau=0.0, train=False) acc_reward += reward[-1].item() if self.cfg.verbose: print( f"\nstep: {it}; mean_loc: {round(distr.loc.mean().item(), 5)}; mean reward: {round(reward[-1].item(), 5)}", end='') embeddings = env.embeddings[0].cpu().numpy() gt_seg = env.gt_seg[0].cpu().numpy() gt_mc = cm.prism( env.gt_soln[0].cpu() / env.gt_soln[0].max().item() ) if env.gt_edge_weights is not None else torch.zeros( env.raw.shape[-2:]) rl_labels = env.current_soln.cpu().numpy()[0] if it < n_examples: ex_embeds.append(pca_project(embeddings, n_comps=3)) ex_raws.append(env.raw[0].cpu().permute(1, 2, 0).squeeze()) ex_sps.append(env.init_sp_seg[0].cpu()) ex_mc_gts.append(gt_mc) ex_gts.append(gt_seg) ex_rl.append(rl_labels) _rl_scores = matching(gt_seg, rl_labels, thresh=taus, criterion='iou', report_matches=False) if it == 0: for tau_it in range(len(_rl_scores)): rl_scores.append( np.array( list( map( float, list(_rl_scores[tau_it]._asdict().values()) [1:])))) keys = list(_rl_scores[0]._asdict().keys())[1:] else: for tau_it in range(len(_rl_scores)): rl_scores[tau_it] += np.array( list( map( float, list(_rl_scores[tau_it]._asdict().values()) [1:]))) div = np.ones_like(rl_scores[0]) for i, key in enumerate(keys): if key not in ('fp', 'tp', 'fn'): div[i] = 10 for tau_it in range(len(rl_scores)): rl_scores[tau_it] = dict(zip(keys, rl_scores[tau_it] / div)) fig, axs = plt.subplots(1, 2, figsize=(10, 10)) plt.subplots_adjust(hspace=.5) for m in ('precision', 'recall', 'accuracy', 'f1'): y = [s[m] for s in rl_scores] data = [[x, y] for (x, y) in zip(taus, y)] table = wandb.Table(data=data, columns=["IoU_threshold", m]) wandb.log({ "validation/" + m: wandb.plot.line(table, "IoU_threshold", m, stroke=None, title=m) }) axs[0].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m) axs[0].set_ylabel('Metric value') axs[0].grid() axs[0].legend(bbox_to_anchor=(.8, 1.65), loc='upper left', fontsize='xx-small') axs[0].set_title('RL method') axs[0].set_xlabel(r'IoU threshold $\tau$') for m in ('fp', 'tp', 'fn'): y = [s[m] for s in rl_scores] data = [[x, y] for (x, y) in zip(taus, y)] table = wandb.Table(data=data, columns=["IoU_threshold", m]) wandb.log({ "validation/" + m: wandb.plot.line(table, "IoU_threshold", m, stroke=None, title=m) }) axs[1].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m) axs[1].set_ylabel('Number #') axs[1].grid() axs[1].legend(bbox_to_anchor=(.87, 1.6), loc='upper left', fontsize='xx-small') axs[1].set_title('RL method') axs[1].set_xlabel(r'IoU threshold $\tau$') wandb.log( {"validation/metrics": [wandb.Image(fig, caption="metrics")]}) wandb.log({"validation_reward": acc_reward}) plt.close('all') if acc_reward > self.best_val_reward: self.best_val_reward = acc_reward wandb.run.summary["validation_reward"] = acc_reward torch.save( self.model.state_dict(), os.path.join(wandb.run.dir, "best_checkpoint_agent.pth")) if self.cfg.verbose: print("\n###### finish validate ######\n", end='') for i in range(n_examples): fig, axs = plt.subplots(2, 3, sharex='col', sharey='row', gridspec_kw={ 'hspace': 0, 'wspace': 0 }) axs[0, 0].imshow(ex_gts[i], cmap=random_label_cmap(), interpolation="none") axs[0, 0].set_title('gt') axs[0, 0].axis('off') if ex_raws[i].ndim == 3: axs[0, 1].imshow(ex_raws[i][..., 0]) else: axs[0, 1].imshow(ex_raws[i]) axs[0, 1].set_title('raw image') axs[0, 1].axis('off') axs[0, 2].imshow(ex_sps[i], cmap=random_label_cmap(), interpolation="none") axs[0, 2].set_title('superpixels') axs[0, 2].axis('off') axs[1, 0].imshow(ex_embeds[i]) axs[1, 0].set_title('pc proj 1-3', y=-0.15) axs[1, 0].axis('off') if ex_raws[i].ndim == 3: if ex_raws[i].shape[-1] > 1: axs[1, 1].imshow(ex_raws[i][..., 1]) else: axs[1, 1].imshow(ex_raws[i][..., 0]) else: axs[1, 1].imshow(ex_raws[i]) axs[1, 1].set_title('sp edge', y=-0.15) axs[1, 1].axis('off') axs[1, 2].imshow(ex_rl[i], cmap=random_label_cmap(), interpolation="none") axs[1, 2].set_title('prediction', y=-0.15) axs[1, 2].axis('off') wandb.log({ "validation/samples": [wandb.Image(fig, caption="sample images")] }) plt.close('all')
def validate_and_compare_to_clustering(model, env, distance, device, cfg): """validates the prediction against the method of clustering the embedding space""" model.eval() offs = [[1, 0], [0, 1], [2, 0], [0, 2], [4, 0], [0, 4], [16, 0], [0, 16]] ex_raws, ex_sps, ex_gts, ex_mc_gts, ex_embeds, ex_clst, ex_clst_sp, ex_mcaff, ex_mc_embed, ex_rl, \ ex_clst_graph_agglo= [], [], [], [], [], [], [], [], [], [], [] dset = SpgDset(cfg.val_data_dir, dict_to_attrdict(cfg.patch_manager), dict_to_attrdict(cfg.val_data_keys), max(cfg.s_subgraph)) dloader = iter(DataLoader(dset)) acc_reward = 0 forwarder = Forwarder() delta_dist = 0.4 # segm_metric = AveragePrecision() clst_metric_rl = ClusterMetrics() # clst_metric = ClusterMetrics() metric_sp_gt = ClusterMetrics() # clst_metric_mcaff = ClusterMetrics() # clst_metric_mcembed = ClusterMetrics() # clst_metric_graphagglo = ClusterMetrics() sbd = SBD() # map_rl, map_embed, map_sp_gt, map_mcaff, map_mcembed, map_graphagglo = [], [], [], [], [], [] sbd_rl, sbd_embed, sbd_sp_gt, sbd_mcaff, sbd_mcembed, sbd_graphagglo = [], [], [], [], [], [] n_examples = len(dset) for it in range(n_examples): update_env_data(env, dloader, dset, device, with_gt_edges=False) env.reset() state = env.get_state() distr, _, _, _, _, node_features, embeddings = forwarder.forward(model, state, State, device, grad=False, post_data=False, get_node_feats=True, get_embeddings=True) action = torch.sigmoid(distr.loc) reward = env.execute_action(action, tau=0.0, train=False) acc_reward += reward[-2].item() embeds = embeddings[0].cpu() # node_features = node_features.cpu().numpy() rag = env.rags[0] edge_ids = rag.uvIds() gt_seg = env.gt_seg[0].cpu().numpy() # l2_embeddings = get_angles(embeds[None])[0] # l2_node_feats = get_angles(torch.from_numpy(node_features.T[None, ..., None])).squeeze().T.numpy() # clst_labels_kmeans = cluster_embeddings(l2_embeddings.permute((1, 2, 0)), len(np.unique(gt_seg))) # node_labels = cluster_embeddings(l2_node_feats, len(np.unique(gt_seg))) # clst_labels_sp_kmeans = elf.segmentation.features.project_node_labels_to_pixels(rag, node_labels).squeeze() # clst_labels_sp_graph_agglo = get_soln_graph_clustering(env.init_sp_seg, torch.from_numpy(edge_ids.astype(np.int)), torch.from_numpy(l2_node_feats), len(np.unique(gt_seg)))[0][0].numpy() # mc_labels_aff = env.get_current_soln(edge_weights=env.edge_features[:, 0]).cpu().numpy()[0] # ew_embedaffs = 1 - get_edge_features_1d(env.init_sp_seg[0].cpu().numpy(), offs, get_affinities_from_embeddings_2d(embeddings, offs, delta_dist, distance)[0].cpu().numpy())[0][:, 0] # mc_labels_embedding_aff = env.get_current_soln(edge_weights=torch.from_numpy(ew_embedaffs).to(device)).cpu().numpy()[0] rl_labels = env.current_soln.cpu().numpy()[0] ex_embeds.append(pca_project(embeds, n_comps=3)) ex_raws.append(env.raw[0].cpu().permute(1, 2, 0).squeeze()) # ex_sps.append(cm.prism(env.init_sp_seg[0].cpu() / env.init_sp_seg[0].max().item())) ex_sps.append(env.init_sp_seg[0].cpu()) ex_mc_gts.append(project_overseg_to_seg(env.init_sp_seg[0], torch.from_numpy(gt_seg).to(device)).cpu().numpy()) ex_gts.append(gt_seg) ex_rl.append(rl_labels) # ex_clst.append(clst_labels_kmeans) # ex_clst_sp.append(clst_labels_sp_kmeans) # ex_clst_graph_agglo.append(clst_labels_sp_graph_agglo) # ex_mcaff.append(mc_labels_aff) # ex_mc_embed.append(mc_labels_embedding_aff) # map_rl.append(segm_metric(rl_labels, gt_seg)) sbd_rl.append(sbd(gt_seg, rl_labels)) clst_metric_rl(rl_labels, gt_seg) # map_sp_gt.append(segm_metric(ex_mc_gts[-1], gt_seg)) sbd_sp_gt.append(sbd(gt_seg, ex_mc_gts[-1])) metric_sp_gt(ex_mc_gts[-1], gt_seg) # map_embed.append(segm_metric(clst_labels_kmeans, gt_seg)) # clst_metric(clst_labels_kmeans, gt_seg) # map_mcaff.append(segm_metric(mc_labels_aff, gt_seg)) # sbd_mcaff.append(sbd(gt_seg, mc_labels_aff)) # clst_metric_mcaff(mc_labels_aff, gt_seg) # # map_mcembed.append(segm_metric(mc_labels_embedding_aff, gt_seg)) # sbd_mcembed.append(sbd(gt_seg, mc_labels_embedding_aff)) # clst_metric_mcembed(mc_labels_embedding_aff, gt_seg) # # map_graphagglo.append(segm_metric(clst_labels_sp_graph_agglo, gt_seg)) # sbd_graphagglo.append(sbd(gt_seg, clst_labels_sp_graph_agglo.astype(np.int))) # clst_metric_graphagglo(clst_labels_sp_graph_agglo.astype(np.int), gt_seg) print("\nSBD: ") print(f"sp gt : {round(np.array(sbd_sp_gt).mean(), 4)}; {round(np.array(sbd_sp_gt).std(), 4)}") print(f"ours : {round(np.array(sbd_rl).mean(), 4)}; {round(np.array(sbd_rl).std(), 4)}") # print(f"mc node : {np.array(sbd_mcembed).mean()}") # print(f"mc embed : {np.array(sbd_mcaff).mean()}") # print(f"graph agglo : {np.array(sbd_graphagglo).mean()}") # print("\nmAP: ") # print(f"sp gt : {np.array(map_sp_gt).mean()}") # print(f"ours : {np.array(map_rl).mean()}") # print(f"mc node : {np.array(map_mcembed).mean()}") # print(f"mc embed : {np.array(map_mcaff).mean()}") # print(f"graph agglo : {np.array(map_graphagglo).mean()}") # vi_rl_s, vi_rl_m, are_rl, arp_rl, arr_rl = clst_metric_rl.dump() vi_spgt_s, vi_spgt_m, are_spgt, arp_spgt, arr_spgt = metric_sp_gt.dump() # vi_mcaff_s, vi_mcaff_m, are_mcaff, arp_mcaff, arr_mcaff = clst_metric_mcaff.dump() # vi_mcembed_s, vi_mcembed_m, are_mcembed, arp_embed, arr_mcembed = clst_metric_mcembed.dump() # vi_graphagglo_s, vi_graphagglo_m, are_graphagglo, arp_graphagglo, arr_graphagglo = clst_metric_graphagglo.dump() # vi_rl_s_std, vi_rl_m_std, are_rl_std, arp_rl_std, arr_rl_std = clst_metric_rl.dump_std() vi_spgt_s_std, vi_spgt_m_std, are_spgt_std, arp_spgt_std, arr_spgt_std = metric_sp_gt.dump_std() print("\nVI merge: ") print(f"sp gt : {round(vi_spgt_m, 4)}; {round(vi_spgt_m_std, 4)}") print(f"ours : {round(vi_rl_m, 4)}; {round(vi_rl_m_std, 4)}") # print(f"mc affnties : {vi_mcaff_m}") # print(f"mc embed : {vi_mcembed_m}") # print(f"graph agglo : {vi_graphagglo_m}") # print("\nVI split: ") print(f"sp gt : {round(vi_spgt_s, 4)}; {round(vi_spgt_s_std, 4)}") print(f"ours : {round(vi_rl_s, 4)}; {round(vi_rl_s_std, 4)}") # print(f"mc affnties : {vi_mcaff_s}") # print(f"mc embed : {vi_mcembed_s}") # print(f"graph agglo : {vi_graphagglo_s}") # print("\nARE: ") print(f"sp gt : {round(are_spgt, 4)}; {round(are_spgt_std, 4)}") print(f"ours : {round(are_rl, 4)}; {round(are_rl_std, 4)}") # print(f"mc affnties : {are_mcaff}") # print(f"mc embed : {are_mcembed}") # print(f"graph agglo : {are_graphagglo}") # print("\nARP: ") print(f"sp gt : {round(arp_spgt, 4)}; {round(arp_spgt_std, 4)}") print(f"ours : {round(arp_rl, 4)}; {round(arp_rl_std, 4)}") # print(f"mc affnties : {arp_mcaff}") # print(f"mc embed : {arp_embed}") # print(f"graph agglo : {arp_graphagglo}") # print("\nARR: ") print(f"sp gt : {round(arr_spgt, 4)}; {round(arr_spgt_std, 4)}") print(f"ours : {round(arr_rl, 4)}; {round(arr_rl_std, 4)}") # print(f"mc affnties : {arr_mcaff}") # print(f"mc embed : {arr_mcembed}") # print(f"graph agglo : {arr_graphagglo}") exit() for i in range(len(ex_gts)): fig, axs = plt.subplots(2, 4, figsize=(20, 13), sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0}) axs[0, 0].imshow(ex_gts[i], cmap=random_label_cmap(), interpolation="none") axs[0, 0].set_title('gt') axs[0, 0].axis('off') axs[0, 1].imshow(ex_embeds[i]) axs[0, 1].set_title('pc proj') axs[0, 1].axis('off') # axs[0, 2].imshow(ex_clst[i], cmap=random_label_cmap(), interpolation="none") # axs[0, 2].set_title('pix clst') # axs[0, 2].axis('off') axs[0, 2].imshow(ex_clst_graph_agglo[i], cmap=random_label_cmap(), interpolation="none") axs[0, 2].set_title('nagglo') axs[0, 2].axis('off') axs[0, 3].imshow(ex_mc_embed[i], cmap=random_label_cmap(), interpolation="none") axs[0, 3].set_title('mc embed') axs[0, 3].axis('off') axs[1, 0].imshow(ex_mc_gts[i], cmap=random_label_cmap(), interpolation="none") axs[1, 0].set_title('sp gt') axs[1, 0].axis('off') axs[1, 1].imshow(ex_sps[i], cmap=random_label_cmap(), interpolation="none") axs[1, 1].set_title('sp') axs[1, 1].axis('off') # axs[1, 2].imshow(ex_clst_sp[i], cmap=random_label_cmap(), interpolation="none") # axs[1, 2].set_title('sp clst') # axs[1, 2].axis('off') axs[1, 2].imshow(ex_rl[i], cmap=random_label_cmap(), interpolation="none") axs[1, 2].set_title('ours') axs[1, 2].axis('off') axs[1, 3].imshow(ex_mcaff[i], cmap=random_label_cmap(), interpolation="none") axs[1, 3].set_title('mc aff') axs[1, 3].axis('off') plt.show() # wandb.log({"validation/samples": [wandb.Image(fig, caption="sample images")]}) plt.close('all')