コード例 #1
0
    def execute_action(self, actions, logg_vals=None, post_stats=False, post_images=False, tau=None, train=True):
        self.current_edge_weights = actions.squeeze()

        self.current_soln, obj_edge_ind_critic, obj_node_mask_critic, obj_edge_mask_actor = self.get_current_soln(self.current_edge_weights)

        if 'artificial_cells' in self.cfg.reward_function or 'leptin_data' in self.cfg.reward_function:
            split_actions = [actions[self.e_offs[i-1]:self.e_offs[i]].squeeze(-1) for i in range(1, len(self.e_offs))]
            sp_reward = self.reward_function(self.current_soln.long(), self.init_sp_seg.long(),
                                               dir_edges=self.dir_edge_ids, edge_score=False, res=50,
                                               sp_cmrads=self.sp_rads, actions=split_actions)
            object_weights = obj_node_mask_critic.sum(1)
            reward = [(sp_reward[None] * obj_node_mask_critic).sum(1) / object_weights]
            reward.append(self.last_final_reward)
            self.counter += 1
            self.last_final_reward = reward[0].mean()
        else:
            assert False

        total_reward = 0
        for _rew in reward:
            total_reward += _rew.mean().item()

        if post_stats:
            tag = "train/" if train else "validation/"
            wandb.log({tag + "avg_return": total_reward})
            if post_images:
                mc_soln = self.gt_soln[-1].cpu() if self.gt_edge_weights is not None else torch.zeros(self.raw.shape[-2:])
                wandb.log({tag + "pred_mean": wandb.Histogram(self.current_edge_weights.view(-1).cpu().numpy())})
                fig, axes = plt.subplots(2, 3, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0})
                axes[0, 0].imshow(self.gt_seg[-1].cpu().squeeze(), cmap=random_label_cmap(), interpolation="none")
                axes[0, 0].set_title('gt')
                if self.raw.ndim == 3:
                    axes[0, 1].imshow(self.raw[-1, 0])
                else:
                    axes[0, 1].imshow(self.raw[-1])
                axes[0, 1].imshow(self.raw[-1, 0].cpu().squeeze())
                axes[0, 1].set_title('raw image')
                axes[0, 2].imshow(self.raw[-1, 1].cpu().squeeze())
                axes[0, 2].set_title('edge sp')
                axes[1, 0].imshow(self.init_sp_seg[-1].cpu(), cmap=random_label_cmap(), interpolation="none")
                axes[1, 0].set_title('superpixels', y=-0.15)
                # axes[1, 1].imshow(pca_project(get_angles(self.embeddings)[0].detach().cpu().numpy()))
                axes[1, 1].imshow(pca_project(self.embeddings[-1].detach().cpu()))
                axes[1, 1].set_title('embed', y=-0.15)
                axes[1, 2].imshow(self.current_soln[-1].cpu(), cmap=random_label_cmap(), interpolation="none")
                axes[1, 2].set_title('prediction', y=-0.15)
                wandb.log({tag: [wandb.Image(fig, caption="state")]})
                plt.close('all')
            if logg_vals is not None:
                for key, val in logg_vals.items():
                    wandb.log({tag + key: val})

        self.acc_reward.append(total_reward)
        return reward, State(self.current_node_embeddings, self.edge_ids, self.edge_features, self.sp_feat,
                             obj_edge_ind_critic, obj_node_mask_critic, obj_edge_mask_actor, self.gt_edge_weights)
