def __init__(self, cfg, global_count): super(AgentA2CTrainer, self).__init__() assert torch.cuda.device_count() == 1 self.device = torch.device("cuda:0") torch.cuda.set_device(self.device) torch.set_default_tensor_type(torch.FloatTensor) self.cfg = cfg self.global_count = global_count self.memory = TransitionData_ts(capacity=self.cfg.mem_size) self.best_val_reward = -np.inf if self.cfg.distance == 'cosine': self.distance = CosineDistance() else: self.distance = L2Distance() self.model = Agent(self.cfg, State, self.distance, self.device, with_temp=False) wandb.watch(self.model) self.model.cuda(self.device) self.model_mtx = Lock() MovSumLosses = namedtuple('mov_avg_losses', ('actor', 'critic')) Scalers = namedtuple('Scalers', ('critic', 'actor')) OptimizerContainer = namedtuple('OptimizerContainer', ('actor', 'critic', 'actor_shed', 'critic_shed')) actor_optimizer = torch.optim.Adam(self.model.actor.parameters(), lr=self.cfg.actor_lr) critic_optimizer = torch.optim.Adam(self.model.critic.parameters(), lr=self.cfg.critic_lr) lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched) bw = lr_sched_cfg.mov_avg_bandwidth off = lr_sched_cfg.mov_avg_offset weights = np.linspace(lr_sched_cfg.weight_range[0], lr_sched_cfg.weight_range[1], bw) weights = weights / weights.sum() # make them sum up to one shed = lr_sched_cfg.torch_sched self.mov_sum_losses = MovSumLosses(RunningAverage(weights, band_width=bw, offset=off), RunningAverage(weights, band_width=bw, offset=off)) self.optimizers = OptimizerContainer(actor_optimizer, critic_optimizer, *[ReduceLROnPlateau(opt, patience=shed.patience, threshold=shed.threshold, min_lr=shed.min_lr, factor=shed.factor) for opt in (actor_optimizer, critic_optimizer)]) self.scalers = Scalers(torch.cuda.amp.GradScaler(), torch.cuda.amp.GradScaler()) self.forwarder = Forwarder() if self.cfg.agent_model_name != "": self.model.load_state_dict(torch.load(self.cfg.agent_model_name)) # finished with prepping self.train_dset = SpgDset(self.cfg.data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys), max(self.cfg.s_subgraph)) self.val_dset = SpgDset(self.cfg.val_data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys), max(self.cfg.s_subgraph)) self.segm_metric = AveragePrecision() self.clst_metric = ClusterMetrics() self.global_counter = 0
class AgentSacTrainer(object): def __init__(self, cfg, global_count): super(AgentSacTrainer, self).__init__() assert torch.cuda.device_count() == 1 self.device = torch.device("cuda:0") torch.cuda.set_device(self.device) torch.set_default_tensor_type(torch.FloatTensor) self.cfg = cfg self.global_count = global_count self.memory = TransitionData_ts(capacity=self.cfg.mem_size) self.best_val_reward = -np.inf if self.cfg.distance == 'cosine': self.distance = CosineDistance() else: self.distance = L2Distance() self.model = Agent(self.cfg, State, self.distance, self.device) wandb.watch(self.model) self.model.cuda(self.device) self.model_mtx = Lock() MovSumLosses = namedtuple('mov_avg_losses', ('actor', 'critic', 'temperature')) Scalers = namedtuple('Scalers', ('critic', 'actor')) OptimizerContainer = namedtuple( 'OptimizerContainer', ('actor', 'critic', 'temperature', 'actor_shed', 'critic_shed', 'temp_shed')) actor_optimizer = torch.optim.Adam(self.model.actor.parameters(), lr=self.cfg.actor_lr) if self.cfg.fe_optimization: critic_optimizer = torch.optim.Adam( list(self.model.critic.parameters()) + list(self.model.fe_ext.parameters()), lr=self.cfg.critic_lr) else: critic_optimizer = torch.optim.Adam(self.model.critic.parameters(), lr=self.cfg.critic_lr) # critic_optimizer = torch.optim.Adam(self.model.critic.parameters(), lr=self.cfg.critic_lr) temp_optimizer = torch.optim.Adam([self.model.log_alpha], lr=self.cfg.alpha_lr) lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched) bw = lr_sched_cfg.mov_avg_bandwidth off = lr_sched_cfg.mov_avg_offset weights = np.linspace(lr_sched_cfg.weight_range[0], lr_sched_cfg.weight_range[1], bw) weights = weights / weights.sum() # make them sum up to one shed = lr_sched_cfg.torch_sched self.mov_sum_losses = MovSumLosses( RunningAverage(weights, band_width=bw, offset=off), RunningAverage(weights, band_width=bw, offset=off), RunningAverage(weights, band_width=bw, offset=off)) self.optimizers = OptimizerContainer( actor_optimizer, critic_optimizer, temp_optimizer, *[ ReduceLROnPlateau(opt, patience=shed.patience, threshold=shed.threshold, min_lr=shed.min_lr, factor=shed.factor) for opt in (actor_optimizer, critic_optimizer, temp_optimizer) ]) self.scalers = Scalers(torch.cuda.amp.GradScaler(), torch.cuda.amp.GradScaler()) self.forwarder = Forwarder() if self.cfg.agent_model_name != "": self.model.load_state_dict(torch.load(self.cfg.agent_model_name)) # finished with prepping self.train_dset = SpgDset(self.cfg.data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.train_data_keys), max(self.cfg.s_subgraph)) self.val_dset = SpgDset(self.cfg.val_data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.val_data_keys), max(self.cfg.s_subgraph)) self.segm_metric = AveragePrecision() self.clst_metric = ClusterMetrics() self.global_counter = 0 def validate(self): """validates the prediction against the method of clustering the embedding space""" env = MulticutEmbeddingsEnv(self.cfg, self.device) if self.cfg.verbose: print("\n\n###### start validate ######", end='') self.model.eval() n_examples = len(self.val_dset) # taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] # rl_scores, keys = [], None self.clst_metric.reset() map_scores = [] ex_raws, ex_sps, ex_gts, ex_mc_gts, ex_feats, ex_emb, ex_n_emb, ex_rl, edge_ids, rewards, actions = [ [] for _ in range(11) ] dloader = iter( DataLoader(self.val_dset, batch_size=1, shuffle=False, pin_memory=True, num_workers=0)) acc_reward = 0 for it in range(n_examples): update_env_data(env, dloader, self.val_dset, self.device, with_gt_edges="sub_graph_dice" in self.cfg.reward_function) env.reset() state = env.get_state() self.model_mtx.acquire() try: distr, _, _, _, _, node_features, embeddings = self.forwarder.forward( self.model, state, State, self.device, grad=False, post_data=False, get_node_feats=True, get_embeddings=True) finally: self.model_mtx.release() action = torch.sigmoid(distr.loc) reward = env.execute_action(action, tau=0.0, train=False) rew = reward[-1].item( ) if self.cfg.reward_function == "sub_graph_dice" else reward[ -2].item() acc_reward += rew rl_labels = env.current_soln.cpu().numpy()[0] gt_seg = env.gt_seg[0].cpu().numpy() if self.cfg.verbose: print( f"\nstep: {it}; mean_loc: {round(distr.loc.mean().item(), 5)}; mean reward: {round(rew, 5)}", end='') if it in self.cfg.store_indices: node_features = node_features[:env.n_offs[1]][ env.init_sp_seg[0].long()].permute(2, 0, 1).cpu() gt_mc = cm.prism( env.gt_soln[0].cpu() / env.gt_soln[0].max().item() ) if env.gt_edge_weights is not None else torch.zeros( env.raw.shape[-2:]) ex_feats.append(pca_project(node_features, n_comps=3)) ex_emb.append(pca_project(embeddings[0].cpu(), n_comps=3)) ex_n_emb.append( pca_project(node_features[:self.cfg.dim_embeddings], n_comps=3)) ex_raws.append(env.raw[0].cpu().permute(1, 2, 0).squeeze()) ex_sps.append(env.init_sp_seg[0].cpu()) ex_mc_gts.append(gt_mc) ex_gts.append(gt_seg) ex_rl.append(rl_labels) edge_ids.append(env.edge_ids) rewards.append(reward[-1]) actions.append(action) map_scores.append(self.segm_metric(rl_labels, gt_seg)) self.clst_metric(rl_labels, gt_seg) ''' _rl_scores = matching(gt_seg, rl_labels, thresh=taus, criterion='iou', report_matches=False) if it == 0: for tau_it in range(len(_rl_scores)): rl_scores.append(np.array(list(map(float, list(_rl_scores[tau_it]._asdict().values())[1:])))) keys = list(_rl_scores[0]._asdict().keys())[1:] else: for tau_it in range(len(_rl_scores)): rl_scores[tau_it] += np.array(list(map(float, list(_rl_scores[tau_it]._asdict().values())[1:])) ''' ''' div = np.ones_like(rl_scores[0]) for i, key in enumerate(keys): if key not in ('fp', 'tp', 'fn'): div[i] = 10 for tau_it in range(len(rl_scores)): rl_scores[tau_it] = dict(zip(keys, rl_scores[tau_it] / div)) fig, axs = plt.subplots(1, 2, figsize=(10, 10)) plt.subplots_adjust(hspace=.5) for m in ('precision', 'recall', 'accuracy', 'f1'): y = [s[m] for s in rl_scores] data = [[x, y] for (x, y) in zip(taus, y)] table = wandb.Table(data=data, columns=["IoU_threshold", m]) wandb.log({"validation/" + m: wandb.plot.line(table, "IoU_threshold", m, stroke=None, title=m)}) axs[0].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m) axs[0].set_ylabel('Metric value') axs[0].grid() axs[0].legend(bbox_to_anchor=(.8, 1.65), loc='upper left', fontsize='xx-small') axs[0].set_title('RL method') axs[0].set_xlabel(r'IoU threshold $\tau$') for m in ('fp', 'tp', 'fn'): y = [s[m] for s in rl_scores] data = [[x, y] for (x, y) in zip(taus, y)] table = wandb.Table(data=data, columns=["IoU_threshold", m]) wandb.log({"validation/" + m: wandb.plot.line(table, "IoU_threshold", m, stroke=None, title=m)}) axs[1].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m) axs[1].set_ylabel('Number #') axs[1].grid() axs[1].legend(bbox_to_anchor=(.87, 1.6), loc='upper left', fontsize='xx-small'); axs[1].set_title('RL method') axs[1].set_xlabel(r'IoU threshold $\tau$') #wandb.log({"validation/metrics": [wandb.Image(fig, caption="metrics")]}) plt.close('all') ''' splits, merges, are, arp, arr = self.clst_metric.dump() wandb.log({"validation/acc_reward": acc_reward}) wandb.log({"validation/mAP": np.mean(map_scores)}, step=self.global_counter) wandb.log({"validation/UnderSegmVI": splits}, step=self.global_counter) wandb.log({"validation/OverSegmVI": merges}, step=self.global_counter) wandb.log({"validation/ARE": are}, step=self.global_counter) wandb.log({"validation/ARP": arp}, step=self.global_counter) wandb.log({"validation/ARR": arr}, step=self.global_counter) # do the lr sheduling self.optimizers.critic_shed.step(acc_reward) self.optimizers.actor_shed.step(acc_reward) if acc_reward > self.best_val_reward: self.best_val_reward = acc_reward wandb.run.summary["validation/acc_reward"] = acc_reward torch.save( self.model.state_dict(), os.path.join(wandb.run.dir, "best_checkpoint_agent.pth")) if self.cfg.verbose: print("\n###### finish validate ######\n", end='') label_cm = random_label_cmap(zeroth=1.0) label_cm.set_bad(alpha=0) for it, i in enumerate(self.cfg.store_indices): fig, axs = plt.subplots( 2, 4 if self.cfg.reward_function == "sub_graph_dice" else 5, sharex='col', figsize=(9, 5), sharey='row', gridspec_kw={ 'hspace': 0, 'wspace': 0 }) axs[0, 0].imshow(ex_gts[it], cmap=random_label_cmap(), interpolation="none") axs[0, 0].set_title('gt', y=1.05, size=10) axs[0, 0].axis('off') if ex_raws[it].ndim == 3: if ex_raws[it].shape[-1] > 2: axs[0, 1].imshow(ex_raws[it][..., :3], cmap="gray") else: axs[0, 1].imshow(ex_raws[it][..., 0], cmap="gray") else: axs[1, 1].imshow(ex_raws[it], cmap="gray") axs[0, 1].set_title('raw image', y=1.05, size=10) axs[0, 1].axis('off') if ex_raws[it].ndim == 3: if ex_raws[it].shape[-1] > 1: axs[0, 2].imshow(ex_raws[it][..., -1], cmap="gray") else: axs[0, 2].imshow(ex_raws[it][..., 0], cmap="gray") else: axs[0, 2].imshow(ex_raws[it], cmap="gray") axs[0, 2].set_title('plantseg', y=1.05, size=10) axs[0, 2].axis('off') axs[0, 3].imshow(ex_sps[it], cmap=random_label_cmap(), interpolation="none") axs[0, 3].set_title('superpixels', y=1.05, size=10) axs[0, 3].axis('off') axs[1, 0].imshow(ex_feats[it]) axs[1, 0].set_title('features', y=-0.15, size=10) axs[1, 0].axis('off') axs[1, 1].imshow(ex_n_emb[it]) axs[1, 1].set_title('node embeddings', y=-0.15, size=10) axs[1, 1].axis('off') axs[1, 2].imshow(ex_emb[it]) axs[1, 2].set_title('embeddings', y=-0.15, size=10) axs[1, 2].axis('off') axs[1, 3].imshow(ex_rl[it], cmap=random_label_cmap(), interpolation="none") axs[1, 3].set_title('prediction', y=-0.15, size=10) axs[1, 3].axis('off') if self.cfg.reward_function != "sub_graph_dice": frame_rew, scores_rew, bnd_mask = get_colored_edges_in_sseg( ex_sps[it][None].float(), edge_ids[it].cpu(), rewards[it].cpu()) frame_act, scores_act, _ = get_colored_edges_in_sseg( ex_sps[it][None].float(), edge_ids[it].cpu(), 1 - actions[it].cpu().squeeze()) bnd_mask = torch.from_numpy(dilation(bnd_mask.cpu().numpy())) frame_rew = np.stack([ dilation(frame_rew.cpu().numpy()[..., i]) for i in range(3) ], -1) frame_act = np.stack([ dilation(frame_act.cpu().numpy()[..., i]) for i in range(3) ], -1) ex_rl[it] = ex_rl[it].squeeze().astype(np.float) ex_rl[it][bnd_mask] = np.nan axs[1, 4].imshow(frame_rew, interpolation="none") axs[1, 4].imshow(ex_rl[it], cmap=label_cm, alpha=0.8, interpolation="none") axs[1, 4].set_title("rewards", y=-0.2) axs[1, 4].axis('off') axs[0, 4].imshow(frame_act, interpolation="none") axs[0, 4].imshow(ex_rl[it], cmap=label_cm, alpha=0.8, interpolation="none") axs[0, 4].set_title("actions", y=1.05) axs[0, 4].axis('off') wandb.log( { "validation/sample_" + str(i): [wandb.Image(fig, caption="sample images")] }, step=self.global_counter) plt.close('all') def update_critic(self, obs, action, reward): self.optimizers.critic.zero_grad() with torch.cuda.amp.autocast(enabled=True): current_Q, side_loss = self.forwarder.forward(self.model, obs, State, self.device, actions=action, grad=True) critic_loss = torch.tensor([0.0], device=current_Q[0].device) mean_reward = 0 for i, sz in enumerate(self.cfg.s_subgraph): target_Q = reward[i] target_Q = target_Q.detach() critic_loss = critic_loss + F.mse_loss(current_Q[i], target_Q) mean_reward += reward[i].mean() critic_loss = critic_loss / len( self.cfg.s_subgraph) + self.cfg.side_loss_weight * side_loss self.scalers.critic.scale(critic_loss).backward() self.scalers.critic.step(self.optimizers.critic) self.scalers.critic.update() return critic_loss.item(), mean_reward / len(self.cfg.s_subgraph) def update_actor_and_alpha(self, obs, reward, expl_action): self.optimizers.actor.zero_grad() self.optimizers.temperature.zero_grad() with torch.cuda.amp.autocast(enabled=True): expl_action = None distribution, actor_Q, action, side_loss = self.forwarder.forward( self.model, obs, State, self.device, expl_action=expl_action, policy_opt=True, grad=True) log_prob = distribution.log_prob(action) actor_loss = torch.tensor([0.0], device=actor_Q[0].device) alpha_loss = torch.tensor([0.0], device=actor_Q[0].device) _log_prob, sg_entropy = [], [] for i, sz in enumerate(self.cfg.s_subgraph): ret = get_joint_sg_logprobs_edges(log_prob, distribution.scale, obs, i, sz) _log_prob.append(ret[0]) sg_entropy.append(ret[1]) loss = (self.model.alpha[i].detach() * _log_prob[i] - actor_Q[i]).mean() actor_loss = actor_loss + loss actor_loss = actor_loss / len( self.cfg.s_subgraph) + self.cfg.side_loss_weight * side_loss min_entropy = (self.cfg.entropy_range[1] - self.cfg.entropy_range[0]) * ((1.5 - reward[-1]) / 1.5) + \ self.cfg.entropy_range[0] min_entropy = torch.ones_like(min_entropy) for i, sz in enumerate(self.cfg.s_subgraph): min_entropy = min_entropy.to( self.model.alpha[i].device).squeeze() entropy = sg_entropy[i].detach( ) if self.cfg.use_closed_form_entropy else -_log_prob[ i].detach() alpha_loss = alpha_loss + ( self.model.alpha[i] * (entropy - (self.cfg.s_subgraph[i] * min_entropy))).mean() alpha_loss = alpha_loss / len(self.cfg.s_subgraph) self.scalers.actor.scale(actor_loss).backward() self.scalers.actor.scale(alpha_loss).backward() self.scalers.actor.step(self.optimizers.actor) self.scalers.actor.step(self.optimizers.temperature) self.scalers.actor.update() return actor_loss.item(), alpha_loss.item( ), min_entropy, distribution.loc.mean().item() def _step(self, step): actor_loss, alpha_loss, min_entropy, loc_mean = None, None, None, None (obs, action, reward), sample_idx = self.memory.sample() critic_loss, mean_reward = self.update_critic(obs, action, reward) self.memory.report_sample_loss(critic_loss + mean_reward, sample_idx) avg = self.mov_sum_losses.critic.apply(critic_loss) if avg is not None: self.optimizers.critic_shed.step(avg) wandb.log({"loss/critic": critic_loss}, step=self.global_counter) if self.cfg.actor_update_after < step and step % self.cfg.actor_update_frequency == 0: actor_loss, alpha_loss, min_entropy, loc_mean = self.update_actor_and_alpha( obs, reward, action) avg = self.mov_sum_losses.actor.apply(actor_loss) if avg is not None: self.optimizers.actor_shed.step(avg) avg = self.mov_sum_losses.temperature.apply(alpha_loss) if avg is not None: self.optimizers.temp_shed.step(avg) wandb.log({"loss/actor": actor_loss}, step=self.global_counter) wandb.log({"loss/alpha": alpha_loss}, step=self.global_counter) if step % self.cfg.post_stats_frequency == 0: if min_entropy != "nl": wandb.log({"min_entropy": min_entropy}, step=self.global_counter) wandb.log({"mov_avg/critic": self.mov_sum_losses.critic.avg}, step=self.global_counter) wandb.log({"mov_avg/actor": self.mov_sum_losses.actor.avg}, step=self.global_counter) wandb.log( {"mov_avg/temperature": self.mov_sum_losses.temperature.avg}, step=self.global_counter) wandb.log( { "lr/critic": self.optimizers.critic_shed.optimizer.param_groups[0]['lr'] }, step=self.global_counter) wandb.log( { "lr/actor": self.optimizers.actor_shed.optimizer.param_groups[0]['lr'] }, step=self.global_counter) wandb.log( { "lr/temperature": self.optimizers.temp_shed.optimizer.param_groups[0]['lr'] }, step=self.global_counter) self.global_counter = self.global_counter + 1 if step % self.cfg.critic_target_update_frequency == 0: soft_update_params(self.model.critic, self.model.critic_tgt, self.cfg.critic_tau) soft_update_params(self.model.fe_ext, self.model.fe_ext_tgt, self.cfg.critic_tau) return critic_loss, actor_loss, alpha_loss, loc_mean def train_until_finished(self): while self.global_count.value() <= self.cfg.T_max + self.cfg.mem_size: self.model_mtx.acquire() try: stats = [[], [], [], []] for i in range(self.cfg.n_updates_per_step): _stats = self._step(self.global_count.value()) [s.append(_s) for s, _s in zip(stats, _stats)] for j in range(len(stats)): if any([_s is None for _s in stats[j]]): stats[j] = "nl" else: stats[j] = round( sum(stats[j]) / self.cfg.n_updates_per_step, 5) if self.cfg.verbose: print( f"step: {self.global_count.value()}; mean_loc: {stats[-1]}; n_explorer_steps {self.memory.push_count}", end="") print(f"; cl: {stats[0]}; acl: {stats[1]}; al: {stats[3]}") finally: self.model_mtx.release() self.global_count.increment() self.memory.reset_push_count() if self.global_count.value() % self.cfg.validatoin_freq == 0: self.validate() torch.save(self.model.state_dict(), os.path.join(wandb.run.dir, "last_checkpoint_agent.pth")) # Acts and trains model def train_and_explore(self, rn): self.global_count.reset() set_seed_everywhere(rn) wandb.config.random_seed = rn if self.cfg.verbose: print('###### start training ######') print('Running on device: ', self.device) print('found ', self.train_dset.length, " training data patches") print('found ', self.val_dset.length, "validation data patches") print('training with seed: ' + str(rn)) explorers = [] for i in range(self.cfg.n_explorers): explorers.append(threading.Thread(target=self.explore)) [explorer.start() for explorer in explorers] self.memory.is_full_event.wait() trainer = threading.Thread(target=self.train_until_finished) trainer.start() trainer.join() self.global_count.set(self.cfg.T_max + self.cfg.mem_size + 4) [explorer.join() for explorer in explorers] self.memory.clear() del self.memory # torch.save(self.model.state_dict(), os.path.join(wandb.run.dir, "last_checkpoint_agent.pth")) if self.cfg.verbose: print('\n\n###### training finished ######') return def explore(self): env = MulticutEmbeddingsEnv(self.cfg, self.device) tau = 1 while self.global_count.value() <= self.cfg.T_max + self.cfg.mem_size: dloader = iter( DataLoader(self.train_dset, batch_size=self.cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0)) for iteration in range( (len(self.train_dset) // self.cfg.batch_size) * self.cfg.data_update_frequency): if iteration % self.cfg.data_update_frequency == 0: update_env_data(env, dloader, self.train_dset, self.device, with_gt_edges="sub_graph_dice" in self.cfg.reward_function) env.reset() state = env.get_state() if not self.memory.is_full(): action = torch.rand((env.edge_ids.shape[-1], 1), device=self.device) else: self.model_mtx.acquire() try: distr, _, action, _, _ = self.forwarder.forward( self.model, state, State, self.device, grad=False) finally: self.model_mtx.release() reward = env.execute_action(action, tau=max(0, tau)) self.memory.push(state_to_cpu(state, State), action, reward) if self.global_count.value( ) > self.cfg.T_max + self.cfg.mem_size: break return
def __init__(self, cfg, global_count): super(AgentSacTrainerObjLvlReward, self).__init__() assert torch.cuda.device_count() == 1 self.device = torch.device("cuda:0") torch.cuda.set_device(self.device) torch.set_default_tensor_type(torch.FloatTensor) self.cfg = cfg self.global_count = global_count self.memory = TransitionData_ts(capacity=self.cfg.mem_size) self.best_val_reward = -np.inf if self.cfg.distance == 'cosine': self.distance = CosineDistance() else: self.distance = L2Distance() self.fe_ext = FeExtractor(dict_to_attrdict(self.cfg.backbone), self.distance, cfg.fe_delta_dist, self.device) self.fe_ext.embed_model.load_state_dict( torch.load(self.cfg.fe_model_name)) self.fe_ext.cuda(self.device) self.model = Agent(self.cfg, State, self.distance, self.device) wandb.watch(self.model) self.model.cuda(self.device) self.model_mtx = Lock() MovSumLosses = namedtuple('mov_avg_losses', ('actor', 'critic', 'temperature')) Scalers = namedtuple('Scalers', ('critic', 'actor')) OptimizerContainer = namedtuple( 'OptimizerContainer', ('actor', 'critic', 'temperature', 'actor_shed', 'critic_shed', 'temp_shed')) actor_optimizer = torch.optim.Adam(self.model.actor.parameters(), lr=self.cfg.actor_lr) critic_optimizer = torch.optim.Adam(self.model.critic.parameters(), lr=self.cfg.critic_lr) temp_optimizer = torch.optim.Adam([self.model.log_alpha], lr=self.cfg.alpha_lr) lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched) bw = lr_sched_cfg.mov_avg_bandwidth off = lr_sched_cfg.mov_avg_offset weights = np.linspace(lr_sched_cfg.weight_range[0], lr_sched_cfg.weight_range[1], bw) weights = weights / weights.sum() # make them sum up to one shed = lr_sched_cfg.torch_sched self.mov_sum_losses = MovSumLosses( RunningAverage(weights, band_width=bw, offset=off), RunningAverage(weights, band_width=bw, offset=off), RunningAverage(weights, band_width=bw, offset=off)) self.optimizers = OptimizerContainer( actor_optimizer, critic_optimizer, temp_optimizer, *[ ReduceLROnPlateau(opt, patience=shed.patience, threshold=shed.threshold, min_lr=shed.min_lr, factor=shed.factor) for opt in (actor_optimizer, critic_optimizer, temp_optimizer) ]) self.scalers = Scalers(torch.cuda.amp.GradScaler(), torch.cuda.amp.GradScaler()) self.forwarder = Forwarder() if self.cfg.agent_model_name != "": self.model.load_state_dict(torch.load(self.cfg.agent_model_name)) # if "policy_warmup" in self.cfg and self.cfg.agent_model_name == "": # supervised_policy_pretraining(self.model, self.env, self.cfg, device=self.device) # torch.save(self.model.state_dict(), os.path.join(wandb.run.dir, "sv_pretrained_policy_agent.pth")) # finished with prepping for param in self.fe_ext.parameters(): param.requires_grad = False self.train_dset = SpgDset(self.cfg.data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys)) self.val_dset = SpgDset(self.cfg.val_data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys))
class AgentSacTrainerObjLvlReward(object): def __init__(self, cfg, global_count): super(AgentSacTrainerObjLvlReward, self).__init__() assert torch.cuda.device_count() == 1 self.device = torch.device("cuda:0") torch.cuda.set_device(self.device) torch.set_default_tensor_type(torch.FloatTensor) self.cfg = cfg self.global_count = global_count self.memory = TransitionData_ts(capacity=self.cfg.mem_size) self.best_val_reward = -np.inf if self.cfg.distance == 'cosine': self.distance = CosineDistance() else: self.distance = L2Distance() self.fe_ext = FeExtractor(dict_to_attrdict(self.cfg.backbone), self.distance, cfg.fe_delta_dist, self.device) self.fe_ext.embed_model.load_state_dict( torch.load(self.cfg.fe_model_name)) self.fe_ext.cuda(self.device) self.model = Agent(self.cfg, State, self.distance, self.device) wandb.watch(self.model) self.model.cuda(self.device) self.model_mtx = Lock() MovSumLosses = namedtuple('mov_avg_losses', ('actor', 'critic', 'temperature')) Scalers = namedtuple('Scalers', ('critic', 'actor')) OptimizerContainer = namedtuple( 'OptimizerContainer', ('actor', 'critic', 'temperature', 'actor_shed', 'critic_shed', 'temp_shed')) actor_optimizer = torch.optim.Adam(self.model.actor.parameters(), lr=self.cfg.actor_lr) critic_optimizer = torch.optim.Adam(self.model.critic.parameters(), lr=self.cfg.critic_lr) temp_optimizer = torch.optim.Adam([self.model.log_alpha], lr=self.cfg.alpha_lr) lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched) bw = lr_sched_cfg.mov_avg_bandwidth off = lr_sched_cfg.mov_avg_offset weights = np.linspace(lr_sched_cfg.weight_range[0], lr_sched_cfg.weight_range[1], bw) weights = weights / weights.sum() # make them sum up to one shed = lr_sched_cfg.torch_sched self.mov_sum_losses = MovSumLosses( RunningAverage(weights, band_width=bw, offset=off), RunningAverage(weights, band_width=bw, offset=off), RunningAverage(weights, band_width=bw, offset=off)) self.optimizers = OptimizerContainer( actor_optimizer, critic_optimizer, temp_optimizer, *[ ReduceLROnPlateau(opt, patience=shed.patience, threshold=shed.threshold, min_lr=shed.min_lr, factor=shed.factor) for opt in (actor_optimizer, critic_optimizer, temp_optimizer) ]) self.scalers = Scalers(torch.cuda.amp.GradScaler(), torch.cuda.amp.GradScaler()) self.forwarder = Forwarder() if self.cfg.agent_model_name != "": self.model.load_state_dict(torch.load(self.cfg.agent_model_name)) # if "policy_warmup" in self.cfg and self.cfg.agent_model_name == "": # supervised_policy_pretraining(self.model, self.env, self.cfg, device=self.device) # torch.save(self.model.state_dict(), os.path.join(wandb.run.dir, "sv_pretrained_policy_agent.pth")) # finished with prepping for param in self.fe_ext.parameters(): param.requires_grad = False self.train_dset = SpgDset(self.cfg.data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys)) self.val_dset = SpgDset(self.cfg.val_data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys)) def validate(self): """validates the prediction against the method of clustering the embedding space""" env = MulticutEmbeddingsEnv(self.fe_ext, self.cfg, self.device) if self.cfg.verbose: print("\n\n###### start validate ######", end='') self.model.eval() n_examples = len(self.val_dset) taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] rl_scores, keys = [], None ex_raws, ex_sps, ex_gts, ex_mc_gts, ex_embeds, ex_rl = [], [], [], [], [], [] dloader = iter( DataLoader(self.val_dset, batch_size=1, shuffle=True, pin_memory=True, num_workers=0)) acc_reward = 0 for it in range(len(self.val_dset)): update_env_data(env, dloader, self.val_dset, self.device, with_gt_edges="sub_graph_dice" in self.cfg.reward_function) env.reset() state = env.get_state() self.model_mtx.acquire() try: distr, _ = self.forwarder.forward(self.model, state, State, self.device, grad=False, post_data=False) finally: self.model_mtx.release() action = torch.sigmoid(distr.loc) reward, state = env.execute_action(action, None, post_images=True, tau=0.0, train=False) acc_reward += reward[-1].item() if self.cfg.verbose: print( f"\nstep: {it}; mean_loc: {round(distr.loc.mean().item(), 5)}; mean reward: {round(reward[-1].item(), 5)}", end='') embeddings = env.embeddings[0].cpu().numpy() gt_seg = env.gt_seg[0].cpu().numpy() gt_mc = cm.prism( env.gt_soln[0].cpu() / env.gt_soln[0].max().item() ) if env.gt_edge_weights is not None else torch.zeros( env.raw.shape[-2:]) rl_labels = env.current_soln.cpu().numpy()[0] if it < n_examples: ex_embeds.append(pca_project(embeddings, n_comps=3)) ex_raws.append(env.raw[0].cpu().permute(1, 2, 0).squeeze()) ex_sps.append(env.init_sp_seg[0].cpu()) ex_mc_gts.append(gt_mc) ex_gts.append(gt_seg) ex_rl.append(rl_labels) _rl_scores = matching(gt_seg, rl_labels, thresh=taus, criterion='iou', report_matches=False) if it == 0: for tau_it in range(len(_rl_scores)): rl_scores.append( np.array( list( map( float, list(_rl_scores[tau_it]._asdict().values()) [1:])))) keys = list(_rl_scores[0]._asdict().keys())[1:] else: for tau_it in range(len(_rl_scores)): rl_scores[tau_it] += np.array( list( map( float, list(_rl_scores[tau_it]._asdict().values()) [1:]))) div = np.ones_like(rl_scores[0]) for i, key in enumerate(keys): if key not in ('fp', 'tp', 'fn'): div[i] = 10 for tau_it in range(len(rl_scores)): rl_scores[tau_it] = dict(zip(keys, rl_scores[tau_it] / div)) fig, axs = plt.subplots(1, 2, figsize=(10, 10)) plt.subplots_adjust(hspace=.5) for m in ('precision', 'recall', 'accuracy', 'f1'): y = [s[m] for s in rl_scores] data = [[x, y] for (x, y) in zip(taus, y)] table = wandb.Table(data=data, columns=["IoU_threshold", m]) wandb.log({ "validation/" + m: wandb.plot.line(table, "IoU_threshold", m, stroke=None, title=m) }) axs[0].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m) axs[0].set_ylabel('Metric value') axs[0].grid() axs[0].legend(bbox_to_anchor=(.8, 1.65), loc='upper left', fontsize='xx-small') axs[0].set_title('RL method') axs[0].set_xlabel(r'IoU threshold $\tau$') for m in ('fp', 'tp', 'fn'): y = [s[m] for s in rl_scores] data = [[x, y] for (x, y) in zip(taus, y)] table = wandb.Table(data=data, columns=["IoU_threshold", m]) wandb.log({ "validation/" + m: wandb.plot.line(table, "IoU_threshold", m, stroke=None, title=m) }) axs[1].plot(taus, [s[m] for s in rl_scores], '.-', lw=2, label=m) axs[1].set_ylabel('Number #') axs[1].grid() axs[1].legend(bbox_to_anchor=(.87, 1.6), loc='upper left', fontsize='xx-small') axs[1].set_title('RL method') axs[1].set_xlabel(r'IoU threshold $\tau$') wandb.log( {"validation/metrics": [wandb.Image(fig, caption="metrics")]}) wandb.log({"validation_reward": acc_reward}) plt.close('all') if acc_reward > self.best_val_reward: self.best_val_reward = acc_reward wandb.run.summary["validation_reward"] = acc_reward torch.save( self.model.state_dict(), os.path.join(wandb.run.dir, "best_checkpoint_agent.pth")) if self.cfg.verbose: print("\n###### finish validate ######\n", end='') for i in range(n_examples): fig, axs = plt.subplots(2, 3, sharex='col', sharey='row', gridspec_kw={ 'hspace': 0, 'wspace': 0 }) axs[0, 0].imshow(ex_gts[i], cmap=random_label_cmap(), interpolation="none") axs[0, 0].set_title('gt') axs[0, 0].axis('off') if ex_raws[i].ndim == 3: axs[0, 1].imshow(ex_raws[i][..., 0]) else: axs[0, 1].imshow(ex_raws[i]) axs[0, 1].set_title('raw image') axs[0, 1].axis('off') axs[0, 2].imshow(ex_sps[i], cmap=random_label_cmap(), interpolation="none") axs[0, 2].set_title('superpixels') axs[0, 2].axis('off') axs[1, 0].imshow(ex_embeds[i]) axs[1, 0].set_title('pc proj 1-3', y=-0.15) axs[1, 0].axis('off') if ex_raws[i].ndim == 3: if ex_raws[i].shape[-1] > 1: axs[1, 1].imshow(ex_raws[i][..., 1]) else: axs[1, 1].imshow(ex_raws[i][..., 0]) else: axs[1, 1].imshow(ex_raws[i]) axs[1, 1].set_title('sp edge', y=-0.15) axs[1, 1].axis('off') axs[1, 2].imshow(ex_rl[i], cmap=random_label_cmap(), interpolation="none") axs[1, 2].set_title('prediction', y=-0.15) axs[1, 2].axis('off') wandb.log({ "validation/samples": [wandb.Image(fig, caption="sample images")] }) plt.close('all') def update_critic(self, obs, action, reward): self.optimizers.critic.zero_grad() with torch.cuda.amp.autocast(enabled=True): current_Q1, current_Q2 = self.forwarder.forward(self.model, obs, State, self.device, actions=action) target_Q = reward[0] target_Q = target_Q.detach() critic_loss = F.mse_loss(current_Q1.squeeze(1), target_Q) + F.mse_loss( current_Q2.squeeze(1), target_Q) self.scalers.critic.scale(critic_loss).backward() self.scalers.critic.step(self.optimizers.critic) self.scalers.critic.update() return critic_loss.item(), reward[0].mean() def update_actor_and_alpha(self, obs, reward, expl_action): self.optimizers.actor.zero_grad() self.optimizers.temperature.zero_grad() obj_edge_mask_actor = obs.obj_edge_mask_actor.to(self.device) with torch.cuda.amp.autocast(enabled=True): distribution, actor_Q1, actor_Q2, action, side_loss = self.forwarder.forward( self.model, obs, State, self.device, expl_action=expl_action, policy_opt=True) obj_n_edges = obj_edge_mask_actor.sum(1) log_prob = distribution.log_prob(action) actor_loss = torch.tensor([0.0], device=actor_Q1[0].device) alpha_loss = torch.tensor([0.0], device=actor_Q1[0].device) actor_Q = torch.min(actor_Q1, actor_Q2) obj_log_prob = (log_prob[None] * obj_edge_mask_actor[..., None]).sum(1) obj_entropy = ( (1 / 2 * (1 + (2 * np.pi * distribution.scale**2).log()))[None] * obj_edge_mask_actor[..., None]).sum(1).squeeze(1) loss = (self.model.alpha.detach() * obj_log_prob - actor_Q).mean() actor_loss = actor_loss + loss actor_loss = actor_loss + self.cfg.side_loss_weight * side_loss min_entropy = ( self.cfg.entropy_range[1] - self.cfg.entropy_range[0]) * ( (1.5 - reward[0]) / 1.5) + self.cfg.entropy_range[0] min_entropy = min_entropy.to(self.model.alpha.device).squeeze() entropy = obj_entropy.detach( ) if self.cfg.use_closed_form_entropy else -obj_log_prob.detach() alpha_loss = alpha_loss + (self.model.alpha * (entropy - (obj_n_edges * min_entropy))).mean() self.scalers.actor.scale(actor_loss).backward() self.scalers.actor.scale(alpha_loss).backward() self.scalers.actor.step(self.optimizers.actor) self.scalers.actor.step(self.optimizers.temperature) self.scalers.actor.update() return actor_loss.item(), alpha_loss.item(), min_entropy.mean().item( ), distribution.loc.mean().item() def _step(self, step): actor_loss, alpha_loss, min_entropy, loc_mean = None, None, None, None (obs, action, reward), sample_idx = self.memory.sample() action = action.to(self.device) for i in range(len(reward)): reward[i] = reward[i].to(self.device) critic_loss, mean_reward = self.update_critic(obs, action, reward) self.memory.report_sample_loss(critic_loss + mean_reward, sample_idx) self.mov_sum_losses.critic.apply(critic_loss) # self.optimizers.critic_shed.step(self.mov_sum_losses.critic.avg) wandb.log({"loss/critic": critic_loss}) if self.cfg.actor_update_after < step and step % self.cfg.actor_update_frequency == 0: actor_loss, alpha_loss, min_entropy, loc_mean = self.update_actor_and_alpha( obs, reward, action) self.mov_sum_losses.actor.apply(actor_loss) self.mov_sum_losses.temperature.apply(alpha_loss) # self.optimizers.actor_shed.step(self.mov_sum_losses.actor.avg) # self.optimizers.temp_shed.step(self.mov_sum_losses.actor.avg) wandb.log({"loss/actor": actor_loss}) wandb.log({"loss/alpha": alpha_loss}) if step % self.cfg.post_stats_frequency == 0: if min_entropy != "nl": wandb.log({"min_entropy": min_entropy}) wandb.log({"mov_avg/critic": self.mov_sum_losses.critic.avg}) wandb.log({"mov_avg/actor": self.mov_sum_losses.actor.avg}) wandb.log( {"mov_avg/temperature": self.mov_sum_losses.temperature.avg}) wandb.log({ "lr/critic": self.optimizers.critic_shed.optimizer.param_groups[0]['lr'] }) wandb.log({ "lr/actor": self.optimizers.actor_shed.optimizer.param_groups[0]['lr'] }) wandb.log({ "lr/temperature": self.optimizers.temp_shed.optimizer.param_groups[0]['lr'] }) if step % self.cfg.critic_target_update_frequency == 0: soft_update_params(self.model.critic, self.model.critic_tgt, self.cfg.critic_tau) return [critic_loss, actor_loss, alpha_loss, loc_mean] def train_until_finished(self): while self.global_count.value() <= self.cfg.T_max + self.cfg.mem_size: self.model_mtx.acquire() try: stats = [[], [], [], []] for i in range(self.cfg.n_updates_per_step): _stats = self._step(self.global_count.value()) [s.append(_s) for s, _s in zip(stats, _stats)] for j in range(len(stats)): if any([_s is None for _s in stats[j]]): stats[j] = "nl" else: stats[j] = round( sum(stats[j]) / self.cfg.n_updates_per_step, 5) if self.cfg.verbose: print( f"step: {self.global_count.value()}; mean_loc: {stats[-1]}; n_explorer_steps {self.memory.push_count}", end="") print(f"; cl: {stats[0]}; acl: {stats[1]}; al: {stats[3]}") finally: self.model_mtx.release() self.global_count.increment() self.memory.reset_push_count() if self.global_count.value() % self.cfg.validatoin_freq == 0: self.validate() # Acts and trains model def train_and_explore(self, rn): self.global_count.reset() set_seed_everywhere(rn) wandb.config.random_seed = rn if self.cfg.verbose: print('###### start training ######') print('Running on device: ', self.device) print('found ', self.train_dset.length, " training data patches") print('found ', self.val_dset.length, "validation data patches") print('training with seed: ' + str(rn)) explorers = [] for i in range(self.cfg.n_explorers): explorers.append(threading.Thread(target=self.explore)) [explorer.start() for explorer in explorers] self.memory.is_full_event.wait() trainer = threading.Thread(target=self.train_until_finished) trainer.start() trainer.join() self.global_count.set(self.cfg.T_max + self.cfg.mem_size + 4) [explorer.join() for explorer in explorers] self.memory.clear() del self.memory torch.save(self.model.state_dict(), os.path.join(wandb.run.dir, "last_checkpoint_agent.pth")) if self.cfg.verbose: print('\n\n###### training finished ######') return def explore(self): env = MulticutEmbeddingsEnv(self.fe_ext, self.cfg, self.device) tau = 1 while self.global_count.value() <= self.cfg.T_max + self.cfg.mem_size: dloader = iter( DataLoader(self.train_dset, batch_size=self.cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0)) for iteration in range( (len(self.train_dset) // self.cfg.batch_size) * self.cfg.data_update_frequency): if iteration % self.cfg.data_update_frequency == 0: update_env_data(env, dloader, self.train_dset, self.device, with_gt_edges="sub_graph_dice" in self.cfg.reward_function) env.reset() state = env.get_state() if not self.memory.is_full(): action = torch.rand((env.edge_ids.shape[-1], 1), device=self.device) else: self.model_mtx.acquire() try: distr, action = self.forwarder.forward(self.model, state, State, self.device, grad=False) finally: self.model_mtx.release() reward, state = env.execute_action(action, tau=max(0, tau)) for i in range(len(reward)): reward[i] = reward[i].cpu() self.memory.push(state_to_cpu(state, State), action.cpu(), reward) if self.global_count.value( ) > self.cfg.T_max + self.cfg.mem_size: break return
def validate_and_compare_to_clustering(model, env, distance, device, cfg): """validates the prediction against the method of clustering the embedding space""" model.eval() offs = [[1, 0], [0, 1], [2, 0], [0, 2], [4, 0], [0, 4], [16, 0], [0, 16]] ex_raws, ex_sps, ex_gts, ex_mc_gts, ex_embeds, ex_clst, ex_clst_sp, ex_mcaff, ex_mc_embed, ex_rl, \ ex_clst_graph_agglo= [], [], [], [], [], [], [], [], [], [], [] dset = SpgDset(cfg.val_data_dir, dict_to_attrdict(cfg.patch_manager), dict_to_attrdict(cfg.val_data_keys), max(cfg.s_subgraph)) dloader = iter(DataLoader(dset)) acc_reward = 0 forwarder = Forwarder() delta_dist = 0.4 # segm_metric = AveragePrecision() clst_metric_rl = ClusterMetrics() # clst_metric = ClusterMetrics() metric_sp_gt = ClusterMetrics() # clst_metric_mcaff = ClusterMetrics() # clst_metric_mcembed = ClusterMetrics() # clst_metric_graphagglo = ClusterMetrics() sbd = SBD() # map_rl, map_embed, map_sp_gt, map_mcaff, map_mcembed, map_graphagglo = [], [], [], [], [], [] sbd_rl, sbd_embed, sbd_sp_gt, sbd_mcaff, sbd_mcembed, sbd_graphagglo = [], [], [], [], [], [] n_examples = len(dset) for it in range(n_examples): update_env_data(env, dloader, dset, device, with_gt_edges=False) env.reset() state = env.get_state() distr, _, _, _, _, node_features, embeddings = forwarder.forward(model, state, State, device, grad=False, post_data=False, get_node_feats=True, get_embeddings=True) action = torch.sigmoid(distr.loc) reward = env.execute_action(action, tau=0.0, train=False) acc_reward += reward[-2].item() embeds = embeddings[0].cpu() # node_features = node_features.cpu().numpy() rag = env.rags[0] edge_ids = rag.uvIds() gt_seg = env.gt_seg[0].cpu().numpy() # l2_embeddings = get_angles(embeds[None])[0] # l2_node_feats = get_angles(torch.from_numpy(node_features.T[None, ..., None])).squeeze().T.numpy() # clst_labels_kmeans = cluster_embeddings(l2_embeddings.permute((1, 2, 0)), len(np.unique(gt_seg))) # node_labels = cluster_embeddings(l2_node_feats, len(np.unique(gt_seg))) # clst_labels_sp_kmeans = elf.segmentation.features.project_node_labels_to_pixels(rag, node_labels).squeeze() # clst_labels_sp_graph_agglo = get_soln_graph_clustering(env.init_sp_seg, torch.from_numpy(edge_ids.astype(np.int)), torch.from_numpy(l2_node_feats), len(np.unique(gt_seg)))[0][0].numpy() # mc_labels_aff = env.get_current_soln(edge_weights=env.edge_features[:, 0]).cpu().numpy()[0] # ew_embedaffs = 1 - get_edge_features_1d(env.init_sp_seg[0].cpu().numpy(), offs, get_affinities_from_embeddings_2d(embeddings, offs, delta_dist, distance)[0].cpu().numpy())[0][:, 0] # mc_labels_embedding_aff = env.get_current_soln(edge_weights=torch.from_numpy(ew_embedaffs).to(device)).cpu().numpy()[0] rl_labels = env.current_soln.cpu().numpy()[0] ex_embeds.append(pca_project(embeds, n_comps=3)) ex_raws.append(env.raw[0].cpu().permute(1, 2, 0).squeeze()) # ex_sps.append(cm.prism(env.init_sp_seg[0].cpu() / env.init_sp_seg[0].max().item())) ex_sps.append(env.init_sp_seg[0].cpu()) ex_mc_gts.append(project_overseg_to_seg(env.init_sp_seg[0], torch.from_numpy(gt_seg).to(device)).cpu().numpy()) ex_gts.append(gt_seg) ex_rl.append(rl_labels) # ex_clst.append(clst_labels_kmeans) # ex_clst_sp.append(clst_labels_sp_kmeans) # ex_clst_graph_agglo.append(clst_labels_sp_graph_agglo) # ex_mcaff.append(mc_labels_aff) # ex_mc_embed.append(mc_labels_embedding_aff) # map_rl.append(segm_metric(rl_labels, gt_seg)) sbd_rl.append(sbd(gt_seg, rl_labels)) clst_metric_rl(rl_labels, gt_seg) # map_sp_gt.append(segm_metric(ex_mc_gts[-1], gt_seg)) sbd_sp_gt.append(sbd(gt_seg, ex_mc_gts[-1])) metric_sp_gt(ex_mc_gts[-1], gt_seg) # map_embed.append(segm_metric(clst_labels_kmeans, gt_seg)) # clst_metric(clst_labels_kmeans, gt_seg) # map_mcaff.append(segm_metric(mc_labels_aff, gt_seg)) # sbd_mcaff.append(sbd(gt_seg, mc_labels_aff)) # clst_metric_mcaff(mc_labels_aff, gt_seg) # # map_mcembed.append(segm_metric(mc_labels_embedding_aff, gt_seg)) # sbd_mcembed.append(sbd(gt_seg, mc_labels_embedding_aff)) # clst_metric_mcembed(mc_labels_embedding_aff, gt_seg) # # map_graphagglo.append(segm_metric(clst_labels_sp_graph_agglo, gt_seg)) # sbd_graphagglo.append(sbd(gt_seg, clst_labels_sp_graph_agglo.astype(np.int))) # clst_metric_graphagglo(clst_labels_sp_graph_agglo.astype(np.int), gt_seg) print("\nSBD: ") print(f"sp gt : {round(np.array(sbd_sp_gt).mean(), 4)}; {round(np.array(sbd_sp_gt).std(), 4)}") print(f"ours : {round(np.array(sbd_rl).mean(), 4)}; {round(np.array(sbd_rl).std(), 4)}") # print(f"mc node : {np.array(sbd_mcembed).mean()}") # print(f"mc embed : {np.array(sbd_mcaff).mean()}") # print(f"graph agglo : {np.array(sbd_graphagglo).mean()}") # print("\nmAP: ") # print(f"sp gt : {np.array(map_sp_gt).mean()}") # print(f"ours : {np.array(map_rl).mean()}") # print(f"mc node : {np.array(map_mcembed).mean()}") # print(f"mc embed : {np.array(map_mcaff).mean()}") # print(f"graph agglo : {np.array(map_graphagglo).mean()}") # vi_rl_s, vi_rl_m, are_rl, arp_rl, arr_rl = clst_metric_rl.dump() vi_spgt_s, vi_spgt_m, are_spgt, arp_spgt, arr_spgt = metric_sp_gt.dump() # vi_mcaff_s, vi_mcaff_m, are_mcaff, arp_mcaff, arr_mcaff = clst_metric_mcaff.dump() # vi_mcembed_s, vi_mcembed_m, are_mcembed, arp_embed, arr_mcembed = clst_metric_mcembed.dump() # vi_graphagglo_s, vi_graphagglo_m, are_graphagglo, arp_graphagglo, arr_graphagglo = clst_metric_graphagglo.dump() # vi_rl_s_std, vi_rl_m_std, are_rl_std, arp_rl_std, arr_rl_std = clst_metric_rl.dump_std() vi_spgt_s_std, vi_spgt_m_std, are_spgt_std, arp_spgt_std, arr_spgt_std = metric_sp_gt.dump_std() print("\nVI merge: ") print(f"sp gt : {round(vi_spgt_m, 4)}; {round(vi_spgt_m_std, 4)}") print(f"ours : {round(vi_rl_m, 4)}; {round(vi_rl_m_std, 4)}") # print(f"mc affnties : {vi_mcaff_m}") # print(f"mc embed : {vi_mcembed_m}") # print(f"graph agglo : {vi_graphagglo_m}") # print("\nVI split: ") print(f"sp gt : {round(vi_spgt_s, 4)}; {round(vi_spgt_s_std, 4)}") print(f"ours : {round(vi_rl_s, 4)}; {round(vi_rl_s_std, 4)}") # print(f"mc affnties : {vi_mcaff_s}") # print(f"mc embed : {vi_mcembed_s}") # print(f"graph agglo : {vi_graphagglo_s}") # print("\nARE: ") print(f"sp gt : {round(are_spgt, 4)}; {round(are_spgt_std, 4)}") print(f"ours : {round(are_rl, 4)}; {round(are_rl_std, 4)}") # print(f"mc affnties : {are_mcaff}") # print(f"mc embed : {are_mcembed}") # print(f"graph agglo : {are_graphagglo}") # print("\nARP: ") print(f"sp gt : {round(arp_spgt, 4)}; {round(arp_spgt_std, 4)}") print(f"ours : {round(arp_rl, 4)}; {round(arp_rl_std, 4)}") # print(f"mc affnties : {arp_mcaff}") # print(f"mc embed : {arp_embed}") # print(f"graph agglo : {arp_graphagglo}") # print("\nARR: ") print(f"sp gt : {round(arr_spgt, 4)}; {round(arr_spgt_std, 4)}") print(f"ours : {round(arr_rl, 4)}; {round(arr_rl_std, 4)}") # print(f"mc affnties : {arr_mcaff}") # print(f"mc embed : {arr_mcembed}") # print(f"graph agglo : {arr_graphagglo}") exit() for i in range(len(ex_gts)): fig, axs = plt.subplots(2, 4, figsize=(20, 13), sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0}) axs[0, 0].imshow(ex_gts[i], cmap=random_label_cmap(), interpolation="none") axs[0, 0].set_title('gt') axs[0, 0].axis('off') axs[0, 1].imshow(ex_embeds[i]) axs[0, 1].set_title('pc proj') axs[0, 1].axis('off') # axs[0, 2].imshow(ex_clst[i], cmap=random_label_cmap(), interpolation="none") # axs[0, 2].set_title('pix clst') # axs[0, 2].axis('off') axs[0, 2].imshow(ex_clst_graph_agglo[i], cmap=random_label_cmap(), interpolation="none") axs[0, 2].set_title('nagglo') axs[0, 2].axis('off') axs[0, 3].imshow(ex_mc_embed[i], cmap=random_label_cmap(), interpolation="none") axs[0, 3].set_title('mc embed') axs[0, 3].axis('off') axs[1, 0].imshow(ex_mc_gts[i], cmap=random_label_cmap(), interpolation="none") axs[1, 0].set_title('sp gt') axs[1, 0].axis('off') axs[1, 1].imshow(ex_sps[i], cmap=random_label_cmap(), interpolation="none") axs[1, 1].set_title('sp') axs[1, 1].axis('off') # axs[1, 2].imshow(ex_clst_sp[i], cmap=random_label_cmap(), interpolation="none") # axs[1, 2].set_title('sp clst') # axs[1, 2].axis('off') axs[1, 2].imshow(ex_rl[i], cmap=random_label_cmap(), interpolation="none") axs[1, 2].set_title('ours') axs[1, 2].axis('off') axs[1, 3].imshow(ex_mcaff[i], cmap=random_label_cmap(), interpolation="none") axs[1, 3].set_title('mc aff') axs[1, 3].axis('off') plt.show() # wandb.log({"validation/samples": [wandb.Image(fig, caption="sample images")]}) plt.close('all')
def __init__(self, cfg, global_count): super(AgentSaTrainerObjLvlReward, self).__init__() assert torch.cuda.device_count() == 1 self.device = torch.device("cuda:0") torch.cuda.set_device(self.device) torch.set_default_tensor_type(torch.FloatTensor) self.cfg = cfg self.global_count = global_count self.memory = TransitionData_ts(capacity=self.cfg.mem_size) self.best_val_reward = -np.inf if self.cfg.distance == 'cosine': self.distance = CosineDistance() else: self.distance = L2Distance() self.fe_ext = FeExtractor(dict_to_attrdict(self.cfg.backbone), self.distance, cfg.fe_delta_dist, self.device) self.fe_ext.embed_model.load_state_dict( torch.load(self.cfg.fe_model_name)) self.fe_ext.cuda(self.device) self.model = Agent(self.cfg, State, self.distance, self.device) wandb.watch(self.model) self.model.cuda(self.device) self.model_mtx = Lock() self.optimizer = torch.optim.Adam(self.model.actor.parameters(), lr=self.cfg.actor_lr) lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched) bw = lr_sched_cfg.mov_avg_bandwidth off = lr_sched_cfg.mov_avg_offset weights = np.linspace(lr_sched_cfg.weight_range[0], lr_sched_cfg.weight_range[1], bw) weights = weights / weights.sum() # make them sum up to one shed = lr_sched_cfg.torch_sched self.shed = ReduceLROnPlateau(self.optimizer, patience=shed.patience, threshold=shed.threshold, min_lr=shed.min_lr, factor=shed.factor) self.mov_sum_loss = RunningAverage(weights, band_width=bw, offset=off) self.scaler = torch.cuda.amp.GradScaler() self.forwarder = Forwarder() if self.cfg.agent_model_name != "": self.model.load_state_dict(torch.load(self.cfg.agent_model_name)) # finished with prepping for param in self.fe_ext.parameters(): param.requires_grad = False self.train_dset = SpgDset(self.cfg.data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys)) self.val_dset = SpgDset(self.cfg.val_data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys))