Beispiel #1
0
    def __init__(self, cfg, StateClass, distance, device, with_temp=True):
        super(Agent, self).__init__()
        self.cfg = cfg
        self.std_bounds = self.cfg.std_bounds
        self.mu_bounds = self.cfg.mu_bounds
        self.device = device
        self.StateClass = StateClass
        self.distance = distance
        self.offs = [[1, 0], [0, 1], [2, 0], [0, 2], [4, 0], [0, 4], [16, 0], [0, 16]]

        dim_embed = self.cfg.dim_embeddings + (3 * int(cfg.use_handcrafted_features))

        self.fe_ext = FeExtractor(dict_to_attrdict(self.cfg.backbone), self.distance, cfg.fe_delta_dist, self.device)
        if "fe_model_name" in self.cfg:
            self.fe_ext.embed_model.load_state_dict(torch.load(self.cfg.fe_model_name))
        self.fe_ext.cuda(self.device)
        if self.cfg.fe_optimization:
            for param in self.fe_ext.parameters():
                param.requires_grad = False

        self.fe_ext_tgt = FeExtractor(dict_to_attrdict(self.cfg.backbone), self.distance, cfg.fe_delta_dist, self.device)
        if "fe_model_name" in self.cfg:
            self.fe_ext_tgt.embed_model.load_state_dict(torch.load(self.cfg.fe_model_name))
        self.fe_ext_tgt.cuda(self.device)
        for param in self.fe_ext_tgt.parameters():
            param.requires_grad = False

        self.actor = PolicyNet(dim_embed, 2, cfg.gnn_n_hl, cfg.gnn_size_hl, distance, device, False, cfg.gnn_act_depth,
                               cfg.gnn_act_norm_inp, cfg.n_init_edge_feat)
        self.critic = QValueNet(self.cfg.s_subgraph, dim_embed, 1, 1, cfg.gnn_n_hl, cfg.gnn_size_hl, distance, device,
                                False, cfg.gnn_crit_depth, cfg.gnn_crit_norm_inp, cfg.n_init_edge_feat)
        self.critic_tgt = QValueNet(self.cfg.s_subgraph, dim_embed, 1, 1, cfg.gnn_n_hl, cfg.gnn_size_hl, distance,
                                    device, False, cfg.gnn_crit_depth, cfg.gnn_crit_norm_inp, cfg.n_init_edge_feat)

        self.log_alpha = torch.tensor([np.log(self.cfg.init_temperature)] * len(self.cfg.s_subgraph)).to(device)
        if with_temp:
            self.log_alpha.requires_grad = True