コード例 #2
0
ファイル: sac.py プロジェクト: paulhfu/RLForSeg
    def validate(self):
        """validates the prediction against the method of clustering the embedding space"""
        env = MulticutEmbeddingsEnv(self.cfg, self.device)
        if self.cfg.verbose:
            print("\n\n###### start validate ######", end='')
        self.model.eval()
        n_examples = len(self.val_dset)
        # taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        # rl_scores, keys = [], None

        self.clst_metric.reset()
        map_scores = []
        ex_raws, ex_sps, ex_gts, ex_mc_gts, ex_feats, ex_emb, ex_n_emb, ex_rl, edge_ids, rewards, actions = [
            [] for _ in range(11)
        ]
        dloader = iter(
            DataLoader(self.val_dset,
                       batch_size=1,
                       shuffle=False,
                       pin_memory=True,
                       num_workers=0))
        acc_reward = 0

        for it in range(n_examples):
            update_env_data(env,
                            dloader,
                            self.val_dset,
                            self.device,
                            with_gt_edges="sub_graph_dice"
                            in self.cfg.reward_function)
            env.reset()
            state = env.get_state()

            self.model_mtx.acquire()
            try:
                distr, _, _, _, _, node_features, embeddings = self.forwarder.forward(
                    self.model,
                    state,
                    State,
                    self.device,
                    grad=False,
                    post_data=False,
                    get_node_feats=True,
                    get_embeddings=True)
            finally:
                self.model_mtx.release()
            action = torch.sigmoid(distr.loc)
            reward = env.execute_action(action, tau=0.0, train=False)
            rew = reward[-1].item(
            ) if self.cfg.reward_function == "sub_graph_dice" else reward[
                -2].item()
            acc_reward += rew
            rl_labels = env.current_soln.cpu().numpy()[0]
            gt_seg = env.gt_seg[0].cpu().numpy()
            if self.cfg.verbose:
                print(
                    f"\nstep: {it}; mean_loc: {round(distr.loc.mean().item(), 5)}; mean reward: {round(rew, 5)}",
                    end='')

            if it in self.cfg.store_indices:
                node_features = node_features[:env.n_offs[1]][
                    env.init_sp_seg[0].long()].permute(2, 0, 1).cpu()
                gt_mc = cm.prism(
                    env.gt_soln[0].cpu() / env.gt_soln[0].max().item()
                ) if env.gt_edge_weights is not None else torch.zeros(
                    env.raw.shape[-2:])

                ex_feats.append(pca_project(node_features, n_comps=3))
                ex_emb.append(pca_project(embeddings[0].cpu(), n_comps=3))
                ex_n_emb.append(
                    pca_project(node_features[:self.cfg.dim_embeddings],
                                n_comps=3))
                ex_raws.append(env.raw[0].cpu().permute(1, 2, 0).squeeze())
                ex_sps.append(env.init_sp_seg[0].cpu())
                ex_mc_gts.append(gt_mc)
                ex_gts.append(gt_seg)
                ex_rl.append(rl_labels)
                edge_ids.append(env.edge_ids)
                rewards.append(reward[-1])
                actions.append(action)

            map_scores.append(self.segm_metric(rl_labels, gt_seg))
            self.clst_metric(rl_labels, gt_seg)
            '''
            _rl_scores = matching(gt_seg, rl_labels, thresh=taus, criterion='iou', report_matches=False)
            if it == 0:
                for tau_it in range(len(_rl_scores)):
                    rl_scores.append(np.array(list(map(float, list(_rl_scores[tau_it]._asdict().values())[1:]))))
                keys = list(_rl_scores[0]._asdict().keys())[1:]
            else:
                for tau_it in range(len(_rl_scores)):
                    rl_scores[tau_it] += np.array(list(map(float, list(_rl_scores[tau_it]._asdict().values())[1:]))
            '''
        '''
        div = np.ones_like(rl_scores[0])
        for i, key in enumerate(keys):
            if key not in ('fp', 'tp', 'fn'):
                div[i] = 10
        for tau_it in range(len(rl_scores)):
            rl_scores[tau_it] = dict(zip(keys, rl_scores[tau_it] / div))
        fig, axs = plt.subplots(1, 2, figsize=(10, 10))
        plt.subplots_adjust(hspace=.5)
        for m in ('precision', 'recall', 'accuracy', 'f1'):
            y = [s[m] for s in rl_scores]
            data = [[x, y] for (x, y) in zip(taus, y)]
            table = wandb.Table(data=data, columns=["IoU_threshold", m])
            wandb.log({"validation/" + m: wandb.plot.line(table, "IoU_threshold", m, stroke=None, title=m)})
            axs[0].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m)
        axs[0].set_ylabel('Metric value')
        axs[0].grid()
        axs[0].legend(bbox_to_anchor=(.8, 1.65), loc='upper left', fontsize='xx-small')
        axs[0].set_title('RL method')
        axs[0].set_xlabel(r'IoU threshold $\tau$')
        for m in ('fp', 'tp', 'fn'):
            y = [s[m] for s in rl_scores]
            data = [[x, y] for (x, y) in zip(taus, y)]
            table = wandb.Table(data=data, columns=["IoU_threshold", m])
            wandb.log({"validation/" + m: wandb.plot.line(table, "IoU_threshold", m, stroke=None, title=m)})
            axs[1].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m)
        axs[1].set_ylabel('Number #')
        axs[1].grid()
        axs[1].legend(bbox_to_anchor=(.87, 1.6), loc='upper left', fontsize='xx-small');
        axs[1].set_title('RL method')
        axs[1].set_xlabel(r'IoU threshold $\tau$')
        #wandb.log({"validation/metrics": [wandb.Image(fig, caption="metrics")]})
        plt.close('all')
        '''

        splits, merges, are, arp, arr = self.clst_metric.dump()
        wandb.log({"validation/acc_reward": acc_reward})
        wandb.log({"validation/mAP": np.mean(map_scores)},
                  step=self.global_counter)
        wandb.log({"validation/UnderSegmVI": splits}, step=self.global_counter)
        wandb.log({"validation/OverSegmVI": merges}, step=self.global_counter)
        wandb.log({"validation/ARE": are}, step=self.global_counter)
        wandb.log({"validation/ARP": arp}, step=self.global_counter)
        wandb.log({"validation/ARR": arr}, step=self.global_counter)

        # do the lr sheduling
        self.optimizers.critic_shed.step(acc_reward)
        self.optimizers.actor_shed.step(acc_reward)

        if acc_reward > self.best_val_reward:
            self.best_val_reward = acc_reward
            wandb.run.summary["validation/acc_reward"] = acc_reward
            torch.save(
                self.model.state_dict(),
                os.path.join(wandb.run.dir, "best_checkpoint_agent.pth"))
        if self.cfg.verbose:
            print("\n###### finish validate ######\n", end='')

        label_cm = random_label_cmap(zeroth=1.0)
        label_cm.set_bad(alpha=0)

        for it, i in enumerate(self.cfg.store_indices):
            fig, axs = plt.subplots(
                2,
                4 if self.cfg.reward_function == "sub_graph_dice" else 5,
                sharex='col',
                figsize=(9, 5),
                sharey='row',
                gridspec_kw={
                    'hspace': 0,
                    'wspace': 0
                })
            axs[0, 0].imshow(ex_gts[it],
                             cmap=random_label_cmap(),
                             interpolation="none")
            axs[0, 0].set_title('gt', y=1.05, size=10)
            axs[0, 0].axis('off')
            if ex_raws[it].ndim == 3:
                if ex_raws[it].shape[-1] > 2:
                    axs[0, 1].imshow(ex_raws[it][..., :3], cmap="gray")
                else:
                    axs[0, 1].imshow(ex_raws[it][..., 0], cmap="gray")
            else:
                axs[1, 1].imshow(ex_raws[it], cmap="gray")

            axs[0, 1].set_title('raw image', y=1.05, size=10)
            axs[0, 1].axis('off')
            if ex_raws[it].ndim == 3:
                if ex_raws[it].shape[-1] > 1:
                    axs[0, 2].imshow(ex_raws[it][..., -1], cmap="gray")
                else:
                    axs[0, 2].imshow(ex_raws[it][..., 0], cmap="gray")
            else:
                axs[0, 2].imshow(ex_raws[it], cmap="gray")
            axs[0, 2].set_title('plantseg', y=1.05, size=10)
            axs[0, 2].axis('off')
            axs[0, 3].imshow(ex_sps[it],
                             cmap=random_label_cmap(),
                             interpolation="none")
            axs[0, 3].set_title('superpixels', y=1.05, size=10)
            axs[0, 3].axis('off')

            axs[1, 0].imshow(ex_feats[it])
            axs[1, 0].set_title('features', y=-0.15, size=10)
            axs[1, 0].axis('off')
            axs[1, 1].imshow(ex_n_emb[it])
            axs[1, 1].set_title('node embeddings', y=-0.15, size=10)
            axs[1, 1].axis('off')
            axs[1, 2].imshow(ex_emb[it])
            axs[1, 2].set_title('embeddings', y=-0.15, size=10)
            axs[1, 2].axis('off')
            axs[1, 3].imshow(ex_rl[it],
                             cmap=random_label_cmap(),
                             interpolation="none")
            axs[1, 3].set_title('prediction', y=-0.15, size=10)
            axs[1, 3].axis('off')

            if self.cfg.reward_function != "sub_graph_dice":
                frame_rew, scores_rew, bnd_mask = get_colored_edges_in_sseg(
                    ex_sps[it][None].float(), edge_ids[it].cpu(),
                    rewards[it].cpu())
                frame_act, scores_act, _ = get_colored_edges_in_sseg(
                    ex_sps[it][None].float(), edge_ids[it].cpu(),
                    1 - actions[it].cpu().squeeze())

                bnd_mask = torch.from_numpy(dilation(bnd_mask.cpu().numpy()))
                frame_rew = np.stack([
                    dilation(frame_rew.cpu().numpy()[..., i]) for i in range(3)
                ], -1)
                frame_act = np.stack([
                    dilation(frame_act.cpu().numpy()[..., i]) for i in range(3)
                ], -1)
                ex_rl[it] = ex_rl[it].squeeze().astype(np.float)
                ex_rl[it][bnd_mask] = np.nan

                axs[1, 4].imshow(frame_rew, interpolation="none")
                axs[1, 4].imshow(ex_rl[it],
                                 cmap=label_cm,
                                 alpha=0.8,
                                 interpolation="none")
                axs[1, 4].set_title("rewards", y=-0.2)
                axs[1, 4].axis('off')

                axs[0, 4].imshow(frame_act, interpolation="none")
                axs[0, 4].imshow(ex_rl[it],
                                 cmap=label_cm,
                                 alpha=0.8,
                                 interpolation="none")
                axs[0, 4].set_title("actions", y=1.05)
                axs[0, 4].axis('off')

            wandb.log(
                {
                    "validation/sample_" + str(i):
                    [wandb.Image(fig, caption="sample images")]
                },
                step=self.global_counter)
            plt.close('all')
