def __init__(self, cfg, global_count): super(AgentA2CTrainer, self).__init__() assert torch.cuda.device_count() == 1 self.device = torch.device("cuda:0") torch.cuda.set_device(self.device) torch.set_default_tensor_type(torch.FloatTensor) self.cfg = cfg self.global_count = global_count self.memory = TransitionData_ts(capacity=self.cfg.mem_size) self.best_val_reward = -np.inf if self.cfg.distance == 'cosine': self.distance = CosineDistance() else: self.distance = L2Distance() self.model = Agent(self.cfg, State, self.distance, self.device, with_temp=False) wandb.watch(self.model) self.model.cuda(self.device) self.model_mtx = Lock() MovSumLosses = namedtuple('mov_avg_losses', ('actor', 'critic')) Scalers = namedtuple('Scalers', ('critic', 'actor')) OptimizerContainer = namedtuple('OptimizerContainer', ('actor', 'critic', 'actor_shed', 'critic_shed')) actor_optimizer = torch.optim.Adam(self.model.actor.parameters(), lr=self.cfg.actor_lr) critic_optimizer = torch.optim.Adam(self.model.critic.parameters(), lr=self.cfg.critic_lr) lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched) bw = lr_sched_cfg.mov_avg_bandwidth off = lr_sched_cfg.mov_avg_offset weights = np.linspace(lr_sched_cfg.weight_range[0], lr_sched_cfg.weight_range[1], bw) weights = weights / weights.sum() # make them sum up to one shed = lr_sched_cfg.torch_sched self.mov_sum_losses = MovSumLosses(RunningAverage(weights, band_width=bw, offset=off), RunningAverage(weights, band_width=bw, offset=off)) self.optimizers = OptimizerContainer(actor_optimizer, critic_optimizer, *[ReduceLROnPlateau(opt, patience=shed.patience, threshold=shed.threshold, min_lr=shed.min_lr, factor=shed.factor) for opt in (actor_optimizer, critic_optimizer)]) self.scalers = Scalers(torch.cuda.amp.GradScaler(), torch.cuda.amp.GradScaler()) self.forwarder = Forwarder() if self.cfg.agent_model_name != "": self.model.load_state_dict(torch.load(self.cfg.agent_model_name)) # finished with prepping self.train_dset = SpgDset(self.cfg.data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys), max(self.cfg.s_subgraph)) self.val_dset = SpgDset(self.cfg.val_data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys), max(self.cfg.s_subgraph)) self.segm_metric = AveragePrecision() self.clst_metric = ClusterMetrics() self.global_counter = 0
def pretrain_embeddings_gt(self, model, device, writer=None): dset = SpgDset(root_dir=self.cfg.gen.data_dir) dloader = DataLoader(dset, batch_size=self.cfg.fe.warmup.batch_size, shuffle=True, pin_memory=True, num_workers=0) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) max_p = torch.nn.MaxPool2d(3, padding=1, stride=1) sheduler = ReduceLROnPlateau(optimizer) acc_loss = 0 iteration = 0 while iteration <= self.cfg.fe.warmup.n_iterations: for it, (raw, gt, sp_seg, indices) in enumerate(dloader): raw, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to( device) sp_seg_edge = torch.cat([(-max_p(-sp_seg) != sp_seg).float(), (max_p(sp_seg) != sp_seg).float()], 1) embeddings = model(torch.cat([raw, sp_seg_edge], 1)) loss = self.contr_loss(embeddings, gt.long().squeeze(1)) optimizer.zero_grad() loss.backward(retain_graph=False) optimizer.step() acc_loss += loss.item() if writer is not None: writer.add_scalar("fe_warm_start/loss", loss.item(), iteration) writer.add_scalar("fe_warm_start/lr", optimizer.param_groups[0]['lr'], iteration) if it % 50 == 0: plt.clf() fig = plt.figure(frameon=False) plt.imshow(sp_seg[0].detach().squeeze().cpu().numpy()) plt.colorbar() writer.add_figure("image/sp_seg", fig, iteration // 50) if it % 10 == 0: sheduler.step(acc_loss / 10) acc_loss = 0 iteration += 1 if iteration > self.cfg.fe.warmup.n_iterations: break del loss del embeddings return
def pretrain_embeddings_sp(self, model, device, writer=None): dset = SpgDset(self.args.data_dir, self.cfg.fe.patch_manager, self.cfg.fe.patch_stride, self.cfg.fe.patch_shape) dloader = DataLoader(dset, batch_size=self.cfg.fe.warmup.batch_size, shuffle=True, pin_memory=True, num_workers=0) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) max_p = torch.nn.MaxPool2d(3, padding=1, stride=1) sheduler = ReduceLROnPlateau(optimizer) acc_loss = 0 for i in range(self.cfg.fe.warmup.n_iterations): print(f"fe ext wu iter: {i}") for it, (raw, gt, sp_seg, indices) in enumerate(dloader): raw, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to( device) sp_seg_edge = torch.cat([(-max_p(-sp_seg) != sp_seg).float(), (max_p(sp_seg) != sp_seg).float()], 1) embeddings = model(torch.cat([raw, sp_seg_edge], 1), True if it % 500 == 0 else False) loss = self.contr_loss(embeddings, sp_seg.long().squeeze(1)) optimizer.zero_grad() loss.backward(retain_graph=False) optimizer.step() acc_loss += loss.item() if writer is not None: writer.add_scalar("fe_warm_start/loss", loss.item(), (len(dloader) * i) + it) writer.add_scalar("fe_warm_start/lr", optimizer.param_groups[0]['lr'], (len(dloader) * i) + it) if it % 500 == 0: plt.clf() fig = plt.figure(frameon=False) plt.imshow(sp_seg[0].detach().squeeze().cpu().numpy()) plt.colorbar() writer.add_figure("image/sp_seg", fig, ((len(dloader) * i) + it) // 500) if it % 10 == 0: sheduler.step(acc_loss / 10) acc_loss = 0
def validate(self): self.device = torch.device("cuda:0") model = GcnEdgeAC(self.cfg, self.args, self.device) thresh = 0.5 assert self.args.model_name != "" model.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_name))) model.cuda(self.device) for param in model.parameters(): param.requires_grad = False dloader = DataLoader(SpgDset(root_dir=self.args.data_dir), batch_size=1, shuffle=True, pin_memory=True, num_workers=0) env = SpGcnEnv(self.args, self.device) abs_diffs, rel_diffs, sizes, n_larger_thresh = [], [], [], [] for i in range(len(dloader)): self.update_env_data(env, dloader, self.device) env.reset() state = env.get_state() distr, _, _, _ = self.agent_forward(env, model, state=state, grad=False) actions = torch.sigmoid(distr.loc) diff = (actions - env.b_gt_edge_weights).squeeze().abs() abs_diffs.append(diff.sum().item()) rel_diffs.append(diff.mean().item()) sizes.append(len(diff)) n_larger_thresh.append((diff > thresh).float().sum().item()) mean_size = sum(sizes) / len(sizes) mean_n_larger_thresh = sum(n_larger_thresh) / len(n_larger_thresh) return abs_diffs, rel_diffs, mean_size, mean_n_larger_thresh
def train(self): writer = SummaryWriter(logdir=self.log_dir) writer.add_text("conf", self.cfg.pretty()) device = "cuda:0" wu_cfg = self.cfg.fe.trainer model = UNet2D(self.cfg.fe.n_raw_channels, self.cfg.fe.n_embedding_features, final_sigmoid=False, num_levels=5) momentum_model = UNet2D(self.cfg.fe.n_raw_channels, self.cfg.fe.n_embedding_features, final_sigmoid=False, num_levels=5) if wu_cfg.identical_initialization: soft_update_params(model, momentum_model, 1) momentum_model.cuda(device) for param in momentum_model.parameters(): param.requires_grad = False model.cuda(device) dset = SpgDset(self.cfg.gen.data_dir, wu_cfg.patch_manager, wu_cfg.patch_stride, wu_cfg.patch_shape, wu_cfg.reorder_sp) dloader = DataLoader(dset, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr) sheduler = ReduceLROnPlateau(optimizer, patience=100, threshold=1e-3, min_lr=1e-6, factor=0.1) criterion = EntrInfoNCE(alpha=self.cfg.fe.alpha, beta=self.cfg.fe.beta, lbd=self.cfg.fe.lbd, tau=self.cfg.fe.tau, gamma=self.cfg.fe.gamma, num_neg=self.cfg.fe.num_neg, subs_size=self.cfg.fe.subs_size) tfs = RndAugmentationTfs(wu_cfg.patch_shape) acc_loss = 0 iteration = 0 k_step = math.ceil((wu_cfg.n_iterations - wu_cfg.n_k_stop_it) / (wu_cfg.k_start - wu_cfg.k_stop)) k = wu_cfg.k_start psi_step = (wu_cfg.psi_start - wu_cfg.psi_stop) / ( wu_cfg.n_iterations - wu_cfg.n_k_stop_it) psi = wu_cfg.psi_start while iteration <= wu_cfg.n_iterations: for it, (raw, gt, sp_seg, indices) in enumerate(dloader): inp, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to( device) mask = torch.ones(( inp.shape[0], 1, ) + inp.shape[2:], device=device).float() # get transforms spat_tf, int_tf = tfs.sample(1, 1) _, _int_tf = tfs.sample(1, 1) # add noise to intensity tf of input for momentum network mom_inp = add_sp_gauss_noise(_int_tf(inp), 0.2, 0.1, 0.3) # get momentum prediction embeddings_mom = momentum_model( mom_inp.unsqueeze(2)).squeeze(2) # do the same spatial tf for input, mask and momentum prediction paired = spat_tf(torch.cat((mask, inp, embeddings_mom), -3)) embeddings_mom, mask = paired[..., inp.shape[1] + 1:, :, :], paired[..., 0, :, :][:, None] # do intensity transform for spatial transformed input aug_inp = int_tf(paired[..., 1:inp.shape[1] + 1, :, :]) # and add some noise aug_inp = add_sp_gauss_noise(aug_inp, 0.2, 0.1, 0.3) # get prediction of the augmented input embeddings = model(aug_inp.unsqueeze(2)).squeeze(2) # put embeddings on unit sphere so we can use cosine distance embeddings = embeddings / torch.norm( embeddings, dim=1, keepdim=True) embeddings_mom = embeddings_mom + ( mask == 0) # set the void of the image to the 1-vector embeddings_mom = embeddings_mom / torch.norm( embeddings_mom, dim=1, keepdim=True) loss = criterion(embeddings.squeeze(0), embeddings_mom.squeeze(0), k, mask.squeeze(0), whiten=wu_cfg.whitened_embeddings, warmup=iteration < wu_cfg.n_warmup_it, psi=psi) optimizer.zero_grad() loss.backward() optimizer.step() acc_loss += loss.item() print(loss.item()) writer.add_scalar("fe_warm_start/loss", loss.item(), iteration) writer.add_scalar("fe_warm_start/lr", optimizer.param_groups[0]['lr'], iteration) if (iteration) % 50 == 0: sheduler.step(acc_loss / 10) acc_loss = 0 fig, (a1, a2, a3, a4) = plt.subplots(1, 4, sharex='col', sharey='row', gridspec_kw={ 'hspace': 0, 'wspace': 0 }) a1.imshow(inp[0].cpu().permute(1, 2, 0).squeeze()) a1.set_title('raw') a2.imshow(aug_inp[0].cpu().permute(1, 2, 0)) a2.set_title('augment') a3.imshow( pca_project( get_angles(embeddings).squeeze(0).detach().cpu())) a3.set_title('embed') a4.imshow( pca_project( get_angles(embeddings_mom).squeeze( 0).detach().cpu())) a4.set_title('mom_embed') writer.add_figure("examples", fig, iteration // 100) iteration += 1 psi = max(psi - psi_step, wu_cfg.psi_stop) if iteration % k_step == 0: k = max(k - 1, wu_cfg.k_stop) if iteration > wu_cfg.n_iterations: break if iteration % wu_cfg.momentum == 0: soft_update_params(model, momentum_model, wu_cfg.momentum_tau) return
def train_step(self, rank, start_time, return_dict, writer): device = torch.device("cuda:" + str(rank)) print('Running on device: ', device) torch.cuda.set_device(device) torch.set_default_tensor_type(torch.FloatTensor) self.setup(rank, self.args.num_processes) if self.cfg.MC_DQL: transition = namedtuple('Transition', ('episode')) else: transition = namedtuple( 'Transition', ('state', 'action', 'reward', 'next_state', 'done')) memory = TransitionData_ts(capacity=self.args.t_max, storage_object=transition) env = SpGcnEnv(self.args, device, writer=writer, writer_counter=self.global_writer_quality_count, win_event_counter=self.global_win_event_count) # Create shared network # model = GcnEdgeAC_1(self.cfg, self.args.n_raw_channels, self.args.n_embedding_features, 1, device, writer=writer) model = GcnEdgeAC(self.cfg, self.args, device, writer=writer) # model = GcnEdgeAC(self.cfg, self.args.n_raw_channels, self.args.n_embedding_features, 1, device, writer=writer) model.cuda(device) shared_model = DDP(model, device_ids=[model.device], find_unused_parameters=True) # dloader = DataLoader(MultiDiscSpGraphDsetBalanced(no_suppix=False, create=False), batch_size=1, shuffle=True, pin_memory=True, # num_workers=0) dloader = DataLoader(SpgDset(), batch_size=self.cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) # Create optimizer for shared network parameters with shared statistics # optimizer = CstmAdam(shared_model.parameters(), lr=self.args.lr, betas=self.args.Adam_betas, # weight_decay=self.args.Adam_weight_decay) ###################### self.action_range = 1 self.device = torch.device(device) self.discount = 0.5 self.critic_tau = self.cfg.critic_tau self.actor_update_frequency = self.cfg.actor_update_frequency self.critic_target_update_frequency = self.cfg.critic_target_update_frequency self.batch_size = self.cfg.batch_size self.log_alpha = torch.tensor(np.log(self.cfg.init_temperature)).to( self.device) self.log_alpha.requires_grad = True # set target entropy to -|A| ###################### # optimizers OptimizerContainer = namedtuple('OptimizerContainer', ('actor', 'critic', 'temperature')) actor_optimizer = torch.optim.Adam( shared_model.module.actor.parameters(), lr=self.cfg.actor_lr, betas=self.cfg.actor_betas) critic_optimizer = torch.optim.Adam( shared_model.module.critic.parameters(), lr=self.cfg.critic_lr, betas=self.cfg.critic_betas) temp_optimizer = torch.optim.Adam([self.log_alpha], lr=self.cfg.alpha_lr, betas=self.cfg.alpha_betas) optimizers = OptimizerContainer(actor_optimizer, critic_optimizer, temp_optimizer) if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only: fe_extr = shared_model.module.fe_ext fe_extr.cuda(device) self.fe_extr_warm_start_1(fe_extr, writer=writer) # self.fe_extr_warm_start(fe_extr, writer=writer) if self.args.model_name == "" and not self.args.no_save: torch.save(fe_extr.state_dict(), os.path.join(self.save_dir, 'agent_model_fe_extr')) elif not self.args.no_save: torch.save(fe_extr.state_dict(), os.path.join(self.save_dir, self.args.model_name)) dist.barrier() for param in model.fe_ext.parameters(): param.requires_grad = False if self.args.model_name != "": shared_model.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_name))) elif self.args.model_fe_name != "": shared_model.module.fe_ext.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_fe_name))) elif self.args.fe_extr_warmup: print('loaded fe extractor') shared_model.module.fe_ext.load_state_dict( torch.load(os.path.join(self.save_dir, 'agent_model_fe_extr'))) if not self.args.test_score_only: quality = self.args.stop_qual_scaling + self.args.stop_qual_offset best_quality = np.inf last_quals = [] while self.global_count.value() <= self.args.T_max: if self.global_count.value() == 78: a = 1 self.update_env_data(env, dloader, device) # waff_dis = torch.softmax(env.edge_features[:, 0].squeeze() + 1e-30, dim=0) # waff_dis = torch.softmax(env.gt_edge_weights + 0.5, dim=0) waff_dis = torch.softmax(torch.ones_like( env.b_gt_edge_weights), dim=0) loss_weight = torch.softmax(env.b_gt_edge_weights + 1, dim=0) env.reset() # self.target_entropy = - float(env.gt_edge_weights.shape[0]) self.target_entropy = -8.0 env.stop_quality = self.stop_qual_rule.apply( self.global_count.value(), quality) if self.cfg.temperature_regulation == 'follow_quality': self.alpha = self.eps_rule.apply(self.global_count.value(), quality) print(self.alpha.item()) with open(os.path.join(self.save_dir, 'runtime_cfg.yaml')) as info: args_dict = yaml.full_load(info) if args_dict is not None: if 'safe_model' in args_dict: self.args.safe_model = args_dict['safe_model'] args_dict['safe_model'] = False if 'add_noise' in args_dict: self.args.add_noise = args_dict['add_noise'] if 'critic_lr' in args_dict and args_dict[ 'critic_lr'] != self.cfg.critic_lr: self.cfg.critic_lr = args_dict['critic_lr'] adjust_learning_rate(critic_optimizer, self.cfg.critic_lr) if 'actor_lr' in args_dict and args_dict[ 'actor_lr'] != self.cfg.actor_lr: self.cfg.actor_lr = args_dict['actor_lr'] adjust_learning_rate(actor_optimizer, self.cfg.actor_lr) if 'alpha_lr' in args_dict and args_dict[ 'alpha_lr'] != self.cfg.alpha_lr: self.cfg.alpha_lr = args_dict['alpha_lr'] adjust_learning_rate(temp_optimizer, self.cfg.alpha_lr) with open(os.path.join(self.save_dir, 'runtime_cfg.yaml'), "w") as info: yaml.dump(args_dict, info) if self.args.safe_model: best_quality = quality if rank == 0: if self.args.model_name_dest != "": torch.save( shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name_dest)) else: torch.save( shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) state = env.get_state() while not env.done: # Calculate policy and values post_input = True if ( self.global_count.value() + 1) % 15 == 0 and env.counter == 0 else False round_n = env.counter # sample action for data collection distr = None if self.global_count.value() < self.cfg.num_seed_steps: action = torch.rand_like(env.b_current_edge_weights) else: distr, _, _, action = self.agent_forward( env, shared_model, state=state, grad=False, post_input=post_input) logg_dict = {'temperature': self.alpha.item()} if distr is not None: logg_dict['mean_loc'] = distr.loc.mean().item() logg_dict['mean_scale'] = distr.scale.mean().item() if self.global_count.value( ) >= self.cfg.num_seed_steps and memory.is_full(): self._step(memory, optimizers, env, shared_model, self.global_count.value(), writer=writer) self.global_writer_loss_count.increment() next_state, reward, quality = env.execute_action( action, logg_dict) last_quals.append(quality) if len(last_quals) > 10: last_quals.pop(0) if self.args.add_noise: noise = torch.randn_like(reward) * self.alpha.item() reward = reward + noise memory.push(self.state_to_cpu(state), action, reward, self.state_to_cpu(next_state), env.done) # Train the network # self._step(memory, shared_model, env, optimizer, loss_weight, off_policy=True, writer=writer) # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward # Optionally clamp rewards # done = done or episode_length >= self.args.max_episode_length # Stop episodes at a max length state = next_state self.global_count.increment() dist.barrier() if rank == 0: if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save: # pass if self.args.model_name_dest != "": torch.save( shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name_dest)) print('saved') else: torch.save(shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) self.cleanup() return sum(last_quals) / 10
def train(self): writer = SummaryWriter(logdir=self.log_dir) writer.add_text("conf", self.cfg.pretty()) device = "cuda:0" wu_cfg = self.cfg.fe.trainer model = UNet2D(self.cfg.fe.n_raw_channels, self.cfg.fe.n_embedding_features, final_sigmoid=False, num_levels=5) model.cuda(device) dset = SpgDset(self.cfg.gen.data_dir, wu_cfg.patch_manager, wu_cfg.patch_stride, wu_cfg.patch_shape, wu_cfg.reorder_sp) dloader = DataLoader(dset, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr) sheduler = ReduceLROnPlateau(optimizer, patience=100, threshold=1e-3, min_lr=1e-6, factor=0.1) criterion = RagInfoNCE(tau=self.cfg.fe.tau) acc_loss = 0 iteration = 0 while iteration <= wu_cfg.n_iterations: for it, (raw, gt, sp_seg, indices) in enumerate(dloader): inp, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to( device) edges = dloader.dataset.get_graphs(indices, sp_seg, device)[0] off = 0 for i in range(len(edges)): sp_seg[i] += off edges[i] += off off = sp_seg[i].max() + 1 edges = torch.cat(edges, 1) embeddings = model(inp.unsqueeze(2)).squeeze(2) # put embeddings on unit sphere so we can use cosine distance embeddings = embeddings / torch.norm( embeddings, dim=1, keepdim=True) loss = criterion(embeddings, sp_seg, edges) optimizer.zero_grad() loss.backward() optimizer.step() acc_loss += loss.item() print(loss.item()) writer.add_scalar("fe_warm_start/loss", loss.item(), iteration) writer.add_scalar("fe_warm_start/lr", optimizer.param_groups[0]['lr'], iteration) if (iteration) % 50 == 0: sheduler.step(acc_loss / 10) acc_loss = 0 fig, (a1, a2) = plt.subplots(1, 2, sharex='col', sharey='row', gridspec_kw={ 'hspace': 0, 'wspace': 0 }) a1.imshow(inp[0].cpu().permute(1, 2, 0).squeeze()) a1.set_title('raw') a2.imshow( pca_project( get_angles(embeddings).squeeze(0).detach().cpu())) a2.set_title('embed') writer.add_figure("examples", fig, iteration // 100) iteration += 1 if iteration > wu_cfg.n_iterations: break return
def __init__(self, cfg, global_count): super(AgentSacTrainerObjLvlReward, self).__init__() assert torch.cuda.device_count() == 1 self.device = torch.device("cuda:0") torch.cuda.set_device(self.device) torch.set_default_tensor_type(torch.FloatTensor) self.cfg = cfg self.global_count = global_count self.memory = TransitionData_ts(capacity=self.cfg.mem_size) self.best_val_reward = -np.inf if self.cfg.distance == 'cosine': self.distance = CosineDistance() else: self.distance = L2Distance() self.fe_ext = FeExtractor(dict_to_attrdict(self.cfg.backbone), self.distance, cfg.fe_delta_dist, self.device) self.fe_ext.embed_model.load_state_dict( torch.load(self.cfg.fe_model_name)) self.fe_ext.cuda(self.device) self.model = Agent(self.cfg, State, self.distance, self.device) wandb.watch(self.model) self.model.cuda(self.device) self.model_mtx = Lock() MovSumLosses = namedtuple('mov_avg_losses', ('actor', 'critic', 'temperature')) Scalers = namedtuple('Scalers', ('critic', 'actor')) OptimizerContainer = namedtuple( 'OptimizerContainer', ('actor', 'critic', 'temperature', 'actor_shed', 'critic_shed', 'temp_shed')) actor_optimizer = torch.optim.Adam(self.model.actor.parameters(), lr=self.cfg.actor_lr) critic_optimizer = torch.optim.Adam(self.model.critic.parameters(), lr=self.cfg.critic_lr) temp_optimizer = torch.optim.Adam([self.model.log_alpha], lr=self.cfg.alpha_lr) lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched) bw = lr_sched_cfg.mov_avg_bandwidth off = lr_sched_cfg.mov_avg_offset weights = np.linspace(lr_sched_cfg.weight_range[0], lr_sched_cfg.weight_range[1], bw) weights = weights / weights.sum() # make them sum up to one shed = lr_sched_cfg.torch_sched self.mov_sum_losses = MovSumLosses( RunningAverage(weights, band_width=bw, offset=off), RunningAverage(weights, band_width=bw, offset=off), RunningAverage(weights, band_width=bw, offset=off)) self.optimizers = OptimizerContainer( actor_optimizer, critic_optimizer, temp_optimizer, *[ ReduceLROnPlateau(opt, patience=shed.patience, threshold=shed.threshold, min_lr=shed.min_lr, factor=shed.factor) for opt in (actor_optimizer, critic_optimizer, temp_optimizer) ]) self.scalers = Scalers(torch.cuda.amp.GradScaler(), torch.cuda.amp.GradScaler()) self.forwarder = Forwarder() if self.cfg.agent_model_name != "": self.model.load_state_dict(torch.load(self.cfg.agent_model_name)) # if "policy_warmup" in self.cfg and self.cfg.agent_model_name == "": # supervised_policy_pretraining(self.model, self.env, self.cfg, device=self.device) # torch.save(self.model.state_dict(), os.path.join(wandb.run.dir, "sv_pretrained_policy_agent.pth")) # finished with prepping for param in self.fe_ext.parameters(): param.requires_grad = False self.train_dset = SpgDset(self.cfg.data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys)) self.val_dset = SpgDset(self.cfg.val_data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys))
def train(self): writer = SummaryWriter(logdir=self.log_dir) device = "cuda:0" wu_cfg = self.cfg.fe.trainer model = UNet2D(**self.cfg.fe.backbone) model.cuda(device) train_set = SpgDset( "/g/kreshuk/hilt/projects/data/leptin_fused_tp1_ch_0/true_val", reorder_sp=True) val_set = SpgDset( "/g/kreshuk/hilt/projects/data/leptin_fused_tp1_ch_0/train", reorder_sp=True) # pm = StridedPatches2D(wu_cfg.patch_stride, wu_cfg.patch_shape, train_set.image_shape) train_loader = DataLoader(train_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) val_loader = DataLoader(val_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr) sheduler = ReduceLROnPlateau(optimizer, patience=40, threshold=1e-4, min_lr=1e-5, factor=0.1) criterion = RagContrastiveWeights(delta_var=0.1, delta_dist=0.3) acc_loss = 0 valit = 0 iteration = 0 best_loss = np.inf while iteration <= wu_cfg.n_iterations: for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(train_loader): raw, gt = raw.to(device), gt.to(device) loss_embeds = model(raw[:, :, None]).squeeze(2) loss_embeds = loss_embeds / ( torch.norm(loss_embeds, dim=1, keepdim=True) + 1e-9) edges = [ feats.compute_rag(seg.cpu().numpy()).uvIds() for seg in gt ] edges = [ torch.from_numpy(e.astype(np.long)).to(device).T for e in edges ] loss = criterion(loss_embeds, gt.long(), edges, None, 30) optimizer.zero_grad() loss.backward() optimizer.step() print(loss.item()) # writer.add_scalar("fe_train/lr", optimizer.param_groups[0]['lr'], iteration) # writer.add_scalar("fe_train/loss", loss.item(), iteration) # if (iteration) % 100 == 0: # # fig, (a1, a2, a3) = plt.subplots(3, 1, sharex='col', sharey='row', # gridspec_kw={'hspace': 0, 'wspace': 0}) # a1.imshow(raw[0, 0].cpu().squeeze()) # a1.set_title('train raw') # a2.imshow(pca_project(loss_embeds[0].detach().cpu())) # a2.set_title('train embed') # a3.imshow(gt[0, 0].cpu().squeeze()) # a3.set_title('train gt') # plt.show() # # with torch.set_grad_enabled(False): # for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(val_loader): # raw = raw.to(device) # embeds = model(raw[:, :, None]).squeeze(2) # embeds = embeds / (torch.norm(embeds, dim=1, keepdim=True) + 1e-9) # # print(loss.item()) # writer.add_scalar("fe_train/lr", optimizer.param_groups[0]['lr'], iteration) # writer.add_scalar("fe_train/loss", loss.item(), iteration) # fig, (a1, a2) = plt.subplots(2, 1, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0}) # a1.imshow(raw[0, 0].cpu().squeeze()) # a1.set_title('raw') # a2.imshow(pca_project(embeds[0].detach().cpu())) # a2.set_title('embed') # plt.show() # if it > 2: # break iteration += 1 print(iteration) if iteration > wu_cfg.n_iterations: print(self.save_dir) torch.save(model.state_dict(), os.path.join(self.save_dir, "last_model.pth")) break return
def train(self): writer = SummaryWriter(logdir=self.log_dir) device = "cuda:0" wu_cfg = self.cfg.fe.trainer model = UNet2D(**self.cfg.fe.backbone) model.cuda(device) # train_set = SpgDset(self.cfg.gen.data_dir_raw_train, patch_manager="no_cross", patch_stride=(10,10), patch_shape=(300,300), reorder_sp=True) # val_set = SpgDset(self.cfg.gen.data_dir_raw_val, patch_manager="no_cross", patch_stride=(10,10), patch_shape=(300,300), reorder_sp=True) train_set = SpgDset(self.cfg.gen.data_dir_raw_train, reorder_sp=True) val_set = SpgDset(self.cfg.gen.data_dir_raw_val, reorder_sp=True) # pm = StridedPatches2D(wu_cfg.patch_stride, wu_cfg.patch_shape, train_set.image_shape) pm = NoPatches2D() train_set.length = len(train_set.graph_file_names) * np.prod( pm.n_patch_per_dim) train_set.n_patch_per_dim = pm.n_patch_per_dim val_set.length = len(val_set.graph_file_names) gauss_kernel = GaussianSmoothing(1, 5, 3, device=device) # dset = LeptinDset(self.cfg.gen.data_dir_raw, self.cfg.gen.data_dir_affs, wu_cfg.patch_manager, wu_cfg.patch_stride, wu_cfg.patch_shape, wu_cfg.reorder_sp) train_loader = DataLoader(train_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) val_loader = DataLoader(val_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr) sheduler = ReduceLROnPlateau(optimizer, patience=80, threshold=1e-4, min_lr=1e-8, factor=0.1) slcs = [ slice(None, self.cfg.fe.embeddings_separator), slice(self.cfg.fe.embeddings_separator, None) ] criterion = RegRagContrastiveWeights(delta_var=0.1, delta_dist=0.3, slices=slcs) acc_loss = 0 valit = 0 iteration = 0 best_loss = np.inf while iteration <= wu_cfg.n_iterations: for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(train_loader): raw, gt, sp_seg, affinities = raw.to(device), gt.to( device), sp_seg.to(device), affinities.to(device) sp_seg = sp_seg + 1 edge_img = F.pad(get_contour_from_2d_binary(sp_seg), (2, 2, 2, 2), mode='constant') edge_img = gauss_kernel(edge_img.float()) all = torch.cat([raw, gt, sp_seg, edge_img], dim=1) angle = float(torch.randint(-180, 180, (1, )).item()) rot_all = tvF.rotate(all, angle, PIL.Image.NEAREST) rot_raw = rot_all[:, :1] rot_gt = rot_all[:, 1:2] rot_sp = rot_all[:, 2:3] rot_edge_img = rot_all[:, 3:] angle = abs(angle / 180) valid_sp = [] for i in range(len(rot_sp)): _valid_sp = torch.unique(rot_sp[i], sorted=True) _valid_sp = _valid_sp[1:] if _valid_sp[ 0] == 0 else _valid_sp if len(_valid_sp) > self.cfg.gen.sp_samples_per_step: inds = torch.multinomial( torch.ones_like(_valid_sp), self.cfg.gen.sp_samples_per_step, replacement=False) _valid_sp = _valid_sp[inds] valid_sp.append(_valid_sp) _rot_sp, _sp_seg = [], [] for val_sp, rsp, sp in zip(valid_sp, rot_sp, sp_seg): mask = rsp == val_sp[:, None, None] _rot_sp.append((mask * (torch.arange( len(val_sp), device=rsp.device)[:, None, None] + 1) ).sum(0)) mask = sp == val_sp[:, None, None] _sp_seg.append((mask * (torch.arange( len(val_sp), device=sp.device)[:, None, None] + 1) ).sum(0)) rot_sp = torch.stack(_rot_sp) sp_seg = torch.stack(_sp_seg) valid_sp = [ torch.unique(_rot_sp, sorted=True) for _rot_sp in rot_sp ] valid_sp = [ _valid_sp[1:] if _valid_sp[0] == 0 else _valid_sp for _valid_sp in valid_sp ] inp = torch.cat([ torch.cat([raw, edge_img], 1), torch.cat([rot_raw, rot_edge_img], 1) ], 0) offs = offs.numpy().tolist() edge_feat, edges = tuple( zip(*[ get_edge_features_1d(seg.squeeze().cpu().numpy(), os, affs.squeeze().cpu().numpy()) for seg, os, affs in zip(sp_seg, offs, affinities) ])) edges = [ torch.from_numpy(e.astype(np.long)).to(device).T for e in edges ] edge_weights = [ torch.from_numpy(ew.astype(np.float32)).to(device)[:, 0][None] for ew in edge_feat ] valid_edges_masks = [ (_edges[None] == _valid_sp[:, None, None]).sum(0).sum(0) == 2 for _valid_sp, _edges in zip(valid_sp, edges) ] edges = [ _edges[:, valid_edges_mask] - 1 for _edges, valid_edges_mask in zip( edges, valid_edges_masks) ] edge_weights = [ _edge_weights[:, valid_edges_mask] for _edge_weights, valid_edges_mask in zip( edge_weights, valid_edges_masks) ] # put embeddings on unit sphere so we can use cosine distance loss_embeds = model(inp[:, :, None]).squeeze(2) loss_embeds = criterion.norm_each_space(loss_embeds, 1) loss = criterion(loss_embeds, sp_seg.long(), rot_sp.long(), edges, edge_weights, valid_sp, angle, chunks=int(sp_seg.max().item() // self.cfg.gen.train_chunk_size)) optimizer.zero_grad() loss.backward() optimizer.step() print(f"step {iteration}: {loss.item()}") writer.add_scalar("fe_train/lr", optimizer.param_groups[0]['lr'], iteration) writer.add_scalar("fe_train/loss", loss.item(), iteration) if (iteration) % 100 == 0: with torch.set_grad_enabled(False): model.eval() print("####start validation####") for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(val_loader): raw, gt, sp_seg, affinities = raw.to( device), gt.to(device), sp_seg.to( device), affinities.to(device) sp_seg = sp_seg + 1 edge_img = F.pad( get_contour_from_2d_binary(sp_seg), (2, 2, 2, 2), mode='constant') edge_img = gauss_kernel(edge_img.float()) all = torch.cat([raw, gt, sp_seg, edge_img], dim=1) angle = float( torch.randint(-180, 180, (1, )).item()) rot_all = tvF.rotate(all, angle, PIL.Image.NEAREST) rot_raw = rot_all[:, :1] rot_gt = rot_all[:, 1:2] rot_sp = rot_all[:, 2:3] rot_edge_img = rot_all[:, 3:] angle = abs(angle / 180) valid_sp = [ torch.unique(_rot_sp, sorted=True) for _rot_sp in rot_sp ] valid_sp = [ _valid_sp[1:] if _valid_sp[0] == 0 else _valid_sp for _valid_sp in valid_sp ] _rot_sp, _sp_seg = [], [] for val_sp, rsp, sp in zip(valid_sp, rot_sp, sp_seg): mask = rsp == val_sp[:, None, None] _rot_sp.append((mask * (torch.arange( len(val_sp), device=rsp.device)[:, None, None] + 1) ).sum(0)) mask = sp == val_sp[:, None, None] _sp_seg.append((mask * (torch.arange( len(val_sp), device=sp.device)[:, None, None] + 1) ).sum(0)) rot_sp = torch.stack(_rot_sp) sp_seg = torch.stack(_sp_seg) valid_sp = [ torch.unique(_rot_sp, sorted=True) for _rot_sp in rot_sp ] valid_sp = [ _valid_sp[1:] if _valid_sp[0] == 0 else _valid_sp for _valid_sp in valid_sp ] inp = torch.cat([ torch.cat([raw, edge_img], 1), torch.cat([rot_raw, rot_edge_img], 1) ], 0) offs = offs.numpy().tolist() edge_feat, edges = tuple( zip(*[ get_edge_features_1d( seg.squeeze().cpu().numpy(), os, affs.squeeze().cpu().numpy()) for seg, os, affs in zip( sp_seg, offs, affinities) ])) edges = [ torch.from_numpy(e.astype( np.long)).to(device).T for e in edges ] edge_weights = [ torch.from_numpy(ew.astype( np.float32)).to(device)[:, 0][None] for ew in edge_feat ] valid_edges_masks = [ (_edges[None] == _valid_sp[:, None, None] ).sum(0).sum(0) == 2 for _valid_sp, _edges in zip(valid_sp, edges) ] edges = [ _edges[:, valid_edges_mask] - 1 for _edges, valid_edges_mask in zip( edges, valid_edges_masks) ] edge_weights = [ _edge_weights[:, valid_edges_mask] for _edge_weights, valid_edges_mask in zip( edge_weights, valid_edges_masks) ] # put embeddings on unit sphere so we can use cosine distance embeds = model(inp[:, :, None]).squeeze(2) embeds = criterion.norm_each_space(embeds, 1) ls = criterion( embeds, sp_seg.long(), rot_sp.long(), edges, edge_weights, valid_sp, angle, chunks=int(sp_seg.max().item() // self.cfg.gen.train_chunk_size)) acc_loss += ls writer.add_scalar("fe_val/loss", ls, valit) print(f"step {it}: {ls.item()}") valit += 1 acc_loss = acc_loss / len(val_loader) if acc_loss < best_loss: print(self.save_dir) torch.save( model.state_dict(), os.path.join(self.save_dir, "best_val_model.pth")) best_loss = acc_loss sheduler.step(acc_loss) acc_loss = 0 fig, ((a1, a2), (a3, a4)) = plt.subplots(2, 2, sharex='col', sharey='row', gridspec_kw={ 'hspace': 0, 'wspace': 0 }) a1.imshow(raw[0].cpu().permute(1, 2, 0).squeeze()) a1.set_title('raw') a2.imshow( cm.prism(sp_seg[0].cpu().squeeze() / sp_seg[0].cpu().squeeze().max())) a2.set_title('sp') a3.imshow(pca_project(embeds[0, slcs[0]].detach().cpu())) a3.set_title('embed', y=-0.01) a4.imshow(pca_project(embeds[0, slcs[1]].detach().cpu())) a4.set_title('embed rot', y=-0.01) plt.show() writer.add_figure("examples", fig, iteration // 100) # model.train() print("####end validation####") iteration += 1 if iteration > wu_cfg.n_iterations: print(self.save_dir) torch.save(model.state_dict(), os.path.join(self.save_dir, "last_model.pth")) break return
def train_step(self, rank, writer): device = torch.device("cuda:" + str(rank // self.cfg.gen.n_processes_per_gpu)) print('Running on device: ', device) torch.cuda.set_device(device) torch.set_default_tensor_type(torch.FloatTensor) self.setup(rank, self.cfg.gen.n_processes_per_gpu * self.cfg.gen.n_gpu) env = SpGcnEnv(self.cfg, device, writer=writer, writer_counter=self.global_writer_quality_count) # Create shared network model = GcnEdgeAC(self.cfg, device, writer=writer) model.cuda(device) shared_model = DDP(model, device_ids=[device], find_unused_parameters=True) if 'extra' in self.cfg.fe.optim: # optimizers MovSumLosses = namedtuple( 'mov_avg_losses', ('actor', 'embeddings', 'critic', 'temperature')) OptimizerContainer = namedtuple( 'OptimizerContainer', ('actor', 'embeddings', 'critic', 'temperature', 'actor_shed', 'embed_shed', 'critic_shed', 'temp_shed')) else: MovSumLosses = namedtuple('mov_avg_losses', ('actor', 'critic', 'temperature')) OptimizerContainer = namedtuple( 'OptimizerContainer', ('actor', 'critic', 'temperature', 'actor_shed', 'critic_shed', 'temp_shed')) if "rl_loss" == self.cfg.fe.optim: actor_optimizer = torch.optim.Adam( list(shared_model.module.actor.parameters()) + list(shared_model.module.fe_ext.parameters()), lr=self.cfg.sac.actor_lr, betas=self.cfg.sac.actor_betas) else: actor_optimizer = torch.optim.Adam( shared_model.module.actor.parameters(), lr=self.cfg.sac.actor_lr, betas=self.cfg.sac.actor_betas) if "extra" in self.cfg.fe.optim: embeddings_optimizer = torch.optim.Adam( shared_model.module.fe_ext.parameters(), lr=self.cfg.fe.lr, betas=self.cfg.fe.betas) critic_optimizer = torch.optim.Adam( shared_model.module.critic.parameters(), lr=self.cfg.sac.critic_lr, betas=self.cfg.sac.critic_betas) temp_optimizer = torch.optim.Adam([shared_model.module.log_alpha], lr=self.cfg.sac.alpha_lr, betas=self.cfg.sac.alpha_betas) if "extra" in self.cfg.fe.optim: mov_sum_losses = MovSumLosses(RunningAverage(), RunningAverage(), RunningAverage(), RunningAverage()) optimizers = OptimizerContainer( actor_optimizer, embeddings_optimizer, critic_optimizer, temp_optimizer, ReduceLROnPlateau(actor_optimizer), ReduceLROnPlateau(embeddings_optimizer), ReduceLROnPlateau(critic_optimizer), ReduceLROnPlateau(temp_optimizer)) else: mov_sum_losses = MovSumLosses(RunningAverage(), RunningAverage(), RunningAverage()) optimizers = OptimizerContainer( actor_optimizer, critic_optimizer, temp_optimizer, ReduceLROnPlateau(actor_optimizer), ReduceLROnPlateau(critic_optimizer), ReduceLROnPlateau(temp_optimizer)) dist.barrier() if self.cfg.gen.resume: shared_model.module.load_state_dict( torch.load(os.path.join(self.log_dir, self.cfg.gen.model_name))) elif self.cfg.fe.load_pretrained: shared_model.module.fe_ext.load_state_dict( torch.load(os.path.join(self.save_dir, self.cfg.fe.model_name))) elif 'warmup' in self.cfg.fe and rank == 0: print('pretrain fe extractor') self.pretrain_embeddings_gt(shared_model.module.fe_ext, device, writer) torch.save(shared_model.module.fe_ext.state_dict(), os.path.join(self.save_dir, self.cfg.fe.model_name)) dist.barrier() if "none" == self.cfg.fe.optim: for param in shared_model.module.fe_ext.parameters(): param.requires_grad = False dset = SpgDset(self.cfg.gen.data_dir) step = 0 while self.global_count.value() <= self.cfg.trainer.T_max: dloader = DataLoader(dset, batch_size=self.cfg.trainer.batch_size, shuffle=True, pin_memory=True, num_workers=0) for iteration in range( len(dset) * self.cfg.trainer.data_update_frequency): # if self.global_count.value() > self.args.T_max: # a=1 if iteration % self.cfg.trainer.data_update_frequency == 0: self.update_env_data(env, dloader, device) # waff_dis = torch.softmax(env.edge_features[:, 0].squeeze() + 1e-30, dim=0) # waff_dis = torch.softmax(env.gt_edge_weights + 0.5, dim=0) # waff_dis = torch.softmax(torch.ones_like(env.b_gt_edge_weights), dim=0) # loss_weight = torch.softmax(env.b_gt_edge_weights + 1, dim=0) env.reset() self.update_rt_vars(critic_optimizer, actor_optimizer) if rank == 0 and self.cfg.rt_vars.safe_model: if self.cfg.gen.model_name != "": torch.save( shared_model.module.state_dict(), os.path.join(self.log_dir, self.cfg.gen.model_name)) else: torch.save(shared_model.module.state_dict(), os.path.join(self.log_dir, 'agent_model')) state = env.get_state() while not env.done: # Calculate policy and values post_stats = True if (self.global_writer_count.value() + 1) % self.cfg.trainer.post_stats_frequency == 0 \ else False post_model = True if (self.global_writer_count.value() + 1) % self.cfg.trainer.post_model_frequency == 0 \ else False post_stats &= self.memory.is_full() post_model &= self.memory.is_full() distr = None if not self.memory.is_full(): action = torch.rand_like(env.current_edge_weights) else: distr, _, _, action, _, _ = self.agent_forward( env, shared_model, state=state, grad=False, post_input=post_stats, post_model=post_model) logg_dict = {} if post_stats: for i in range(len(self.cfg.sac.s_subgraph)): logg_dict[ 'alpha_' + str(i)] = shared_model.module.alpha[i].item() if distr is not None: logg_dict['mean_loc'] = distr.loc.mean().item() logg_dict['mean_scale'] = distr.scale.mean().item() if self.memory.is_full(): for i in range(self.cfg.trainer.n_updates_per_step): self._step(self.memory, optimizers, mov_sum_losses, env, shared_model, step, writer=writer) self.global_writer_loss_count.increment() next_state, reward = env.execute_action( action, logg_dict, post_stats=post_stats) # next_state, reward, quality = env.execute_action(torch.sigmoid(distr.loc), logg_dict, post_stats=post_stats) if self.cfg.rt_vars.add_noise: noise = torch.randn_like(reward) * 0.2 reward = reward + noise self.memory.push(self.state_to_cpu(state), action, reward, self.state_to_cpu(next_state), env.done) state = next_state self.global_count.increment() step += 1 if rank == 0: self.global_writer_count.increment() if step > self.cfg.trainer.T_max: break dist.barrier() if rank == 0: self.memory.clear() if not self.cfg.gen.cross_validate_hp and not self.cfg.gen.test_score_only and not self.cfg.gen.no_save: # pass if self.cfg.gen.model_name != "": torch.save( shared_model.state_dict(), os.path.join(self.log_dir, self.cfg.gen.model_name)) print('saved') else: torch.save(shared_model.state_dict(), os.path.join(self.log_dir, 'agent_model')) self.cleanup() return sum(env.acc_reward) / len(env.acc_reward)
def train(self): writer = SummaryWriter(logdir=self.log_dir) device = "cuda:0" wu_cfg = self.cfg.fe.trainer model = UNet2D(**self.cfg.fe.backbone) model.cuda(device) train_set = SpgDset(self.cfg.gen.data_dir_raw_train, reorder_sp=False) val_set = SpgDset(self.cfg.gen.data_dir_raw_val, reorder_sp=False) train_loader = DataLoader(train_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) val_loader = DataLoader(val_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr) criterion = AffinityContrastive(delta_var=0.1, delta_dist=0.3) sheduler = ReduceLROnPlateau(optimizer, patience=5, threshold=1e-4, min_lr=1e-5, factor=0.1) valit = 0 iteration = 0 best_loss = np.inf while iteration <= wu_cfg.n_iterations: for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(train_loader): raw, gt, sp_seg, affinities, offs = raw.to(device), gt.to( device), sp_seg.to(device), affinities.to( device), offs[0].to(device) input = torch.cat([raw, affinities], dim=1) embeddings = model(input.unsqueeze(2)).squeeze(2) embeddings = embeddings / torch.norm( embeddings, dim=1, keepdim=True) loss = criterion(embeddings, affinities, offs) optimizer.zero_grad() loss.backward() optimizer.step() lr = optimizer.param_groups[0]['lr'] print(f"step {it}; lr({lr}); loss({loss.item()})") writer.add_scalar("fe_warm_start/loss", loss.item(), iteration) writer.add_scalar("fe_warm_start/lr", lr, iteration) if (iteration) % 100 == 0: acc_loss = 0 with torch.set_grad_enabled(False): for val_it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(val_loader): raw, gt, sp_seg, affinities, offs = raw.to( device), gt.to(device), sp_seg.to( device), affinities.to(device), offs[0].to( device) input = torch.cat([raw, affinities], dim=1) embeddings = model(input.unsqueeze(2)).squeeze(2) embeddings = embeddings / torch.norm( embeddings, dim=1, keepdim=True) loss = criterion(embeddings, affinities, offs) acc_loss += loss writer.add_scalar("fe_val/loss", loss, valit) valit += 1 acc_loss = acc_loss / len(val_loader) if acc_loss < best_loss: torch.save( model.state_dict(), os.path.join(self.save_dir, "best_val_model.pth")) best_loss = acc_loss sheduler.step(acc_loss) fig, (a1, a2) = plt.subplots(1, 2, sharex='col', sharey='row', gridspec_kw={ 'hspace': 0, 'wspace': 0 }) a1.imshow(raw[0].cpu().permute(1, 2, 0).squeeze()) a1.set_title('raw') a2.imshow(pca_project(embeddings[0].detach().cpu())) a2.set_title('embed') plt.show() # writer.add_figure("examples", fig, iteration // 50) iteration += 1 if iteration > wu_cfg.n_iterations: break return
def validate_and_compare_to_clustering(model, env, distance, device, cfg): """validates the prediction against the method of clustering the embedding space""" model.eval() offs = [[1, 0], [0, 1], [2, 0], [0, 2], [4, 0], [0, 4], [16, 0], [0, 16]] ex_raws, ex_sps, ex_gts, ex_mc_gts, ex_embeds, ex_clst, ex_clst_sp, ex_mcaff, ex_mc_embed, ex_rl, \ ex_clst_graph_agglo= [], [], [], [], [], [], [], [], [], [], [] dset = SpgDset(cfg.val_data_dir, dict_to_attrdict(cfg.patch_manager), dict_to_attrdict(cfg.val_data_keys), max(cfg.s_subgraph)) dloader = iter(DataLoader(dset)) acc_reward = 0 forwarder = Forwarder() delta_dist = 0.4 # segm_metric = AveragePrecision() clst_metric_rl = ClusterMetrics() # clst_metric = ClusterMetrics() metric_sp_gt = ClusterMetrics() # clst_metric_mcaff = ClusterMetrics() # clst_metric_mcembed = ClusterMetrics() # clst_metric_graphagglo = ClusterMetrics() sbd = SBD() # map_rl, map_embed, map_sp_gt, map_mcaff, map_mcembed, map_graphagglo = [], [], [], [], [], [] sbd_rl, sbd_embed, sbd_sp_gt, sbd_mcaff, sbd_mcembed, sbd_graphagglo = [], [], [], [], [], [] n_examples = len(dset) for it in range(n_examples): update_env_data(env, dloader, dset, device, with_gt_edges=False) env.reset() state = env.get_state() distr, _, _, _, _, node_features, embeddings = forwarder.forward(model, state, State, device, grad=False, post_data=False, get_node_feats=True, get_embeddings=True) action = torch.sigmoid(distr.loc) reward = env.execute_action(action, tau=0.0, train=False) acc_reward += reward[-2].item() embeds = embeddings[0].cpu() # node_features = node_features.cpu().numpy() rag = env.rags[0] edge_ids = rag.uvIds() gt_seg = env.gt_seg[0].cpu().numpy() # l2_embeddings = get_angles(embeds[None])[0] # l2_node_feats = get_angles(torch.from_numpy(node_features.T[None, ..., None])).squeeze().T.numpy() # clst_labels_kmeans = cluster_embeddings(l2_embeddings.permute((1, 2, 0)), len(np.unique(gt_seg))) # node_labels = cluster_embeddings(l2_node_feats, len(np.unique(gt_seg))) # clst_labels_sp_kmeans = elf.segmentation.features.project_node_labels_to_pixels(rag, node_labels).squeeze() # clst_labels_sp_graph_agglo = get_soln_graph_clustering(env.init_sp_seg, torch.from_numpy(edge_ids.astype(np.int)), torch.from_numpy(l2_node_feats), len(np.unique(gt_seg)))[0][0].numpy() # mc_labels_aff = env.get_current_soln(edge_weights=env.edge_features[:, 0]).cpu().numpy()[0] # ew_embedaffs = 1 - get_edge_features_1d(env.init_sp_seg[0].cpu().numpy(), offs, get_affinities_from_embeddings_2d(embeddings, offs, delta_dist, distance)[0].cpu().numpy())[0][:, 0] # mc_labels_embedding_aff = env.get_current_soln(edge_weights=torch.from_numpy(ew_embedaffs).to(device)).cpu().numpy()[0] rl_labels = env.current_soln.cpu().numpy()[0] ex_embeds.append(pca_project(embeds, n_comps=3)) ex_raws.append(env.raw[0].cpu().permute(1, 2, 0).squeeze()) # ex_sps.append(cm.prism(env.init_sp_seg[0].cpu() / env.init_sp_seg[0].max().item())) ex_sps.append(env.init_sp_seg[0].cpu()) ex_mc_gts.append(project_overseg_to_seg(env.init_sp_seg[0], torch.from_numpy(gt_seg).to(device)).cpu().numpy()) ex_gts.append(gt_seg) ex_rl.append(rl_labels) # ex_clst.append(clst_labels_kmeans) # ex_clst_sp.append(clst_labels_sp_kmeans) # ex_clst_graph_agglo.append(clst_labels_sp_graph_agglo) # ex_mcaff.append(mc_labels_aff) # ex_mc_embed.append(mc_labels_embedding_aff) # map_rl.append(segm_metric(rl_labels, gt_seg)) sbd_rl.append(sbd(gt_seg, rl_labels)) clst_metric_rl(rl_labels, gt_seg) # map_sp_gt.append(segm_metric(ex_mc_gts[-1], gt_seg)) sbd_sp_gt.append(sbd(gt_seg, ex_mc_gts[-1])) metric_sp_gt(ex_mc_gts[-1], gt_seg) # map_embed.append(segm_metric(clst_labels_kmeans, gt_seg)) # clst_metric(clst_labels_kmeans, gt_seg) # map_mcaff.append(segm_metric(mc_labels_aff, gt_seg)) # sbd_mcaff.append(sbd(gt_seg, mc_labels_aff)) # clst_metric_mcaff(mc_labels_aff, gt_seg) # # map_mcembed.append(segm_metric(mc_labels_embedding_aff, gt_seg)) # sbd_mcembed.append(sbd(gt_seg, mc_labels_embedding_aff)) # clst_metric_mcembed(mc_labels_embedding_aff, gt_seg) # # map_graphagglo.append(segm_metric(clst_labels_sp_graph_agglo, gt_seg)) # sbd_graphagglo.append(sbd(gt_seg, clst_labels_sp_graph_agglo.astype(np.int))) # clst_metric_graphagglo(clst_labels_sp_graph_agglo.astype(np.int), gt_seg) print("\nSBD: ") print(f"sp gt : {round(np.array(sbd_sp_gt).mean(), 4)}; {round(np.array(sbd_sp_gt).std(), 4)}") print(f"ours : {round(np.array(sbd_rl).mean(), 4)}; {round(np.array(sbd_rl).std(), 4)}") # print(f"mc node : {np.array(sbd_mcembed).mean()}") # print(f"mc embed : {np.array(sbd_mcaff).mean()}") # print(f"graph agglo : {np.array(sbd_graphagglo).mean()}") # print("\nmAP: ") # print(f"sp gt : {np.array(map_sp_gt).mean()}") # print(f"ours : {np.array(map_rl).mean()}") # print(f"mc node : {np.array(map_mcembed).mean()}") # print(f"mc embed : {np.array(map_mcaff).mean()}") # print(f"graph agglo : {np.array(map_graphagglo).mean()}") # vi_rl_s, vi_rl_m, are_rl, arp_rl, arr_rl = clst_metric_rl.dump() vi_spgt_s, vi_spgt_m, are_spgt, arp_spgt, arr_spgt = metric_sp_gt.dump() # vi_mcaff_s, vi_mcaff_m, are_mcaff, arp_mcaff, arr_mcaff = clst_metric_mcaff.dump() # vi_mcembed_s, vi_mcembed_m, are_mcembed, arp_embed, arr_mcembed = clst_metric_mcembed.dump() # vi_graphagglo_s, vi_graphagglo_m, are_graphagglo, arp_graphagglo, arr_graphagglo = clst_metric_graphagglo.dump() # vi_rl_s_std, vi_rl_m_std, are_rl_std, arp_rl_std, arr_rl_std = clst_metric_rl.dump_std() vi_spgt_s_std, vi_spgt_m_std, are_spgt_std, arp_spgt_std, arr_spgt_std = metric_sp_gt.dump_std() print("\nVI merge: ") print(f"sp gt : {round(vi_spgt_m, 4)}; {round(vi_spgt_m_std, 4)}") print(f"ours : {round(vi_rl_m, 4)}; {round(vi_rl_m_std, 4)}") # print(f"mc affnties : {vi_mcaff_m}") # print(f"mc embed : {vi_mcembed_m}") # print(f"graph agglo : {vi_graphagglo_m}") # print("\nVI split: ") print(f"sp gt : {round(vi_spgt_s, 4)}; {round(vi_spgt_s_std, 4)}") print(f"ours : {round(vi_rl_s, 4)}; {round(vi_rl_s_std, 4)}") # print(f"mc affnties : {vi_mcaff_s}") # print(f"mc embed : {vi_mcembed_s}") # print(f"graph agglo : {vi_graphagglo_s}") # print("\nARE: ") print(f"sp gt : {round(are_spgt, 4)}; {round(are_spgt_std, 4)}") print(f"ours : {round(are_rl, 4)}; {round(are_rl_std, 4)}") # print(f"mc affnties : {are_mcaff}") # print(f"mc embed : {are_mcembed}") # print(f"graph agglo : {are_graphagglo}") # print("\nARP: ") print(f"sp gt : {round(arp_spgt, 4)}; {round(arp_spgt_std, 4)}") print(f"ours : {round(arp_rl, 4)}; {round(arp_rl_std, 4)}") # print(f"mc affnties : {arp_mcaff}") # print(f"mc embed : {arp_embed}") # print(f"graph agglo : {arp_graphagglo}") # print("\nARR: ") print(f"sp gt : {round(arr_spgt, 4)}; {round(arr_spgt_std, 4)}") print(f"ours : {round(arr_rl, 4)}; {round(arr_rl_std, 4)}") # print(f"mc affnties : {arr_mcaff}") # print(f"mc embed : {arr_mcembed}") # print(f"graph agglo : {arr_graphagglo}") exit() for i in range(len(ex_gts)): fig, axs = plt.subplots(2, 4, figsize=(20, 13), sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0}) axs[0, 0].imshow(ex_gts[i], cmap=random_label_cmap(), interpolation="none") axs[0, 0].set_title('gt') axs[0, 0].axis('off') axs[0, 1].imshow(ex_embeds[i]) axs[0, 1].set_title('pc proj') axs[0, 1].axis('off') # axs[0, 2].imshow(ex_clst[i], cmap=random_label_cmap(), interpolation="none") # axs[0, 2].set_title('pix clst') # axs[0, 2].axis('off') axs[0, 2].imshow(ex_clst_graph_agglo[i], cmap=random_label_cmap(), interpolation="none") axs[0, 2].set_title('nagglo') axs[0, 2].axis('off') axs[0, 3].imshow(ex_mc_embed[i], cmap=random_label_cmap(), interpolation="none") axs[0, 3].set_title('mc embed') axs[0, 3].axis('off') axs[1, 0].imshow(ex_mc_gts[i], cmap=random_label_cmap(), interpolation="none") axs[1, 0].set_title('sp gt') axs[1, 0].axis('off') axs[1, 1].imshow(ex_sps[i], cmap=random_label_cmap(), interpolation="none") axs[1, 1].set_title('sp') axs[1, 1].axis('off') # axs[1, 2].imshow(ex_clst_sp[i], cmap=random_label_cmap(), interpolation="none") # axs[1, 2].set_title('sp clst') # axs[1, 2].axis('off') axs[1, 2].imshow(ex_rl[i], cmap=random_label_cmap(), interpolation="none") axs[1, 2].set_title('ours') axs[1, 2].axis('off') axs[1, 3].imshow(ex_mcaff[i], cmap=random_label_cmap(), interpolation="none") axs[1, 3].set_title('mc aff') axs[1, 3].axis('off') plt.show() # wandb.log({"validation/samples": [wandb.Image(fig, caption="sample images")]}) plt.close('all')
def supervised_policy_pretraining(model, env, cfg, device="cuda:0", fe_opt=False): wu_cfg = AttrDict() add_dict(cfg.policy_warmup, wu_cfg) dset = SpgDset(cfg.data_dir, wu_cfg.patch_manager, max(cfg.trn.s_subgraph)) dloader = DataLoader(dset, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) if fe_opt: actor_fe_opt = torch.optim.Adam(list(model.actor.parameters()) + list(env.embedding_net.parameters()), lr=wu_cfg.lr) else: actor_fe_opt = torch.optim.Adam(model.actor.parameters(), lr=wu_cfg.lr) dummy_opt = torch.optim.Adam([model.log_alpha], lr=wu_cfg.lr) sheduler = ReduceLROnPlateau(actor_fe_opt, threshold=0.001, min_lr=1e-6) criterion = torch.nn.BCELoss() acc_loss = 0 iteration = 0 best_score = -np.inf # be careful with this, it assumes a one step episode environment while iteration <= wu_cfg.n_iterations: update_env_data(env, dloader, device, with_gt_edges=True, fe_grad=fe_opt) state = env.get_state() # Calculate policy and values distr, q1, q2, _, _ = agent_forward(env, model, state, policy_opt=True) action = distr.transforms[0](distr.loc) loss = criterion(action.squeeze(1), env.gt_edge_weights) dummy_loss = (model.alpha * 0).sum( ) # not using all parameters in backprop gives error, so add dummy loss for sq1, sq2 in zip(q1, q2): loss = loss + (sq1.sum() * sq2.sum() * 0) actor_fe_opt.zero_grad() loss.backward(retain_graph=False) actor_fe_opt.step() dummy_opt.zero_grad() dummy_loss.backward(retain_graph=False) dummy_opt.step() acc_loss += loss.item() if iteration % 10 == 0: _, reward = env.execute_action(action.detach(), None, post_images=True, tau=0.0) sheduler.step(acc_loss / 10) total_reward = 0 for _rew in reward: total_reward += _rew.mean().item() total_reward /= len(reward) wandb.log({"policy_warm_start/acc_loss": acc_loss}) wandb.log({"policy_warm_start/rewards": total_reward}) acc_loss = 0 if total_reward > best_score: best_model = copy.deepcopy(model.state_dict()) best_score = total_reward wandb.log({"policy_warm_start/loss": loss.item()}) wandb.log({"policy_warm_start/lr": actor_fe_opt.param_groups[0]['lr']}) iteration += 1 model.load_state_dict(best_model) return
def __init__(self, cfg, global_count): super(AgentSaTrainerObjLvlReward, self).__init__() assert torch.cuda.device_count() == 1 self.device = torch.device("cuda:0") torch.cuda.set_device(self.device) torch.set_default_tensor_type(torch.FloatTensor) self.cfg = cfg self.global_count = global_count self.memory = TransitionData_ts(capacity=self.cfg.mem_size) self.best_val_reward = -np.inf if self.cfg.distance == 'cosine': self.distance = CosineDistance() else: self.distance = L2Distance() self.fe_ext = FeExtractor(dict_to_attrdict(self.cfg.backbone), self.distance, cfg.fe_delta_dist, self.device) self.fe_ext.embed_model.load_state_dict( torch.load(self.cfg.fe_model_name)) self.fe_ext.cuda(self.device) self.model = Agent(self.cfg, State, self.distance, self.device) wandb.watch(self.model) self.model.cuda(self.device) self.model_mtx = Lock() self.optimizer = torch.optim.Adam(self.model.actor.parameters(), lr=self.cfg.actor_lr) lr_sched_cfg = dict_to_attrdict(self.cfg.lr_sched) bw = lr_sched_cfg.mov_avg_bandwidth off = lr_sched_cfg.mov_avg_offset weights = np.linspace(lr_sched_cfg.weight_range[0], lr_sched_cfg.weight_range[1], bw) weights = weights / weights.sum() # make them sum up to one shed = lr_sched_cfg.torch_sched self.shed = ReduceLROnPlateau(self.optimizer, patience=shed.patience, threshold=shed.threshold, min_lr=shed.min_lr, factor=shed.factor) self.mov_sum_loss = RunningAverage(weights, band_width=bw, offset=off) self.scaler = torch.cuda.amp.GradScaler() self.forwarder = Forwarder() if self.cfg.agent_model_name != "": self.model.load_state_dict(torch.load(self.cfg.agent_model_name)) # finished with prepping for param in self.fe_ext.parameters(): param.requires_grad = False self.train_dset = SpgDset(self.cfg.data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys)) self.val_dset = SpgDset(self.cfg.val_data_dir, dict_to_attrdict(self.cfg.patch_manager), dict_to_attrdict(self.cfg.data_keys))
# # graph_file.create_dataset("edges", data=edges, chunks=True) # graph_file.create_dataset("edge_feat", data=edge_feat, chunks=True) # graph_file.create_dataset("diff_to_gt", data=diff_to_gt) # graph_file.create_dataset("gt_edge_weights", data=gt_edge_weights, chunks=True) # graph_file.create_dataset("node_labeling", data=node_labeling, chunks=True) # graph_file.create_dataset("affinities", data=affinities, chunks=True) # # graph_file.close() # pix_file.close() if __name__ == "__main__": dir = "/g/kreshuk/hilt/projects/fewShotLearning/mutexWtsd/data/storage/sqrs_crclspn/pix_and_graphs" # store_all(dir) dset = SpgDset(dir) raw, gt, sp_seg, idx = dset.__getitem__(20) edges, edge_feat, diff_to_gt, gt_edge_weights = dset.get_graphs(idx) gt_seg = get_current_soln(gt_edge_weights[0].numpy().astype(np.float64), sp_seg[0].numpy().astype(np.uint64), edges[0].numpy().transpose().astype(np.int64)) fig, (ax1, ax2, ax3) = plt.subplots(1, 3) ax1.imshow(gt[0]) ax1.set_title('gt') ax2.imshow(cm.prism(sp_seg[0] / sp_seg[0].max())) ax2.set_title('sp') ax3.imshow(gt_seg) ax3.set_title('mc') plt.show() a = 1
def train(self): writer = SummaryWriter(logdir=self.log_dir) device = "cuda:0" wu_cfg = self.cfg.fe.trainer model = UNet2D(**self.cfg.fe.backbone) model.cuda(device) train_set = SpgDset(self.cfg.gen.data_dir_raw_train, reorder_sp=False) val_set = SpgDset(self.cfg.gen.data_dir_raw_val, reorder_sp=False) # pm = StridedPatches2D(wu_cfg.patch_stride, wu_cfg.patch_shape, train_set.image_shape) pm = NoPatches2D() train_set.length = len(train_set.graph_file_names) * np.prod(pm.n_patch_per_dim) train_set.n_patch_per_dim = pm.n_patch_per_dim val_set.length = len(val_set.graph_file_names) # dset = LeptinDset(self.cfg.gen.data_dir_raw, self.cfg.gen.data_dir_affs, wu_cfg.patch_manager, wu_cfg.patch_stride, wu_cfg.patch_shape, wu_cfg.reorder_sp) train_loader = DataLoader(train_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) val_loader = DataLoader(val_set, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) gauss_kernel = GaussianSmoothing(1, 5, 3, device=device) optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr) sheduler = ReduceLROnPlateau(optimizer, patience=20, threshold=1e-4, min_lr=1e-5, factor=0.1) criterion = RagContrastiveWeights(delta_var=0.1, delta_dist=0.4) acc_loss = 0 valit = 0 iteration = 0 best_loss = np.inf while iteration <= wu_cfg.n_iterations: for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(train_loader): raw, gt, sp_seg, affinities = raw.to(device), gt.to(device), sp_seg.to(device), affinities.to(device) # edge_img = F.pad(get_contour_from_2d_binary(sp_seg), (2, 2, 2, 2), mode='constant') # edge_img = gauss_kernel(edge_img.float()) # input = torch.cat([raw, edge_img], dim=1) offs = offs.numpy().tolist() loss_embeds = model(raw[:, :, None]).squeeze(2) edge_feat, edges = tuple(zip(*[get_edge_features_1d(seg.squeeze().cpu().numpy(), os, affs.squeeze().cpu().numpy()) for seg, os, affs in zip(sp_seg, offs, affinities)])) edges = [torch.from_numpy(e.astype(np.long)).to(device).T for e in edges] edge_weights = [torch.from_numpy(ew.astype(np.float32)).to(device)[:, 0][None] for ew in edge_feat] # put embeddings on unit sphere so we can use cosine distance loss_embeds = loss_embeds / (torch.norm(loss_embeds, dim=1, keepdim=True) + 1e-9) loss = criterion(loss_embeds, sp_seg.long(), edges, edge_weights, chunks=int(sp_seg.max().item()//self.cfg.gen.train_chunk_size), sigm_factor=self.cfg.gen.sigm_factor, pull_factor=self.cfg.gen.pull_factor) optimizer.zero_grad() loss.backward() optimizer.step() print(loss.item()) writer.add_scalar("fe_train/lr", optimizer.param_groups[0]['lr'], iteration) writer.add_scalar("fe_train/loss", loss.item(), iteration) if (iteration) % 100 == 0: with torch.set_grad_enabled(False): for it, (raw, gt, sp_seg, affinities, offs, indices) in enumerate(val_loader): raw, gt, sp_seg, affinities = raw.to(device), gt.to(device), sp_seg.to(device), affinities.to(device) offs = offs.numpy().tolist() embeddings = model(raw[:, :, None]).squeeze(2) # relabel to consecutive ints starting at 0 edge_feat, edges = tuple(zip( *[get_edge_features_1d(seg.squeeze().cpu().numpy(), os, affs.squeeze().cpu().numpy()) for seg, os, affs in zip(sp_seg, offs, affinities)])) edges = [torch.from_numpy(e.astype(np.long)).to(device).T for e in edges] edge_weights = [torch.from_numpy(ew.astype(np.float32)).to(device)[:, 0][None] for ew in edge_feat] # put embeddings on unit sphere so we can use cosine distance embeddings = embeddings / (torch.norm(embeddings, dim=1, keepdim=True) + 1e-9) ls = criterion(embeddings, sp_seg.long(), edges, edge_weights, chunks=int(sp_seg.max().item()//self.cfg.gen.train_chunk_size), sigm_factor=self.cfg.gen.sigm_factor, pull_factor=self.cfg.gen.pull_factor) # ls = 0 acc_loss += ls writer.add_scalar("fe_val/loss", ls, valit) valit += 1 acc_loss = acc_loss / len(val_loader) if acc_loss < best_loss: print(self.save_dir) torch.save(model.state_dict(), os.path.join(self.save_dir, "best_val_model.pth")) best_loss = acc_loss sheduler.step(acc_loss) acc_loss = 0 fig, ((a1, a2), (a3, a4)) = plt.subplots(2, 2, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0}) a1.imshow(raw[0].cpu().permute(1, 2, 0)[..., 0].squeeze()) a1.set_title('raw') a2.imshow(cm.prism(sp_seg[0, 0].cpu().squeeze() / sp_seg[0, 0].cpu().squeeze().max())) a2.set_title('sp') a3.imshow(pca_project(get_angles(embeddings)[0].detach().cpu())) a3.set_title('angle_embed') a4.imshow(pca_project(embeddings[0].detach().cpu())) a4.set_title('embed') # plt.show() writer.add_figure("examples", fig, iteration//100) iteration += 1 print(iteration) if iteration > wu_cfg.n_iterations: print(self.save_dir) torch.save(model.state_dict(), os.path.join(self.save_dir, "last_model.pth")) break return
def train(self): writer = SummaryWriter(logdir=self.log_dir) writer.add_text("conf", self.cfg.pretty()) device = "cuda:0" wu_cfg = self.cfg.fe.trainer model = UNet2D(self.cfg.fe.n_raw_channels, self.cfg.fe.n_embedding_features, final_sigmoid=False, num_levels=5) model.cuda(device) dset = SpgDset(self.cfg.gen.data_dir, wu_cfg.patch_manager, wu_cfg.patch_stride, wu_cfg.patch_shape, wu_cfg.reorder_sp) dloader = DataLoader(dset, batch_size=wu_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.fe.lr) tfs = RndAugmentationTfs(wu_cfg.patch_shape) criterion = AugmentedAffinityContrastive(delta_var=0.1, delta_dist=0.3) acc_loss = 0 iteration = 0 while iteration <= wu_cfg.n_iterations: for it, (raw, gt, sp_seg, indices) in enumerate(dloader): raw, gt, sp_seg = raw.to(device), gt.to(device), sp_seg.to( device) # this is still not the correct mask calculation as the affinity offsets go in no tf offset direction mask = torch.from_numpy( get_valid_edges([len(criterion.offs)] + list(raw.shape[-2:]), criterion.offs)).to(device)[None] # _, _, _, _, affs = dset.get_graphs(indices, sp_seg, device) spat_tf, int_tf = tfs.sample(1, 1) _, _int_tf = tfs.sample(1, 1) inp = add_sp_gauss_noise(_int_tf(raw), 0.2, 0.1, 0.3) embeddings = model(inp.unsqueeze(2)).squeeze(2) paired = spat_tf(torch.cat((mask, raw, embeddings), -3)) embeddings_0, mask = paired[ ..., inp.shape[1] + len(criterion.offs):, :, :], paired[ ..., :len(criterion.offs), :, :].detach() # do intensity transform for spatial transformed input aug_inp = int_tf(paired[..., len(criterion.offs):inp.shape[1] + len(criterion.offs), :, :]).detach() # get prediction of the augmented input embeddings_1 = model( add_sp_gauss_noise(aug_inp, 0.2, 0.1, 0.3).unsqueeze(2)).squeeze(2) # put embeddings on unit sphere so we can use cosine distance embeddings_0 = embeddings_0 / ( torch.norm(embeddings_0, dim=1, keepdim=True) + 1e-6) embeddings_1 = embeddings_1 / ( torch.norm(embeddings_1, dim=1, keepdim=True) + 1e-6) loss = criterion(embeddings_0, embeddings_1, aug_inp, mask) optimizer.zero_grad() loss.backward() optimizer.step() acc_loss += loss.item() print(loss.item()) writer.add_scalar("fe_warm_start/loss", loss.item(), iteration) writer.add_scalar("fe_warm_start/lr", optimizer.param_groups[0]['lr'], iteration) if (iteration) % 50 == 0: acc_loss = 0 fig, ((a1, a2), (a3, a4)) = plt.subplots(2, 2) a1.imshow(aug_inp[0].cpu().permute(1, 2, 0).squeeze()) a1.set_title('tf_raw') a3.imshow( pca_project( get_angles(embeddings_0).squeeze( 0).detach().cpu())) a3.set_title('tf_embed') a4.imshow( pca_project( get_angles(embeddings_1).squeeze( 0).detach().cpu())) a4.set_title('embed') a2.imshow(raw[0].cpu().permute(1, 2, 0).squeeze()) a2.set_title('raw') plt.show() # writer.add_figure("examples", fig, iteration//100) iteration += 1 if iteration > wu_cfg.n_iterations: break return