Beispiel #2
0
class Agent(torch.nn.Module):
    def __init__(self, cfg, StateClass, distance, device, with_temp=True):
        super(Agent, self).__init__()
        self.cfg = cfg
        self.std_bounds = self.cfg.std_bounds
        self.mu_bounds = self.cfg.mu_bounds
        self.device = device
        self.StateClass = StateClass
        self.distance = distance
        self.offs = [[1, 0], [0, 1], [2, 0], [0, 2], [4, 0], [0, 4], [16, 0], [0, 16]]

        dim_embed = self.cfg.dim_embeddings + (3 * int(cfg.use_handcrafted_features))

        self.fe_ext = FeExtractor(dict_to_attrdict(self.cfg.backbone), self.distance, cfg.fe_delta_dist, self.device)
        if "fe_model_name" in self.cfg:
            self.fe_ext.embed_model.load_state_dict(torch.load(self.cfg.fe_model_name))
        self.fe_ext.cuda(self.device)
        if self.cfg.fe_optimization:
            for param in self.fe_ext.parameters():
                param.requires_grad = False

        self.fe_ext_tgt = FeExtractor(dict_to_attrdict(self.cfg.backbone), self.distance, cfg.fe_delta_dist, self.device)
        if "fe_model_name" in self.cfg:
            self.fe_ext_tgt.embed_model.load_state_dict(torch.load(self.cfg.fe_model_name))
        self.fe_ext_tgt.cuda(self.device)
        for param in self.fe_ext_tgt.parameters():
            param.requires_grad = False

        self.actor = PolicyNet(dim_embed, 2, cfg.gnn_n_hl, cfg.gnn_size_hl, distance, device, False, cfg.gnn_act_depth,
                               cfg.gnn_act_norm_inp, cfg.n_init_edge_feat)
        self.critic = QValueNet(self.cfg.s_subgraph, dim_embed, 1, 1, cfg.gnn_n_hl, cfg.gnn_size_hl, distance, device,
                                False, cfg.gnn_crit_depth, cfg.gnn_crit_norm_inp, cfg.n_init_edge_feat)
        self.critic_tgt = QValueNet(self.cfg.s_subgraph, dim_embed, 1, 1, cfg.gnn_n_hl, cfg.gnn_size_hl, distance,
                                    device, False, cfg.gnn_crit_depth, cfg.gnn_crit_norm_inp, cfg.n_init_edge_feat)

        self.log_alpha = torch.tensor([np.log(self.cfg.init_temperature)] * len(self.cfg.s_subgraph)).to(device)
        if with_temp:
            self.log_alpha.requires_grad = True

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @alpha.setter
    def alpha(self, value):
        self.log_alpha = torch.tensor(np.log(value)).to(self.device)
        self.log_alpha.requires_grad = True

    def get_features(self, model, state, grad):
        with torch.set_grad_enabled(grad):
            embeddings = model(state.raw)
        embed_affs = get_affinities_from_embeddings_2d(embeddings.detach(), self.offs, model.delta_dist, model.distance)
        embed_dists = [1 - get_edge_features_1d(sp.cpu().numpy(), self.offs, embed_aff.cpu().numpy())[0][:, 0] for sp, embed_aff in zip(state.sp_seg, embed_affs)]
        embed_dists = [torch.from_numpy(embed_dist).to(model.device) for embed_dist in embed_dists]
        embed_dists = torch.cat(embed_dists, 0)[:, None]
        node_features = model.get_mean_sp_embedding_sparse(embeddings[:, :, None], state.sp_seg[:, None]).T

        node_features = torch.cat((node_features, state.sp_feat), 1) if self.cfg.use_handcrafted_features else node_features
        edge_features = torch.cat((embed_dists, state.edge_feats), 1).float()

        return node_features, edge_features, embeddings

    def forward(self, state, actions, expl_action, post_data, policy_opt, return_node_features, get_embeddings):
        state = self.StateClass(*state)

        edge_index = torch.cat([state.edge_ids, torch.stack([state.edge_ids[1], state.edge_ids[0]], dim=0)], dim=1)  # gcnn expects two directed edges for one undirected edge

        if actions is None:
            node_features_tgt, edge_features_tgt, embeddings_tgt = self.get_features(self.fe_ext_tgt, state, grad=False)
            with torch.set_grad_enabled(policy_opt):
                out, side_loss = self.actor(node_features_tgt, edge_index, edge_features_tgt, state.gt_edge_weights, post_data)
                mu, std = out.chunk(2, dim=-1)
                mu, std = mu.contiguous(), std.contiguous()

                if post_data:
                    wandb.log({"logits/loc": wandb.Histogram(mu.view(-1).detach().cpu().numpy())})
                    wandb.log({"logits/scale": wandb.Histogram(std.view(-1).detach().cpu().numpy())})

                std = self.std_bounds[0] + 0.5 * (self.std_bounds[1] - self.std_bounds[0]) * (torch.tanh(std) + 1)
                mu = self.mu_bounds[0] + 0.5 * (self.mu_bounds[1] - self.mu_bounds[0]) * (torch.tanh(mu) + 1)

                dist = SigmNorm(mu, std)
                if expl_action is None:
                    actions = dist.rsample()
                else:
                    z = ((expl_action - mu) / std).detach()
                    actions = mu + z * std

            q, sl = self.critic_tgt(node_features_tgt, actions, edge_index, edge_features_tgt, state.subgraph_indices,
                                    state.sep_subgraphs, state.gt_edge_weights, post_data)
            side_loss = (side_loss + sl) / 2
            if policy_opt:
                return dist, q, actions, side_loss
            else:
                # this means either exploration or critic opt
                if return_node_features:
                    if get_embeddings:
                        return dist, q, actions, None, side_loss, node_features_tgt.detach(), embeddings_tgt.detach()
                    return dist, q, actions, None, side_loss, node_features_tgt.detach()
                if get_embeddings:
                    return dist, q, actions, None, side_loss, embeddings_tgt.detach()
                return dist, q, actions, None, side_loss

        node_features, edge_features, _ = self.get_features(self.fe_ext, state, grad=not policy_opt)
        q, side_loss = self.critic(node_features, actions, edge_index, edge_features, state.subgraph_indices,
                                   state.sep_subgraphs, state.gt_edge_weights, post_data)
        return q, side_loss
Beispiel #3
0
    def __init__(self, cfg, global_count):
        super(AgentSacTrainerObjLvlReward, self).__init__()
        assert torch.cuda.device_count() == 1
        self.device = torch.device("cuda:0")
        torch.cuda.set_device(self.device)
        torch.set_default_tensor_type(torch.FloatTensor)

        self.cfg = cfg
        self.global_count = global_count
        self.memory = TransitionData_ts(capacity=self.cfg.mem_size)
        self.best_val_reward = -np.inf
        if self.cfg.distance == 'cosine':
            self.distance = CosineDistance()
        else:
            self.distance = L2Distance()

        self.fe_ext = FeExtractor(dict_to_attrdict(self.cfg.backbone),
                                  self.distance, cfg.fe_delta_dist,
                                  self.device)
        self.fe_ext.embed_model.load_state_dict(
            torch.load(self.cfg.fe_model_name))
        self.fe_ext.cuda(self.device)

        self.model = Agent(self.cfg, State, self.distance, self.device)
        wandb.watch(self.model)
        self.model.cuda(self.device)
        self.model_mtx = Lock()

        MovSumLosses = namedtuple('mov_avg_losses',
                                  ('actor', 'critic', 'temperature'))
        Scalers = namedtuple('Scalers', ('critic', 'actor'))
        OptimizerContainer = namedtuple(
            'OptimizerContainer', ('actor', 'critic', 'temperature',
                                   'actor_shed', 'critic_shed', 'temp_shed'))
        actor_optimizer = torch.optim.Adam(self.model.actor.parameters(),
                                           lr=self.cfg.actor_lr)
        critic_optimizer = torch.optim.Adam(self.model.critic.parameters(),
                                            lr=self.cfg.critic_lr)
        temp_optimizer = torch.optim.Adam([self.model.log_alpha],
                                          lr=self.cfg.alpha_lr)

        lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched)
        bw = lr_sched_cfg.mov_avg_bandwidth
        off = lr_sched_cfg.mov_avg_offset
        weights = np.linspace(lr_sched_cfg.weight_range[0],
                              lr_sched_cfg.weight_range[1], bw)
        weights = weights / weights.sum()  # make them sum up to one
        shed = lr_sched_cfg.torch_sched

        self.mov_sum_losses = MovSumLosses(
            RunningAverage(weights, band_width=bw, offset=off),
            RunningAverage(weights, band_width=bw, offset=off),
            RunningAverage(weights, band_width=bw, offset=off))
        self.optimizers = OptimizerContainer(
            actor_optimizer, critic_optimizer, temp_optimizer, *[
                ReduceLROnPlateau(opt,
                                  patience=shed.patience,
                                  threshold=shed.threshold,
                                  min_lr=shed.min_lr,
                                  factor=shed.factor)
                for opt in (actor_optimizer, critic_optimizer, temp_optimizer)
            ])
        self.scalers = Scalers(torch.cuda.amp.GradScaler(),
                               torch.cuda.amp.GradScaler())
        self.forwarder = Forwarder()

        if self.cfg.agent_model_name != "":
            self.model.load_state_dict(torch.load(self.cfg.agent_model_name))
        # if "policy_warmup" in self.cfg and self.cfg.agent_model_name == "":
        #     supervised_policy_pretraining(self.model, self.env, self.cfg, device=self.device)
        #     torch.save(self.model.state_dict(), os.path.join(wandb.run.dir, "sv_pretrained_policy_agent.pth"))

        # finished with prepping
        for param in self.fe_ext.parameters():
            param.requires_grad = False

        self.train_dset = SpgDset(self.cfg.data_dir,
                                  dict_to_attrdict(self.cfg.patch_manager),
                                  dict_to_attrdict(self.cfg.data_keys))
        self.val_dset = SpgDset(self.cfg.val_data_dir,
                                dict_to_attrdict(self.cfg.patch_manager),
                                dict_to_attrdict(self.cfg.data_keys))