コード例 #3
0
    def validate(self):
        """validates the prediction against the method of clustering the embedding space"""
        env = MulticutEmbeddingsEnv(self.fe_ext, self.cfg, self.device)
        if self.cfg.verbose:
            print("\n\n###### start validate ######", end='')
        self.model.eval()
        n_examples = len(self.val_dset)
        taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        rl_scores, keys = [], None
        ex_raws, ex_sps, ex_gts, ex_mc_gts, ex_embeds, ex_rl = [], [], [], [], [], []
        dloader = iter(
            DataLoader(self.val_dset,
                       batch_size=1,
                       shuffle=True,
                       pin_memory=True,
                       num_workers=0))
        acc_reward = 0
        for it in range(len(self.val_dset)):
            update_env_data(env,
                            dloader,
                            self.val_dset,
                            self.device,
                            with_gt_edges="sub_graph_dice"
                            in self.cfg.reward_function)
            env.reset()
            state = env.get_state()

            self.model_mtx.acquire()
            try:
                distr, _ = self.forwarder.forward(self.model,
                                                  state,
                                                  State,
                                                  self.device,
                                                  grad=False,
                                                  post_data=False)
            finally:
                self.model_mtx.release()
            action = torch.sigmoid(distr.loc)
            reward, state = env.execute_action(action,
                                               None,
                                               post_images=True,
                                               tau=0.0,
                                               train=False)
            acc_reward += reward[-1].item()
            if self.cfg.verbose:
                print(
                    f"\nstep: {it}; mean_loc: {round(distr.loc.mean().item(), 5)}; mean reward: {round(reward[-1].item(), 5)}",
                    end='')

            embeddings = env.embeddings[0].cpu().numpy()
            gt_seg = env.gt_seg[0].cpu().numpy()
            gt_mc = cm.prism(
                env.gt_soln[0].cpu() / env.gt_soln[0].max().item()
            ) if env.gt_edge_weights is not None else torch.zeros(
                env.raw.shape[-2:])
            rl_labels = env.current_soln.cpu().numpy()[0]

            if it < n_examples:
                ex_embeds.append(pca_project(embeddings, n_comps=3))
                ex_raws.append(env.raw[0].cpu().permute(1, 2, 0).squeeze())
                ex_sps.append(env.init_sp_seg[0].cpu())
                ex_mc_gts.append(gt_mc)
                ex_gts.append(gt_seg)
                ex_rl.append(rl_labels)

            _rl_scores = matching(gt_seg,
                                  rl_labels,
                                  thresh=taus,
                                  criterion='iou',
                                  report_matches=False)

            if it == 0:
                for tau_it in range(len(_rl_scores)):
                    rl_scores.append(
                        np.array(
                            list(
                                map(
                                    float,
                                    list(_rl_scores[tau_it]._asdict().values())
                                    [1:]))))
                keys = list(_rl_scores[0]._asdict().keys())[1:]
            else:
                for tau_it in range(len(_rl_scores)):
                    rl_scores[tau_it] += np.array(
                        list(
                            map(
                                float,
                                list(_rl_scores[tau_it]._asdict().values())
                                [1:])))

        div = np.ones_like(rl_scores[0])
        for i, key in enumerate(keys):
            if key not in ('fp', 'tp', 'fn'):
                div[i] = 10

        for tau_it in range(len(rl_scores)):
            rl_scores[tau_it] = dict(zip(keys, rl_scores[tau_it] / div))

        fig, axs = plt.subplots(1, 2, figsize=(10, 10))
        plt.subplots_adjust(hspace=.5)

        for m in ('precision', 'recall', 'accuracy', 'f1'):
            y = [s[m] for s in rl_scores]
            data = [[x, y] for (x, y) in zip(taus, y)]
            table = wandb.Table(data=data, columns=["IoU_threshold", m])
            wandb.log({
                "validation/" + m:
                wandb.plot.line(table,
                                "IoU_threshold",
                                m,
                                stroke=None,
                                title=m)
            })
            axs[0].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m)
        axs[0].set_ylabel('Metric value')
        axs[0].grid()
        axs[0].legend(bbox_to_anchor=(.8, 1.65),
                      loc='upper left',
                      fontsize='xx-small')
        axs[0].set_title('RL method')
        axs[0].set_xlabel(r'IoU threshold $\tau$')

        for m in ('fp', 'tp', 'fn'):
            y = [s[m] for s in rl_scores]
            data = [[x, y] for (x, y) in zip(taus, y)]
            table = wandb.Table(data=data, columns=["IoU_threshold", m])
            wandb.log({
                "validation/" + m:
                wandb.plot.line(table,
                                "IoU_threshold",
                                m,
                                stroke=None,
                                title=m)
            })
            axs[1].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m)
        axs[1].set_ylabel('Number #')
        axs[1].grid()
        axs[1].legend(bbox_to_anchor=(.87, 1.6),
                      loc='upper left',
                      fontsize='xx-small')
        axs[1].set_title('RL method')
        axs[1].set_xlabel(r'IoU threshold $\tau$')

        wandb.log(
            {"validation/metrics": [wandb.Image(fig, caption="metrics")]})
        wandb.log({"validation_reward": acc_reward})
        plt.close('all')
        if acc_reward > self.best_val_reward:
            self.best_val_reward = acc_reward
            wandb.run.summary["validation_reward"] = acc_reward
            torch.save(
                self.model.state_dict(),
                os.path.join(wandb.run.dir, "best_checkpoint_agent.pth"))
        if self.cfg.verbose:
            print("\n###### finish validate ######\n", end='')

        for i in range(n_examples):
            fig, axs = plt.subplots(2,
                                    3,
                                    sharex='col',
                                    sharey='row',
                                    gridspec_kw={
                                        'hspace': 0,
                                        'wspace': 0
                                    })
            axs[0, 0].imshow(ex_gts[i],
                             cmap=random_label_cmap(),
                             interpolation="none")
            axs[0, 0].set_title('gt')
            axs[0, 0].axis('off')
            if ex_raws[i].ndim == 3:
                axs[0, 1].imshow(ex_raws[i][..., 0])
            else:
                axs[0, 1].imshow(ex_raws[i])
            axs[0, 1].set_title('raw image')
            axs[0, 1].axis('off')
            axs[0, 2].imshow(ex_sps[i],
                             cmap=random_label_cmap(),
                             interpolation="none")
            axs[0, 2].set_title('superpixels')
            axs[0, 2].axis('off')
            axs[1, 0].imshow(ex_embeds[i])
            axs[1, 0].set_title('pc proj 1-3', y=-0.15)
            axs[1, 0].axis('off')
            if ex_raws[i].ndim == 3:
                if ex_raws[i].shape[-1] > 1:
                    axs[1, 1].imshow(ex_raws[i][..., 1])
                else:
                    axs[1, 1].imshow(ex_raws[i][..., 0])
            else:
                axs[1, 1].imshow(ex_raws[i])
            axs[1, 1].set_title('sp edge', y=-0.15)
            axs[1, 1].axis('off')
            axs[1, 2].imshow(ex_rl[i],
                             cmap=random_label_cmap(),
                             interpolation="none")
            axs[1, 2].set_title('prediction', y=-0.15)
            axs[1, 2].axis('off')
            wandb.log({
                "validation/samples":
                [wandb.Image(fig, caption="sample images")]
            })
            plt.close('all')