Beispiel #4
0
class AgentSacTrainerObjLvlReward(object):
    def __init__(self, cfg, global_count):
        super(AgentSacTrainerObjLvlReward, self).__init__()
        assert torch.cuda.device_count() == 1
        self.device = torch.device("cuda:0")
        torch.cuda.set_device(self.device)
        torch.set_default_tensor_type(torch.FloatTensor)

        self.cfg = cfg
        self.global_count = global_count
        self.memory = TransitionData_ts(capacity=self.cfg.mem_size)
        self.best_val_reward = -np.inf
        if self.cfg.distance == 'cosine':
            self.distance = CosineDistance()
        else:
            self.distance = L2Distance()

        self.fe_ext = FeExtractor(dict_to_attrdict(self.cfg.backbone),
                                  self.distance, cfg.fe_delta_dist,
                                  self.device)
        self.fe_ext.embed_model.load_state_dict(
            torch.load(self.cfg.fe_model_name))
        self.fe_ext.cuda(self.device)

        self.model = Agent(self.cfg, State, self.distance, self.device)
        wandb.watch(self.model)
        self.model.cuda(self.device)
        self.model_mtx = Lock()

        MovSumLosses = namedtuple('mov_avg_losses',
                                  ('actor', 'critic', 'temperature'))
        Scalers = namedtuple('Scalers', ('critic', 'actor'))
        OptimizerContainer = namedtuple(
            'OptimizerContainer', ('actor', 'critic', 'temperature',
                                   'actor_shed', 'critic_shed', 'temp_shed'))
        actor_optimizer = torch.optim.Adam(self.model.actor.parameters(),
                                           lr=self.cfg.actor_lr)
        critic_optimizer = torch.optim.Adam(self.model.critic.parameters(),
                                            lr=self.cfg.critic_lr)
        temp_optimizer = torch.optim.Adam([self.model.log_alpha],
                                          lr=self.cfg.alpha_lr)

        lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched)
        bw = lr_sched_cfg.mov_avg_bandwidth
        off = lr_sched_cfg.mov_avg_offset
        weights = np.linspace(lr_sched_cfg.weight_range[0],
                              lr_sched_cfg.weight_range[1], bw)
        weights = weights / weights.sum()  # make them sum up to one
        shed = lr_sched_cfg.torch_sched

        self.mov_sum_losses = MovSumLosses(
            RunningAverage(weights, band_width=bw, offset=off),
            RunningAverage(weights, band_width=bw, offset=off),
            RunningAverage(weights, band_width=bw, offset=off))
        self.optimizers = OptimizerContainer(
            actor_optimizer, critic_optimizer, temp_optimizer, *[
                ReduceLROnPlateau(opt,
                                  patience=shed.patience,
                                  threshold=shed.threshold,
                                  min_lr=shed.min_lr,
                                  factor=shed.factor)
                for opt in (actor_optimizer, critic_optimizer, temp_optimizer)
            ])
        self.scalers = Scalers(torch.cuda.amp.GradScaler(),
                               torch.cuda.amp.GradScaler())
        self.forwarder = Forwarder()

        if self.cfg.agent_model_name != "":
            self.model.load_state_dict(torch.load(self.cfg.agent_model_name))
        # if "policy_warmup" in self.cfg and self.cfg.agent_model_name == "":
        #     supervised_policy_pretraining(self.model, self.env, self.cfg, device=self.device)
        #     torch.save(self.model.state_dict(), os.path.join(wandb.run.dir, "sv_pretrained_policy_agent.pth"))

        # finished with prepping
        for param in self.fe_ext.parameters():
            param.requires_grad = False

        self.train_dset = SpgDset(self.cfg.data_dir,
                                  dict_to_attrdict(self.cfg.patch_manager),
                                  dict_to_attrdict(self.cfg.data_keys))
        self.val_dset = SpgDset(self.cfg.val_data_dir,
                                dict_to_attrdict(self.cfg.patch_manager),
                                dict_to_attrdict(self.cfg.data_keys))

    def validate(self):
        """validates the prediction against the method of clustering the embedding space"""
        env = MulticutEmbeddingsEnv(self.fe_ext, self.cfg, self.device)
        if self.cfg.verbose:
            print("\n\n###### start validate ######", end='')
        self.model.eval()
        n_examples = len(self.val_dset)
        taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        rl_scores, keys = [], None
        ex_raws, ex_sps, ex_gts, ex_mc_gts, ex_embeds, ex_rl = [], [], [], [], [], []
        dloader = iter(
            DataLoader(self.val_dset,
                       batch_size=1,
                       shuffle=True,
                       pin_memory=True,
                       num_workers=0))
        acc_reward = 0
        for it in range(len(self.val_dset)):
            update_env_data(env,
                            dloader,
                            self.val_dset,
                            self.device,
                            with_gt_edges="sub_graph_dice"
                            in self.cfg.reward_function)
            env.reset()
            state = env.get_state()

            self.model_mtx.acquire()
            try:
                distr, _ = self.forwarder.forward(self.model,
                                                  state,
                                                  State,
                                                  self.device,
                                                  grad=False,
                                                  post_data=False)
            finally:
                self.model_mtx.release()
            action = torch.sigmoid(distr.loc)
            reward, state = env.execute_action(action,
                                               None,
                                               post_images=True,
                                               tau=0.0,
                                               train=False)
            acc_reward += reward[-1].item()
            if self.cfg.verbose:
                print(
                    f"\nstep: {it}; mean_loc: {round(distr.loc.mean().item(), 5)}; mean reward: {round(reward[-1].item(), 5)}",
                    end='')

            embeddings = env.embeddings[0].cpu().numpy()
            gt_seg = env.gt_seg[0].cpu().numpy()
            gt_mc = cm.prism(
                env.gt_soln[0].cpu() / env.gt_soln[0].max().item()
            ) if env.gt_edge_weights is not None else torch.zeros(
                env.raw.shape[-2:])
            rl_labels = env.current_soln.cpu().numpy()[0]

            if it < n_examples:
                ex_embeds.append(pca_project(embeddings, n_comps=3))
                ex_raws.append(env.raw[0].cpu().permute(1, 2, 0).squeeze())
                ex_sps.append(env.init_sp_seg[0].cpu())
                ex_mc_gts.append(gt_mc)
                ex_gts.append(gt_seg)
                ex_rl.append(rl_labels)

            _rl_scores = matching(gt_seg,
                                  rl_labels,
                                  thresh=taus,
                                  criterion='iou',
                                  report_matches=False)

            if it == 0:
                for tau_it in range(len(_rl_scores)):
                    rl_scores.append(
                        np.array(
                            list(
                                map(
                                    float,
                                    list(_rl_scores[tau_it]._asdict().values())
                                    [1:]))))
                keys = list(_rl_scores[0]._asdict().keys())[1:]
            else:
                for tau_it in range(len(_rl_scores)):
                    rl_scores[tau_it] += np.array(
                        list(
                            map(
                                float,
                                list(_rl_scores[tau_it]._asdict().values())
                                [1:])))

        div = np.ones_like(rl_scores[0])
        for i, key in enumerate(keys):
            if key not in ('fp', 'tp', 'fn'):
                div[i] = 10

        for tau_it in range(len(rl_scores)):
            rl_scores[tau_it] = dict(zip(keys, rl_scores[tau_it] / div))

        fig, axs = plt.subplots(1, 2, figsize=(10, 10))
        plt.subplots_adjust(hspace=.5)

        for m in ('precision', 'recall', 'accuracy', 'f1'):
            y = [s[m] for s in rl_scores]
            data = [[x, y] for (x, y) in zip(taus, y)]
            table = wandb.Table(data=data, columns=["IoU_threshold", m])
            wandb.log({
                "validation/" + m:
                wandb.plot.line(table,
                                "IoU_threshold",
                                m,
                                stroke=None,
                                title=m)
            })
            axs[0].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m)
        axs[0].set_ylabel('Metric value')
        axs[0].grid()
        axs[0].legend(bbox_to_anchor=(.8, 1.65),
                      loc='upper left',
                      fontsize='xx-small')
        axs[0].set_title('RL method')
        axs[0].set_xlabel(r'IoU threshold $\tau$')

        for m in ('fp', 'tp', 'fn'):
            y = [s[m] for s in rl_scores]
            data = [[x, y] for (x, y) in zip(taus, y)]
            table = wandb.Table(data=data, columns=["IoU_threshold", m])
            wandb.log({
                "validation/" + m:
                wandb.plot.line(table,
                                "IoU_threshold",
                                m,
                                stroke=None,
                                title=m)
            })
            axs[1].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m)
        axs[1].set_ylabel('Number #')
        axs[1].grid()
        axs[1].legend(bbox_to_anchor=(.87, 1.6),
                      loc='upper left',
                      fontsize='xx-small')
        axs[1].set_title('RL method')
        axs[1].set_xlabel(r'IoU threshold $\tau$')

        wandb.log(
            {"validation/metrics": [wandb.Image(fig, caption="metrics")]})
        wandb.log({"validation_reward": acc_reward})
        plt.close('all')
        if acc_reward > self.best_val_reward:
            self.best_val_reward = acc_reward
            wandb.run.summary["validation_reward"] = acc_reward
            torch.save(
                self.model.state_dict(),
                os.path.join(wandb.run.dir, "best_checkpoint_agent.pth"))
        if self.cfg.verbose:
            print("\n###### finish validate ######\n", end='')

        for i in range(n_examples):
            fig, axs = plt.subplots(2,
                                    3,
                                    sharex='col',
                                    sharey='row',
                                    gridspec_kw={
                                        'hspace': 0,
                                        'wspace': 0
                                    })
            axs[0, 0].imshow(ex_gts[i],
                             cmap=random_label_cmap(),
                             interpolation="none")
            axs[0, 0].set_title('gt')
            axs[0, 0].axis('off')
            if ex_raws[i].ndim == 3:
                axs[0, 1].imshow(ex_raws[i][..., 0])
            else:
                axs[0, 1].imshow(ex_raws[i])
            axs[0, 1].set_title('raw image')
            axs[0, 1].axis('off')
            axs[0, 2].imshow(ex_sps[i],
                             cmap=random_label_cmap(),
                             interpolation="none")
            axs[0, 2].set_title('superpixels')
            axs[0, 2].axis('off')
            axs[1, 0].imshow(ex_embeds[i])
            axs[1, 0].set_title('pc proj 1-3', y=-0.15)
            axs[1, 0].axis('off')
            if ex_raws[i].ndim == 3:
                if ex_raws[i].shape[-1] > 1:
                    axs[1, 1].imshow(ex_raws[i][..., 1])
                else:
                    axs[1, 1].imshow(ex_raws[i][..., 0])
            else:
                axs[1, 1].imshow(ex_raws[i])
            axs[1, 1].set_title('sp edge', y=-0.15)
            axs[1, 1].axis('off')
            axs[1, 2].imshow(ex_rl[i],
                             cmap=random_label_cmap(),
                             interpolation="none")
            axs[1, 2].set_title('prediction', y=-0.15)
            axs[1, 2].axis('off')
            wandb.log({
                "validation/samples":
                [wandb.Image(fig, caption="sample images")]
            })
            plt.close('all')

    def update_critic(self, obs, action, reward):
        self.optimizers.critic.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):
            current_Q1, current_Q2 = self.forwarder.forward(self.model,
                                                            obs,
                                                            State,
                                                            self.device,
                                                            actions=action)

            target_Q = reward[0]
            target_Q = target_Q.detach()

            critic_loss = F.mse_loss(current_Q1.squeeze(1),
                                     target_Q) + F.mse_loss(
                                         current_Q2.squeeze(1), target_Q)

        self.scalers.critic.scale(critic_loss).backward()
        self.scalers.critic.step(self.optimizers.critic)
        self.scalers.critic.update()

        return critic_loss.item(), reward[0].mean()

    def update_actor_and_alpha(self, obs, reward, expl_action):
        self.optimizers.actor.zero_grad()
        self.optimizers.temperature.zero_grad()
        obj_edge_mask_actor = obs.obj_edge_mask_actor.to(self.device)
        with torch.cuda.amp.autocast(enabled=True):
            distribution, actor_Q1, actor_Q2, action, side_loss = self.forwarder.forward(
                self.model,
                obs,
                State,
                self.device,
                expl_action=expl_action,
                policy_opt=True)
            obj_n_edges = obj_edge_mask_actor.sum(1)
            log_prob = distribution.log_prob(action)
            actor_loss = torch.tensor([0.0], device=actor_Q1[0].device)
            alpha_loss = torch.tensor([0.0], device=actor_Q1[0].device)

            actor_Q = torch.min(actor_Q1, actor_Q2)
            obj_log_prob = (log_prob[None] *
                            obj_edge_mask_actor[..., None]).sum(1)
            obj_entropy = (
                (1 / 2 * (1 +
                          (2 * np.pi * distribution.scale**2).log()))[None] *
                obj_edge_mask_actor[..., None]).sum(1).squeeze(1)

            loss = (self.model.alpha.detach() * obj_log_prob - actor_Q).mean()
            actor_loss = actor_loss + loss

            actor_loss = actor_loss + self.cfg.side_loss_weight * side_loss

            min_entropy = (
                self.cfg.entropy_range[1] - self.cfg.entropy_range[0]) * (
                    (1.5 - reward[0]) / 1.5) + self.cfg.entropy_range[0]

            min_entropy = min_entropy.to(self.model.alpha.device).squeeze()
            entropy = obj_entropy.detach(
            ) if self.cfg.use_closed_form_entropy else -obj_log_prob.detach()
            alpha_loss = alpha_loss + (self.model.alpha *
                                       (entropy -
                                        (obj_n_edges * min_entropy))).mean()

        self.scalers.actor.scale(actor_loss).backward()
        self.scalers.actor.scale(alpha_loss).backward()
        self.scalers.actor.step(self.optimizers.actor)
        self.scalers.actor.step(self.optimizers.temperature)
        self.scalers.actor.update()

        return actor_loss.item(), alpha_loss.item(), min_entropy.mean().item(
        ), distribution.loc.mean().item()

    def _step(self, step):
        actor_loss, alpha_loss, min_entropy, loc_mean = None, None, None, None

        (obs, action, reward), sample_idx = self.memory.sample()
        action = action.to(self.device)
        for i in range(len(reward)):
            reward[i] = reward[i].to(self.device)
        critic_loss, mean_reward = self.update_critic(obs, action, reward)
        self.memory.report_sample_loss(critic_loss + mean_reward, sample_idx)
        self.mov_sum_losses.critic.apply(critic_loss)
        # self.optimizers.critic_shed.step(self.mov_sum_losses.critic.avg)
        wandb.log({"loss/critic": critic_loss})

        if self.cfg.actor_update_after < step and step % self.cfg.actor_update_frequency == 0:
            actor_loss, alpha_loss, min_entropy, loc_mean = self.update_actor_and_alpha(
                obs, reward, action)
            self.mov_sum_losses.actor.apply(actor_loss)
            self.mov_sum_losses.temperature.apply(alpha_loss)
            # self.optimizers.actor_shed.step(self.mov_sum_losses.actor.avg)
            # self.optimizers.temp_shed.step(self.mov_sum_losses.actor.avg)
            wandb.log({"loss/actor": actor_loss})
            wandb.log({"loss/alpha": alpha_loss})

        if step % self.cfg.post_stats_frequency == 0:
            if min_entropy != "nl":
                wandb.log({"min_entropy": min_entropy})
            wandb.log({"mov_avg/critic": self.mov_sum_losses.critic.avg})
            wandb.log({"mov_avg/actor": self.mov_sum_losses.actor.avg})
            wandb.log(
                {"mov_avg/temperature": self.mov_sum_losses.temperature.avg})
            wandb.log({
                "lr/critic":
                self.optimizers.critic_shed.optimizer.param_groups[0]['lr']
            })
            wandb.log({
                "lr/actor":
                self.optimizers.actor_shed.optimizer.param_groups[0]['lr']
            })
            wandb.log({
                "lr/temperature":
                self.optimizers.temp_shed.optimizer.param_groups[0]['lr']
            })

        if step % self.cfg.critic_target_update_frequency == 0:
            soft_update_params(self.model.critic, self.model.critic_tgt,
                               self.cfg.critic_tau)

        return [critic_loss, actor_loss, alpha_loss, loc_mean]

    def train_until_finished(self):
        while self.global_count.value() <= self.cfg.T_max + self.cfg.mem_size:
            self.model_mtx.acquire()
            try:
                stats = [[], [], [], []]
                for i in range(self.cfg.n_updates_per_step):
                    _stats = self._step(self.global_count.value())
                    [s.append(_s) for s, _s in zip(stats, _stats)]
                for j in range(len(stats)):
                    if any([_s is None for _s in stats[j]]):
                        stats[j] = "nl"
                    else:
                        stats[j] = round(
                            sum(stats[j]) / self.cfg.n_updates_per_step, 5)

                if self.cfg.verbose:
                    print(
                        f"step: {self.global_count.value()}; mean_loc: {stats[-1]}; n_explorer_steps {self.memory.push_count}",
                        end="")
                    print(f"; cl: {stats[0]}; acl: {stats[1]}; al: {stats[3]}")
            finally:
                self.model_mtx.release()
                self.global_count.increment()
                self.memory.reset_push_count()
            if self.global_count.value() % self.cfg.validatoin_freq == 0:
                self.validate()

    # Acts and trains model
    def train_and_explore(self, rn):
        self.global_count.reset()

        set_seed_everywhere(rn)
        wandb.config.random_seed = rn
        if self.cfg.verbose:
            print('###### start training ######')
            print('Running on device: ', self.device)
            print('found ', self.train_dset.length, " training data patches")
            print('found ', self.val_dset.length, "validation data patches")
            print('training with seed: ' + str(rn))
        explorers = []
        for i in range(self.cfg.n_explorers):
            explorers.append(threading.Thread(target=self.explore))
        [explorer.start() for explorer in explorers]

        self.memory.is_full_event.wait()
        trainer = threading.Thread(target=self.train_until_finished)
        trainer.start()

        trainer.join()
        self.global_count.set(self.cfg.T_max + self.cfg.mem_size + 4)
        [explorer.join() for explorer in explorers]
        self.memory.clear()
        del self.memory
        torch.save(self.model.state_dict(),
                   os.path.join(wandb.run.dir, "last_checkpoint_agent.pth"))
        if self.cfg.verbose:
            print('\n\n###### training finished ######')
        return

    def explore(self):
        env = MulticutEmbeddingsEnv(self.fe_ext, self.cfg, self.device)
        tau = 1
        while self.global_count.value() <= self.cfg.T_max + self.cfg.mem_size:
            dloader = iter(
                DataLoader(self.train_dset,
                           batch_size=self.cfg.batch_size,
                           shuffle=True,
                           pin_memory=True,
                           num_workers=0))
            for iteration in range(
                (len(self.train_dset) // self.cfg.batch_size) *
                    self.cfg.data_update_frequency):
                if iteration % self.cfg.data_update_frequency == 0:
                    update_env_data(env,
                                    dloader,
                                    self.train_dset,
                                    self.device,
                                    with_gt_edges="sub_graph_dice"
                                    in self.cfg.reward_function)
                env.reset()
                state = env.get_state()

                if not self.memory.is_full():
                    action = torch.rand((env.edge_ids.shape[-1], 1),
                                        device=self.device)
                else:
                    self.model_mtx.acquire()
                    try:
                        distr, action = self.forwarder.forward(self.model,
                                                               state,
                                                               State,
                                                               self.device,
                                                               grad=False)
                    finally:
                        self.model_mtx.release()
                reward, state = env.execute_action(action, tau=max(0, tau))
                for i in range(len(reward)):
                    reward[i] = reward[i].cpu()

                self.memory.push(state_to_cpu(state, State), action.cpu(),
                                 reward)
                if self.global_count.value(
                ) > self.cfg.T_max + self.cfg.mem_size:
                    break
        return
Beispiel #5
0
def preprocess_data():
    gauss_kernel = GaussianSmoothing(1, 5, 3, device="cpu")
    distance = CosineDistance()
    device = "cuda:0"
    fe_model_name = "/g/kreshuk/hilt/storage/leptin_data_nets/best_val_model.pth"

    model = FeExtractor(backbone, distance, device)
    model.embed_model.load_state_dict(torch.load(fe_model_name))
    model.cuda("cuda:0")

    offs = [[1, 0], [0, 1], [2, 0], [0, 2], [4, 0], [0, 4], [8, 0], [0, 8],
            [16, 0], [0, 16]]
    for j, dir in enumerate([tgtdir_val]):
        pix_dir = os.path.join(dir, 'pix_data')
        graph_dir = os.path.join(dir, 'graph_data')
        new_pix_dir = os.path.join(dir, "bg_masked_data", 'pix_data')
        new_graph_dir = os.path.join(dir, "bg_masked_data", 'graph_data')
        fnames = sorted(glob(os.path.join(pix_dir, '*.h5')))

        def process_file(i):
            fname = fnames[i]
            head, tail = os.path.split(fname)
            num = tail[4:-3]
            # os.rename(os.path.join(graph_dir, "graph_" + str(i) + ".h5"), os.path.join(graph_dir, "graph_" + num + ".h5"))
            # os.rename(os.path.join(pix_dir, "pix_" + str(i) + ".h5"), os.path.join(pix_dir, "pix_" + num + ".h5"))

            # raw = torch.from_numpy(h5py.File(fname, 'r')['raw'][:].astype(np.float))
            # gt = h5py.File(fname, 'r')['label'][:].astype(np.long)
            # affs = torch.from_numpy(h5py.File(os.path.join(dir, 'affinities', tail[:-3] + '_predictions' + '.h5'), 'r')['predictions'][:]).squeeze(1)

            # raw -= raw.min()
            # raw /= raw.max()
            #
            # node_labeling = run_watershed(gaussian_filter(affs[0] + affs[1] + affs[2] + affs[3], sigma=.2), min_size=4)
            #
            # # relabel to consecutive ints starting at 0
            # node_labeling = torch.from_numpy(node_labeling.astype(np.long))
            #
            #
            # gt = torch.from_numpy(gt.astype(np.long))
            # mask = node_labeling[None] == torch.unique(node_labeling)[:, None, None]
            # node_labeling = (mask * (torch.arange(len(torch.unique(node_labeling)), device=node_labeling.device)[:, None, None] + 1)).sum(
            #     0) - 1
            #
            # node_labeling += 2
            # # bgm
            # node_labeling[gt == 0] = 0
            # node_labeling[gt == 1] = 1
            # plt.imshow(node_labeling);plt.show()
            #
            # mask = gt[None] == torch.unique(gt)[:, None, None]
            # gt = (mask * (torch.arange(len(torch.unique(gt)), device=gt.device)[:, None, None] + 1)).sum(0) - 1
            #
            # edge_img = get_contour_from_2d_binary(node_labeling[None, None].float())
            # edge_img = gauss_kernel(edge_img.float())
            # raw = torch.cat([raw[None, None], edge_img], dim=1).squeeze(0).numpy()
            # affs = torch.sigmoid(affs).numpy()
            #
            # edge_feat, edges = get_edge_features_1d(node_labeling.numpy(), offs, affs[:4])
            #
            # gt_edge_weights = gt_edge_weights.numpy()
            # gt = gt.numpy()
            # node_labeling = node_labeling.numpy()
            # edges = edges.astype(np.long)
            #
            # affs = affs.astype(np.float32)
            # edge_feat = edge_feat.astype(np.float32)
            # node_labeling = node_labeling.astype(np.float32)
            # gt_edge_weights = gt_edge_weights.astype(np.float32)
            # diff_to_gt = np.abs((edge_feat[:, 0] - gt_edge_weights)).sum()
            # edges = np.sort(edges, axis=-1)
            # edges = edges.T
            # #
            # #
            graph_file = h5py.File(
                os.path.join(graph_dir, "graph_" + num + ".h5"), 'r')
            pix_file = h5py.File(os.path.join(pix_dir, "pix_" + num + ".h5"),
                                 'r')

            raw = pix_file["raw_2chnl"][:]
            gt = pix_file["gt"][:]
            sp_seg = graph_file["node_labeling"][:]
            edges = graph_file["edges"][:]
            affs = graph_file["affinities"][:]

            graph_file.close()
            pix_file.close()
            # plt.imshow(multicut_from_probas(sp_seg, edges.T, calculate_gt_edge_costs(torch.from_numpy(edges.T.astype(np.long)).to(dev), torch.from_numpy(sp_seg).to(dev), torch.from_numpy(gt.squeeze()).to(dev), 1.5).cpu()), cmap=random_label_cmap(), interpolation="none");plt.show()
            with torch.set_grad_enabled(False):
                embeddings = model(
                    torch.from_numpy(raw).to(device).float()[None])
            emb_affs = get_affinities_from_embeddings_2d(
                embeddings, offs, 0.4, distance)[0].cpu().numpy()
            ew_embedaffs = 1 - get_edge_features_1d(sp_seg, offs,
                                                    emb_affs)[0][:, 0]
            mc_soln = torch.from_numpy(
                multicut_from_probas(sp_seg, edges.T,
                                     ew_embedaffs).astype(np.long)).to(device)

            mask = mc_soln[None] == torch.unique(mc_soln)[:, None, None]
            mc_soln = (mask *
                       (torch.arange(len(torch.unique(mc_soln)), device=device)
                        [:, None, None] + 1)).sum(0) - 1

            masses = (
                mc_soln[None] == torch.unique(mc_soln)[:, None,
                                                       None]).sum(-1).sum(-1)
            bg1id = masses.argmax()
            masses[bg1id] = 0
            bg1_mask = mc_soln == bg1id
            bg2_mask = mc_soln == masses.argmax()

            sp_seg = torch.from_numpy(sp_seg.astype(np.long)).to(device)

            mask = sp_seg[None] == torch.unique(sp_seg)[:, None, None]
            sp_seg = (mask *
                      (torch.arange(len(torch.unique(sp_seg)),
                                    device=sp_seg.device)[:, None, None] +
                       1)).sum(0) - 1

            sp_seg += 2
            sp_seg *= (bg1_mask == 0)
            sp_seg *= (bg2_mask == 0)
            sp_seg += bg2_mask

            mask = sp_seg[None] == torch.unique(sp_seg)[:, None, None]
            sp_seg = (mask *
                      (torch.arange(len(torch.unique(sp_seg)),
                                    device=sp_seg.device)[:, None, None] +
                       1)).sum(0) - 1
            sp_seg = sp_seg.cpu()

            raw -= raw.min()
            raw /= raw.max()
            edge_feat, edges = get_edge_features_1d(sp_seg.numpy(), offs[:4],
                                                    affs[:4])
            edges = edges.astype(np.long)

            gt_edge_weights = calculate_gt_edge_costs(
                torch.from_numpy(edges).to(device), sp_seg.to(device),
                torch.from_numpy(gt).to(device), 0.4)
            node_labeling = sp_seg.numpy()

            affs = affs.astype(np.float32)
            edge_feat = edge_feat.astype(np.float32)
            node_labeling = node_labeling.astype(np.float32)
            gt_edge_weights = gt_edge_weights.cpu().numpy().astype(np.float32)
            # edges = np.sort(edges, axis=-1)
            edges = edges.T
            # plt.imshow(sp_seg.cpu(), cmap=random_label_cmap(), interpolation="none");plt.show()
            # plt.imshow(bg1_mask.cpu());plt.show()
            # plt.imshow(bg2_mask.cpu());plt.show()
            new_pix_file = h5py.File(
                os.path.join(new_pix_dir, "pix_" + num + ".h5"), 'w')
            new_graph_file = h5py.File(
                os.path.join(new_graph_dir, "graph_" + num + ".h5"), 'w')

            # plt.imshow(gt_sp_projection, cmap=random_label_cmap(), interpolation="none");plt.show()

            new_pix_file.create_dataset("raw_2chnl", data=raw, chunks=True)
            new_pix_file.create_dataset("gt", data=gt, chunks=True)
            #
            new_graph_file.create_dataset("edges", data=edges, chunks=True)
            new_graph_file.create_dataset("edge_feat",
                                          data=edge_feat,
                                          chunks=True)
            new_graph_file.create_dataset("gt_edge_weights",
                                          data=gt_edge_weights,
                                          chunks=True)
            new_graph_file.create_dataset("node_labeling",
                                          data=node_labeling,
                                          chunks=True)
            new_graph_file.create_dataset("affinities", data=affs, chunks=True)
            new_graph_file.create_dataset("offsets",
                                          data=np.array([[1, 0], [0,
                                                                  1], [2, 0],
                                                         [0, 2], [4,
                                                                  0], [0, 4],
                                                         [8, 0], [0, 8],
                                                         [16, 0], [0, 16]]),
                                          chunks=True)
            #
            new_graph_file.close()
            new_pix_file.close()

        workers = []
        for i in range(len(fnames)):
            worker = threading.Thread(target=process_file, args=(i, ))
            worker.start()
            workers.append(worker)

        for worker in workers:
            worker.join()

    pass
Beispiel #6
0
    def __init__(self, cfg, global_count):
        super(AgentSaTrainerObjLvlReward, self).__init__()
        assert torch.cuda.device_count() == 1
        self.device = torch.device("cuda:0")
        torch.cuda.set_device(self.device)
        torch.set_default_tensor_type(torch.FloatTensor)

        self.cfg = cfg
        self.global_count = global_count
        self.memory = TransitionData_ts(capacity=self.cfg.mem_size)
        self.best_val_reward = -np.inf
        if self.cfg.distance == 'cosine':
            self.distance = CosineDistance()
        else:
            self.distance = L2Distance()

        self.fe_ext = FeExtractor(dict_to_attrdict(self.cfg.backbone),
                                  self.distance, cfg.fe_delta_dist,
                                  self.device)
        self.fe_ext.embed_model.load_state_dict(
            torch.load(self.cfg.fe_model_name))
        self.fe_ext.cuda(self.device)

        self.model = Agent(self.cfg, State, self.distance, self.device)
        wandb.watch(self.model)
        self.model.cuda(self.device)
        self.model_mtx = Lock()

        self.optimizer = torch.optim.Adam(self.model.actor.parameters(),
                                          lr=self.cfg.actor_lr)

        lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched)
        bw = lr_sched_cfg.mov_avg_bandwidth
        off = lr_sched_cfg.mov_avg_offset
        weights = np.linspace(lr_sched_cfg.weight_range[0],
                              lr_sched_cfg.weight_range[1], bw)
        weights = weights / weights.sum()  # make them sum up to one
        shed = lr_sched_cfg.torch_sched
        self.shed = ReduceLROnPlateau(self.optimizer,
                                      patience=shed.patience,
                                      threshold=shed.threshold,
                                      min_lr=shed.min_lr,
                                      factor=shed.factor)

        self.mov_sum_loss = RunningAverage(weights, band_width=bw, offset=off)
        self.scaler = torch.cuda.amp.GradScaler()
        self.forwarder = Forwarder()

        if self.cfg.agent_model_name != "":
            self.model.load_state_dict(torch.load(self.cfg.agent_model_name))

        # finished with prepping
        for param in self.fe_ext.parameters():
            param.requires_grad = False

        self.train_dset = SpgDset(self.cfg.data_dir,
                                  dict_to_attrdict(self.cfg.patch_manager),
                                  dict_to_attrdict(self.cfg.data_keys))
        self.val_dset = SpgDset(self.cfg.val_data_dir,
                                dict_to_attrdict(self.cfg.patch_manager),
                                dict_to_attrdict(self.cfg.data_keys))