コード例 #4
0
def validate_and_compare_to_clustering(model, env, distance, device, cfg):
    """validates the prediction against the method of clustering the embedding space"""

    model.eval()
    offs = [[1, 0], [0, 1], [2, 0], [0, 2], [4, 0], [0, 4], [16, 0], [0, 16]]
    ex_raws, ex_sps, ex_gts, ex_mc_gts, ex_embeds, ex_clst, ex_clst_sp, ex_mcaff, ex_mc_embed, ex_rl, \
    ex_clst_graph_agglo= [], [], [], [], [], [], [], [], [], [], []
    dset = SpgDset(cfg.val_data_dir, dict_to_attrdict(cfg.patch_manager), dict_to_attrdict(cfg.val_data_keys), max(cfg.s_subgraph))
    dloader = iter(DataLoader(dset))
    acc_reward = 0
    forwarder = Forwarder()
    delta_dist = 0.4

    # segm_metric = AveragePrecision()
    clst_metric_rl = ClusterMetrics()
    # clst_metric = ClusterMetrics()
    metric_sp_gt = ClusterMetrics()
    # clst_metric_mcaff = ClusterMetrics()
    # clst_metric_mcembed = ClusterMetrics()
    # clst_metric_graphagglo = ClusterMetrics()
    sbd = SBD()

    # map_rl, map_embed, map_sp_gt, map_mcaff, map_mcembed, map_graphagglo = [], [], [], [], [], []
    sbd_rl, sbd_embed, sbd_sp_gt, sbd_mcaff, sbd_mcembed, sbd_graphagglo = [], [], [], [], [], []

    n_examples = len(dset)
    for it in range(n_examples):
        update_env_data(env, dloader, dset, device, with_gt_edges=False)
        env.reset()
        state = env.get_state()
        distr, _, _, _, _, node_features, embeddings = forwarder.forward(model, state, State,
                                                                              device,
                                                                              grad=False, post_data=False,
                                                                              get_node_feats=True,
                                                                              get_embeddings=True)
        action = torch.sigmoid(distr.loc)
        reward = env.execute_action(action, tau=0.0, train=False)
        acc_reward += reward[-2].item()

        embeds = embeddings[0].cpu()
        # node_features = node_features.cpu().numpy()
        rag = env.rags[0]
        edge_ids = rag.uvIds()
        gt_seg = env.gt_seg[0].cpu().numpy()
        # l2_embeddings = get_angles(embeds[None])[0]
        # l2_node_feats = get_angles(torch.from_numpy(node_features.T[None, ..., None])).squeeze().T.numpy()
        # clst_labels_kmeans = cluster_embeddings(l2_embeddings.permute((1, 2, 0)), len(np.unique(gt_seg)))
        # node_labels = cluster_embeddings(l2_node_feats, len(np.unique(gt_seg)))
        # clst_labels_sp_kmeans = elf.segmentation.features.project_node_labels_to_pixels(rag, node_labels).squeeze()

        # clst_labels_sp_graph_agglo = get_soln_graph_clustering(env.init_sp_seg, torch.from_numpy(edge_ids.astype(np.int)), torch.from_numpy(l2_node_feats), len(np.unique(gt_seg)))[0][0].numpy()
        # mc_labels_aff = env.get_current_soln(edge_weights=env.edge_features[:, 0]).cpu().numpy()[0]
        # ew_embedaffs = 1 - get_edge_features_1d(env.init_sp_seg[0].cpu().numpy(), offs, get_affinities_from_embeddings_2d(embeddings, offs, delta_dist, distance)[0].cpu().numpy())[0][:, 0]
        # mc_labels_embedding_aff = env.get_current_soln(edge_weights=torch.from_numpy(ew_embedaffs).to(device)).cpu().numpy()[0]
        rl_labels = env.current_soln.cpu().numpy()[0]

        ex_embeds.append(pca_project(embeds, n_comps=3))
        ex_raws.append(env.raw[0].cpu().permute(1, 2, 0).squeeze())
        # ex_sps.append(cm.prism(env.init_sp_seg[0].cpu() / env.init_sp_seg[0].max().item()))
        ex_sps.append(env.init_sp_seg[0].cpu())
        ex_mc_gts.append(project_overseg_to_seg(env.init_sp_seg[0], torch.from_numpy(gt_seg).to(device)).cpu().numpy())

        ex_gts.append(gt_seg)
        ex_rl.append(rl_labels)
        # ex_clst.append(clst_labels_kmeans)
        # ex_clst_sp.append(clst_labels_sp_kmeans)
        # ex_clst_graph_agglo.append(clst_labels_sp_graph_agglo)
        # ex_mcaff.append(mc_labels_aff)
        # ex_mc_embed.append(mc_labels_embedding_aff)

        # map_rl.append(segm_metric(rl_labels, gt_seg))
        sbd_rl.append(sbd(gt_seg, rl_labels))
        clst_metric_rl(rl_labels, gt_seg)

        # map_sp_gt.append(segm_metric(ex_mc_gts[-1], gt_seg))
        sbd_sp_gt.append(sbd(gt_seg, ex_mc_gts[-1]))
        metric_sp_gt(ex_mc_gts[-1], gt_seg)

        # map_embed.append(segm_metric(clst_labels_kmeans, gt_seg))
        # clst_metric(clst_labels_kmeans, gt_seg)

        # map_mcaff.append(segm_metric(mc_labels_aff, gt_seg))
        # sbd_mcaff.append(sbd(gt_seg, mc_labels_aff))
        # clst_metric_mcaff(mc_labels_aff, gt_seg)
        #
        # map_mcembed.append(segm_metric(mc_labels_embedding_aff, gt_seg))
        # sbd_mcembed.append(sbd(gt_seg, mc_labels_embedding_aff))
        # clst_metric_mcembed(mc_labels_embedding_aff, gt_seg)
        #
        # map_graphagglo.append(segm_metric(clst_labels_sp_graph_agglo, gt_seg))
        # sbd_graphagglo.append(sbd(gt_seg, clst_labels_sp_graph_agglo.astype(np.int)))
        # clst_metric_graphagglo(clst_labels_sp_graph_agglo.astype(np.int), gt_seg)

    print("\nSBD: ")
    print(f"sp gt       : {round(np.array(sbd_sp_gt).mean(), 4)}; {round(np.array(sbd_sp_gt).std(), 4)}")
    print(f"ours        : {round(np.array(sbd_rl).mean(), 4)}; {round(np.array(sbd_rl).std(), 4)}")
    # print(f"mc node     : {np.array(sbd_mcembed).mean()}")
    # print(f"mc embed    : {np.array(sbd_mcaff).mean()}")
    # print(f"graph agglo : {np.array(sbd_graphagglo).mean()}")

    # print("\nmAP: ")
    # print(f"sp gt       : {np.array(map_sp_gt).mean()}")
    # print(f"ours        : {np.array(map_rl).mean()}")
    # print(f"mc node     : {np.array(map_mcembed).mean()}")
    # print(f"mc embed    : {np.array(map_mcaff).mean()}")
    # print(f"graph agglo : {np.array(map_graphagglo).mean()}")
    #
    vi_rl_s, vi_rl_m, are_rl, arp_rl, arr_rl = clst_metric_rl.dump()
    vi_spgt_s, vi_spgt_m, are_spgt, arp_spgt, arr_spgt = metric_sp_gt.dump()
    # vi_mcaff_s, vi_mcaff_m, are_mcaff, arp_mcaff, arr_mcaff = clst_metric_mcaff.dump()
    # vi_mcembed_s, vi_mcembed_m, are_mcembed, arp_embed, arr_mcembed = clst_metric_mcembed.dump()
    # vi_graphagglo_s, vi_graphagglo_m, are_graphagglo, arp_graphagglo, arr_graphagglo = clst_metric_graphagglo.dump()
    #
    vi_rl_s_std, vi_rl_m_std, are_rl_std, arp_rl_std, arr_rl_std = clst_metric_rl.dump_std()
    vi_spgt_s_std, vi_spgt_m_std, are_spgt_std, arp_spgt_std, arr_spgt_std = metric_sp_gt.dump_std()

    print("\nVI merge: ")
    print(f"sp gt       : {round(vi_spgt_m, 4)}; {round(vi_spgt_m_std, 4)}")
    print(f"ours        : {round(vi_rl_m, 4)}; {round(vi_rl_m_std, 4)}")
    # print(f"mc affnties : {vi_mcaff_m}")
    # print(f"mc embed    : {vi_mcembed_m}")
    # print(f"graph agglo : {vi_graphagglo_m}")
    #
    print("\nVI split: ")
    print(f"sp gt       : {round(vi_spgt_s, 4)}; {round(vi_spgt_s_std, 4)}")
    print(f"ours        : {round(vi_rl_s, 4)}; {round(vi_rl_s_std, 4)}")
    # print(f"mc affnties : {vi_mcaff_s}")
    # print(f"mc embed    : {vi_mcembed_s}")
    # print(f"graph agglo : {vi_graphagglo_s}")
    #
    print("\nARE: ")
    print(f"sp gt       : {round(are_spgt, 4)}; {round(are_spgt_std, 4)}")
    print(f"ours        : {round(are_rl, 4)}; {round(are_rl_std, 4)}")
    # print(f"mc affnties : {are_mcaff}")
    # print(f"mc embed    : {are_mcembed}")
    # print(f"graph agglo : {are_graphagglo}")
    #
    print("\nARP: ")
    print(f"sp gt       : {round(arp_spgt, 4)}; {round(arp_spgt_std, 4)}")
    print(f"ours        : {round(arp_rl, 4)}; {round(arp_rl_std, 4)}")
    # print(f"mc affnties : {arp_mcaff}")
    # print(f"mc embed    : {arp_embed}")
    # print(f"graph agglo : {arp_graphagglo}")
    #
    print("\nARR: ")
    print(f"sp gt       : {round(arr_spgt, 4)}; {round(arr_spgt_std, 4)}")
    print(f"ours        : {round(arr_rl, 4)}; {round(arr_rl_std, 4)}")
    # print(f"mc affnties : {arr_mcaff}")
    # print(f"mc embed    : {arr_mcembed}")
    # print(f"graph agglo : {arr_graphagglo}")

    exit()
    for i in range(len(ex_gts)):
        fig, axs = plt.subplots(2, 4, figsize=(20, 13), sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0})
        axs[0, 0].imshow(ex_gts[i], cmap=random_label_cmap(), interpolation="none")
        axs[0, 0].set_title('gt')
        axs[0, 0].axis('off')
        axs[0, 1].imshow(ex_embeds[i])
        axs[0, 1].set_title('pc proj')
        axs[0, 1].axis('off')
        # axs[0, 2].imshow(ex_clst[i], cmap=random_label_cmap(), interpolation="none")
        # axs[0, 2].set_title('pix clst')
        # axs[0, 2].axis('off')
        axs[0, 2].imshow(ex_clst_graph_agglo[i], cmap=random_label_cmap(), interpolation="none")
        axs[0, 2].set_title('nagglo')
        axs[0, 2].axis('off')
        axs[0, 3].imshow(ex_mc_embed[i], cmap=random_label_cmap(), interpolation="none")
        axs[0, 3].set_title('mc embed')
        axs[0, 3].axis('off')
        axs[1, 0].imshow(ex_mc_gts[i], cmap=random_label_cmap(), interpolation="none")
        axs[1, 0].set_title('sp gt')
        axs[1, 0].axis('off')
        axs[1, 1].imshow(ex_sps[i], cmap=random_label_cmap(), interpolation="none")
        axs[1, 1].set_title('sp')
        axs[1, 1].axis('off')
        # axs[1, 2].imshow(ex_clst_sp[i], cmap=random_label_cmap(), interpolation="none")
        # axs[1, 2].set_title('sp clst')
        # axs[1, 2].axis('off')
        axs[1, 2].imshow(ex_rl[i], cmap=random_label_cmap(), interpolation="none")
        axs[1, 2].set_title('ours')
        axs[1, 2].axis('off')
        axs[1, 3].imshow(ex_mcaff[i], cmap=random_label_cmap(), interpolation="none")
        axs[1, 3].set_title('mc aff')
        axs[1, 3].axis('off')
        plt.show()
        # wandb.log({"validation/samples": [wandb.Image(fig, caption="sample images")]})
        plt.close('all')