def fe_extr_warm_start(self, sp_feature_ext, writer=None): # dloader = DataLoader(MultiDiscSpGraphDsetBalanced(length=self.args.fe_warmup_iterations * 10), batch_size=10, # shuffle=True, pin_memory=True) dloader = DataLoader( MultiDiscSpGraphDset(length=self.args.fe_warmup_iterations * 10), batch_size=10, shuffle=True, pin_memory=True) criterion = ContrastiveLoss(delta_var=0.5, delta_dist=1.5) optimizer = torch.optim.Adam(sp_feature_ext.parameters(), lr=2e-3) for i, (data, gt) in enumerate(dloader): data, gt = data.to(sp_feature_ext.device), gt.to( sp_feature_ext.device) pred = sp_feature_ext(data) l2_reg = None if self.args.l2_reg_params_weight != 0: for W in list(sp_feature_ext.parameters()): if l2_reg is None: l2_reg = W.norm(2) else: l2_reg = l2_reg + W.norm(2) if l2_reg is None: l2_reg = 0 loss = criterion(pred, gt) + l2_reg * self.args.l2_reg_params_weight optimizer.zero_grad() loss.backward() optimizer.step() if writer is not None: writer.add_scalar("loss/fe_warm_start", loss.item(), self.writer_idx_warmup_loss) self.writer_idx_warmup_loss += 1
def fe_extr_warm_start(self, sp_feature_ext, writer=None): dataloader = DataLoader(MultiDiscSpGraphDset(length=100), batch_size=10, shuffle=True, pin_memory=True) criterion = ContrastiveLoss(delta_var=0.5, delta_dist=1.5) optimizer = torch.optim.Adam(sp_feature_ext.parameters()) for i, (data, gt) in enumerate(dataloader): data, gt = data.to(sp_feature_ext.device), gt.to(sp_feature_ext.device) pred = sp_feature_ext(data[:,0,:,:].unsqueeze(1)) loss = criterion(pred, gt) optimizer.zero_grad() loss.backward() optimizer.step() if writer is not None: writer.add_scalar("loss/fe_warm_start", loss.item(), self.writer_idx_warmup_loss) self.writer_idx_warmup_loss += 1
def train(self, rank, start_time, return_dict): device = torch.device("cuda:" + str(rank)) print('Running on device: ', device) torch.cuda.set_device(device) torch.set_default_tensor_type(torch.FloatTensor) writer = None if not self.args.cross_validate_hp: writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs')) # posting parameters param_string = "" for k, v in vars(self.args).items(): param_string += ' ' * 10 + k + ': ' + str(v) + '\n' writer.add_text("params", param_string) self.setup(rank, self.args.num_processes) transition = namedtuple('Transition', ('state', 'action', 'reward', 'state_', 'behav_policy_proba', 'time', 'terminal')) memory = TransitionData(capacity=self.args.t_max, storage_object=transition) env = SpGcnEnv(self.args, device, writer=writer, writer_counter=self.global_writer_quality_count, win_event_counter=self.global_win_event_count) dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False), batch_size=1, shuffle=True, pin_memory=True, num_workers=0) # Create shared network model = GcnEdgeAngle1dPQV(self.args.n_raw_channels, self.args.n_embedding_features, self.args.n_edge_features, self.args.n_actions, device) model.cuda(device) shared_model = DDP(model, device_ids=[model.device]) # Create optimizer for shared network parameters with shared statistics optimizer = CstmAdam(shared_model.parameters(), lr=self.args.lr, betas=self.args.Adam_betas, weight_decay=self.args.Adam_weight_decay) if self.args.fe_extr_warmup and rank == 0: fe_extr = SpVecsUnet(self.args.n_raw_channels, self.args.n_embedding_features, device) fe_extr.cuda(device) self.fe_extr_warm_start(fe_extr, writer=writer) shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict()) if self.args.model_name == "": torch.save(fe_extr.state_dict(), os.path.join(self.save_dir, 'agent_model')) else: torch.save(shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name)) dist.barrier() if self.args.model_name != "": shared_model.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_name))) elif self.args.fe_extr_warmup: print('loaded fe extractor') shared_model.load_state_dict( torch.load(os.path.join(self.save_dir, 'agent_model'))) self.shared_damped_model.load_state_dict(shared_model.state_dict()) env.done = True # Start new episode while self.global_count.value() <= self.args.T_max: if env.done: edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt = \ next(iter(dloader)) edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt = \ edges.squeeze().to(device), edge_feat.squeeze()[:, 0:self.args.n_edge_features].to( device), diff_to_gt.squeeze().to(device), \ gt_edge_weights.squeeze().to(device), node_labeling.squeeze().to(device), raw.squeeze().to( device), nodes.squeeze().to(device), \ angles.squeeze().to(device), affinities.squeeze().numpy(), gt.squeeze() env.update_data(edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt) env.reset() state = [env.state[0].clone(), env.state[1].clone()] episode_length = 0 self.eps = self.eps_rule.apply(self.global_count.value()) env.stop_quality = self.stop_qual_rule.apply( self.global_count.value()) if writer is not None: writer.add_scalar("step/epsilon", self.eps, env.writer_counter.value()) while not env.done: # Calculate policy and values policy_proba, q, v = self.agent_forward(env, shared_model, grad=False) # average_policy_proba, _, _ = self.agent_forward(env, self.shared_average_model) # q_ret = v.detach() # Sample action # action = torch.multinomial(policy, 1)[0, 0] # Step action, behav_policy_proba = self.get_action( policy_proba, q, v, policy='off_uniform', device=device) state_, reward = env.execute_action(action, self.global_count.value()) memory.push(state, action, reward.to(shared_model.module.device), state_, behav_policy_proba, episode_length, env.done) # Train the network self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer) # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward # Optionally clamp rewards # done = done or episode_length >= self.args.max_episode_length # Stop episodes at a max length episode_length += 1 # Increase episode counter state = state_ # Break graph for last values calculated (used for targets, not directly as model outputs) self.global_count.increment() # Qret = 0 for terminal s while len(memory) > 0: self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer) memory.pop(0) dist.barrier() if rank == 0: if not self.args.cross_validate_hp: if self.args.model_name != "": torch.save( shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name)) else: torch.save(shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) else: test_score = 0 env.writer = None for i in range(20): self.update_env_data(env, dloader, device) env.reset() self.eps = 0 while not env.done: # Calculate policy and values policy_proba, q, v = self.agent_forward(env, shared_model, grad=False) action, behav_policy_proba = self.get_action( policy_proba, q, v, policy='off_uniform', device=device) _, _ = env.execute_action(action, self.global_count.value()) if env.win: test_score += 1 return_dict['test_score'] = test_score writer.add_text("time_needed", str((time.time() - start_time))) self.cleanup()
def train(self, rank, start_time, return_dict): device = torch.device("cuda:" + str(rank)) print('Running on device: ', device) torch.cuda.set_device(device) torch.set_default_tensor_type(torch.FloatTensor) writer = None if not self.args.cross_validate_hp: writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs')) # posting parameters param_string = "" for k, v in vars(self.args).items(): param_string += ' ' * 10 + k + ': ' + str(v) + '\n' writer.add_text("params", param_string) self.setup(rank, self.args.num_processes) transition = namedtuple( 'Transition', ('state', 'action', 'reward', 'behav_policy_proba', 'done')) memory = TransitionData(capacity=self.args.t_max, storage_object=transition) env = SpGcnEnv(self.args, device, writer=writer, writer_counter=self.global_writer_quality_count, win_event_counter=self.global_win_event_count, discrete_action_space=False) # Create shared network model = GcnEdgeAngle1dPQA_dueling_1(self.args.n_raw_channels, self.args.n_embedding_features, self.args.n_edge_features, 1, self.args.exp_steps, self.args.p_sigma, device, self.args.density_eval_range, writer=writer) if self.args.no_fe_extr_optim: for param in model.fe_ext.parameters(): param.requires_grad = False model.cuda(device) shared_model = DDP(model, device_ids=[model.device]) dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False), batch_size=1, shuffle=True, pin_memory=True, num_workers=0) # Create optimizer for shared network parameters with shared statistics optimizer = torch.optim.Adam(shared_model.parameters(), lr=self.args.lr, betas=self.args.Adam_betas, weight_decay=self.args.Adam_weight_decay) if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only: fe_extr = SpVecsUnet(self.args.n_raw_channels, self.args.n_embedding_features, device) fe_extr.cuda(device) self.fe_extr_warm_start(fe_extr, writer=writer) shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict()) if self.args.model_name == "": torch.save(shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) else: torch.save(shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name)) dist.barrier() if self.args.model_name != "": shared_model.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_name))) elif self.args.fe_extr_warmup: shared_model.load_state_dict( torch.load(os.path.join(self.save_dir, 'agent_model'))) self.shared_average_model.load_state_dict(shared_model.state_dict()) if not self.args.test_score_only: quality = self.args.stop_qual_scaling + self.args.stop_qual_offset while self.global_count.value() <= self.args.T_max: if self.global_count.value() == 190: a = 1 self.update_env_data(env, dloader, device) env.reset() state = [env.state[0].clone(), env.state[1].clone()] self.b_sigma = self.b_sigma_rule.apply( self.global_count.value(), quality) env.stop_quality = self.stop_qual_rule.apply( self.global_count.value(), quality) with open(os.path.join(self.save_dir, 'config.yaml')) as info: args_dict = yaml.full_load(info) if args_dict is not None: if 'eps' in args_dict: if self.args.eps != args_dict['eps']: self.eps = args_dict['eps'] if 'safe_model' in args_dict: self.args.safe_model = args_dict['safe_model'] if 'add_noise' in args_dict: self.args.add_noise = args_dict['add_noise'] if writer is not None: writer.add_scalar("step/b_variance", self.b_sigma, env.writer_counter.value()) if self.args.safe_model: if rank == 0: if self.args.model_name_dest != "": torch.save( shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name_dest)) else: torch.save( shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) while not env.done: post_input = True if self.global_count.value( ) % 50 and env.counter == 0 else False # Calculate policy and values policy_means, p_dis = self.agent_forward( env, shared_model, grad=False, stats_only=True, post_input=post_input) # Step action, b_rvs = self.get_action(policy_means, p_dis, device) state_, reward, quality = env.execute_action(action) if self.args.add_noise: if self.global_count.value( ) > 110 and self.global_count.value() % 5: noise = torch.randn_like(reward) * 0.8 reward = reward + noise memory.push(state, action, reward, b_rvs, env.done) # Train the network # self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer) # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward # Optionally clamp rewards # done = done or episode_length >= self.args.max_episode_length # Stop episodes at a max length state = state_ # Break graph for last values calculated (used for targets, not directly as model outputs) self.global_count.increment() self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer) memory.clear() # while len(memory) > 0: # self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer) # memory.pop(0) dist.barrier() if rank == 0: if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save: # pass if self.args.model_name != "": torch.save( shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name)) print('saved') else: torch.save(shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) if self.args.cross_validate_hp or self.args.test_score_only: test_score = 0 env.writer = None for i in range(20): self.update_env_data(env, dloader, device) env.reset() self.b_sigma = self.args.p_sigma env.stop_quality = 40 while not env.done: # Calculate policy and values policy_means, p_dis = self.agent_forward( env, shared_model, grad=False, stats_only=True) action, b_rvs = self.get_action( policy_means, p_dis, device) _, _ = env.execute_action(action, self.global_count.value()) # import matplotlib.pyplot as plt; # plt.imshow(env.get_current_soln()); # plt.show() if env.win: test_score += 1 return_dict['test_score'] = test_score writer.add_text("time_needed", str((time.time() - start_time))) self.cleanup()
def train(self): step_counter = 0 device = torch.device("cuda:" + str(0)) print('Running on device: ', device) torch.cuda.set_device(device) torch.set_default_tensor_type(torch.FloatTensor) writer = None if not self.args.cross_validate_hp: writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs')) # posting parameters param_string = "" for k, v in vars(self.args).items(): param_string += ' ' * 10 + k + ': ' + str(v) + '\n' writer.add_text("params", param_string) # Create shared network model = GcnEdgeAngle1dQ(self.args.n_raw_channels, self.args.n_embedding_features, self.args.n_edge_features, 1, device, writer=writer) if self.args.no_fe_extr_optim: for param in model.fe_ext.parameters(): param.requires_grad = False model.cuda(device) dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False), batch_size=1, shuffle=True, pin_memory=True, num_workers=0) optimizer = Adam(model.parameters(), lr=self.args.lr) loss = GraphDiceLoss() if self.args.fe_extr_warmup and not self.args.test_score_only: fe_extr = SpVecsUnet(self.args.n_raw_channels, self.args.n_embedding_features, device) fe_extr.cuda(device) self.fe_extr_warm_start(fe_extr, writer=writer) model.fe_ext.load_state_dict(fe_extr.state_dict()) if self.args.model_name == "": torch.save(model.state_dict(), os.path.join(self.save_dir, 'agent_model')) else: torch.save(model.state_dict(), os.path.join(self.save_dir, self.args.model_name)) if self.args.model_name != "": model.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_name))) elif self.args.fe_extr_warmup: print('loaded fe extractor') model.load_state_dict( torch.load(os.path.join(self.save_dir, 'agent_model'))) while step_counter <= self.args.T_max: if step_counter == 78: a = 1 if (step_counter + 1) % 1000 == 0: post_input = True else: post_input = False with open(os.path.join(self.save_dir, 'config.yaml')) as info: args_dict = yaml.full_load(info) if args_dict is not None: if 'lr' in args_dict: self.args.lr = args_dict['lr'] adjust_learning_rate(optimizer, self.args.lr) round_n = 0 raw, gt, sp_seg, sp_indices, edge_ids, edge_weights, gt_edges, edge_features = \ self._get_data(dloader, device) inp = [ obj.float().to(model.device) for obj in [edge_weights, sp_seg, raw + gt, sp_seg] ] pred, side_loss = model(inp, sp_indices=sp_indices, edge_index=edge_ids.to(model.device), angles=None, edge_features_1d=edge_features.to( model.device), round_n=round_n, post_input=post_input) pred = pred.squeeze() loss_val = loss(pred, gt_edges.to(device)) ttl_loss = loss_val + side_loss quality = (pred - gt_edges.to(device)).abs().sum() optimizer.zero_grad() ttl_loss.backward() optimizer.step() if writer is not None: writer.add_scalar("step/lr", self.args.lr, step_counter) writer.add_scalar("step/dice_loss", loss_val.item(), step_counter) writer.add_scalar("step/side_loss", side_loss.item(), step_counter) writer.add_scalar("step/quality", quality.item(), step_counter) step_counter += 1 a = 1
def train(self, rank, start_time, return_dict): device = torch.device("cuda:" + str(rank)) print('Running on device: ', device) torch.cuda.set_device(device) torch.set_default_tensor_type(torch.FloatTensor) writer = None if not self.args.cross_validate_hp: writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs')) # posting parameters param_string = "" for k, v in vars(self.args).items(): param_string += ' ' * 10 + k + ': ' + str(v) + '\n' writer.add_text("params", param_string) self.setup(rank, self.args.num_processes) if self.cfg.MC_DQL: transition = namedtuple('Transition', ('episode')) else: transition = namedtuple( 'Transition', ('state', 'action', 'reward', 'next_state', 'done')) memory = TransitionData_ts(capacity=self.args.t_max, storage_object=transition) env = SpGcnEnv(self.args, device, writer=writer, writer_counter=self.global_writer_quality_count, win_event_counter=self.global_win_event_count) # Create shared network model = GcnEdgeAC_1(self.cfg, self.args.n_raw_channels, self.args.n_embedding_features, 1, device, writer=writer) model.cuda(device) shared_model = DDP(model, device_ids=[model.device], find_unused_parameters=True) # dloader = DataLoader(MultiDiscSpGraphDsetBalanced(no_suppix=False, create=False), batch_size=1, shuffle=True, pin_memory=True, # num_workers=0) dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False), batch_size=1, shuffle=True, pin_memory=True, num_workers=0) # Create optimizer for shared network parameters with shared statistics # optimizer = CstmAdam(shared_model.parameters(), lr=self.args.lr, betas=self.args.Adam_betas, # weight_decay=self.args.Adam_weight_decay) ###################### self.action_range = 1 self.device = torch.device(device) self.discount = 0.5 self.critic_tau = self.cfg.critic_tau self.actor_update_frequency = self.cfg.actor_update_frequency self.critic_target_update_frequency = self.cfg.critic_target_update_frequency self.batch_size = self.cfg.batch_size self.log_alpha = torch.tensor(np.log(self.cfg.init_temperature)).to( self.device) self.log_alpha.requires_grad = True # set target entropy to -|A| ###################### # optimizers OptimizerContainer = namedtuple('OptimizerContainer', ('actor', 'critic', 'temperature')) actor_optimizer = torch.optim.Adam( shared_model.module.actor.parameters(), lr=self.cfg.actor_lr, betas=self.cfg.actor_betas) critic_optimizer = torch.optim.Adam( shared_model.module.critic.parameters(), lr=self.cfg.critic_lr, betas=self.cfg.critic_betas) temp_optimizer = torch.optim.Adam([self.log_alpha], lr=self.cfg.alpha_lr, betas=self.cfg.alpha_betas) optimizers = OptimizerContainer(actor_optimizer, critic_optimizer, temp_optimizer) if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only: fe_extr = shared_model.module.fe_ext fe_extr.cuda(device) self.fe_extr_warm_start_1(fe_extr, writer=writer) if self.args.model_name == "" and not self.args.no_save: torch.save(fe_extr.state_dict(), os.path.join(self.save_dir, 'agent_model_fe_extr')) elif not self.args.no_save: torch.save(fe_extr.state_dict(), os.path.join(self.save_dir, self.args.model_name)) dist.barrier() for param in model.fe_ext.parameters(): param.requires_grad = False if self.args.model_name != "": shared_model.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_name))) elif self.args.model_fe_name != "": shared_model.module.fe_ext.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_fe_name))) elif self.args.fe_extr_warmup: print('loaded fe extractor') shared_model.module.fe_ext.load_state_dict( torch.load(os.path.join(self.save_dir, 'agent_model_fe_extr'))) if not self.args.test_score_only: quality = self.args.stop_qual_scaling + self.args.stop_qual_offset while self.global_count.value() <= self.args.T_max: if self.global_count.value() == 78: a = 1 self.update_env_data(env, dloader, device) # waff_dis = torch.softmax(env.edge_features[:, 0].squeeze() + 1e-30, dim=0) # waff_dis = torch.softmax(env.gt_edge_weights + 0.5, dim=0) waff_dis = torch.softmax(torch.ones_like(env.gt_edge_weights), dim=0) loss_weight = torch.softmax(env.gt_edge_weights + 1, dim=0) env.reset() self.target_entropy = -float(env.gt_edge_weights.shape[0]) env.stop_quality = self.stop_qual_rule.apply( self.global_count.value(), quality) if self.cfg.temperature_regulation == 'follow_quality': self.alpha = self.eps_rule.apply(self.global_count.value(), quality) print(self.alpha.item()) with open(os.path.join(self.save_dir, 'runtime_cfg.yaml')) as info: args_dict = yaml.full_load(info) if args_dict is not None: if 'safe_model' in args_dict: self.args.safe_model = args_dict['safe_model'] if 'critic_lr' in args_dict and args_dict[ 'critic_lr'] != self.cfg.critic_lr: self.cfg.critic_lr = args_dict['critic_lr'] adjust_learning_rate(critic_optimizer, self.cfg.critic_lr) if 'actor_lr' in args_dict and args_dict[ 'actor_lr'] != self.cfg.actor_lr: self.cfg.actor_lr = args_dict['actor_lr'] adjust_learning_rate(actor_optimizer, self.cfg.actor_lr) if 'alpha_lr' in args_dict and args_dict[ 'alpha_lr'] != self.cfg.alpha_lr: self.cfg.alpha_lr = args_dict['alpha_lr'] adjust_learning_rate(temp_optimizer, self.cfg.alpha_lr) if self.args.safe_model and not self.args.no_save: if rank == 0: if self.args.model_name_dest != "": torch.save( shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name_dest)) else: torch.save( shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) if self.cfg.MC_DQL: state_pixels, edge_ids, sp_indices, edge_angles, counter = env.get_state( ) state_ep = [ state_pixels, edge_ids, sp_indices, edge_angles ] episode = [state_ep] state = counter while not env.done: # Calculate policy and values post_input = True if ( self.global_count.value() + 1) % 15 == 0 and env.counter == 0 else False round_n = env.counter # sample action for data collection if self.global_count.value() < self.cfg.num_seed_steps: action = torch.rand_like(env.current_edge_weights) else: _, _, _, action = self.agent_forward( env, shared_model, state=state_ep + [state], grad=False, post_input=post_input) action = action.cpu() (_, _, _, _, next_state ), reward, quality = env.execute_action(action) episode.append( (state, action, reward, next_state, env.done)) state = next_state memory.push(episode) if self.global_count.value( ) >= self.cfg.num_seed_steps and memory.is_full(): self._step_episodic_mem(memory, optimizers, env, shared_model, self.global_count.value(), writer=writer) self.global_writer_loss_count.increment() else: state = env.get_state() while not env.done: # Calculate policy and values post_input = True if ( self.global_count.value() + 1) % 15 == 0 and env.counter == 0 else False round_n = env.counter # sample action for data collection if self.global_count.value() < self.cfg.num_seed_steps: action = torch.rand_like(env.current_edge_weights) else: _, _, _, action = self.agent_forward( env, shared_model, state=state, grad=False, post_input=post_input) action = action.cpu() if self.global_count.value( ) >= self.cfg.num_seed_steps and memory.is_full(): self._step(memory, optimizers, env, shared_model, self.global_count.value(), writer=writer) self.global_writer_loss_count.increment() next_state, reward, quality = env.execute_action( action) memory.push(state, action, reward, next_state, env.done) # Train the network # self._step(memory, shared_model, env, optimizer, loss_weight, off_policy=True, writer=writer) # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward # Optionally clamp rewards # done = done or episode_length >= self.args.max_episode_length # Stop episodes at a max length state = next_state self.global_count.increment() if "self_reg" in self.args.eps_rule and quality <= 2: break dist.barrier() if rank == 0: if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save: # pass if self.args.model_name_dest != "": torch.save( shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name_dest)) print('saved') else: torch.save(shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) self.cleanup()
def fe_extr_warm_start(self, sp_feature_ext, writer=None): # dloader = DataLoader(MultiDiscSpGraphDsetBalanced(length=self.args.fe_warmup_iterations * 10), batch_size=10, # shuffle=True, pin_memory=True) dloader = DataLoader(MultiDiscSpGraphDset( length=self.args.fe_warmup_iterations * 10, less=True, no_suppix=False), batch_size=1, shuffle=True, pin_memory=True) contrastive_l = ContrastiveLoss(delta_var=0.5, delta_dist=1.5) dice = GraphDiceLoss() small_lcf = nn.Sequential( nn.Linear(sp_feature_ext.n_embedding_channels, 256), nn.Linear(256, 512), nn.Linear(512, 1024), nn.Linear(1024, 256), nn.Linear(256, 1), ) small_lcf.cuda(device=sp_feature_ext.device) optimizer = torch.optim.Adam(sp_feature_ext.parameters(), lr=1e-3) for i, (data, node_labeling, gt_pix, gt_edges, edge_index) in enumerate(dloader): data, node_labeling, gt_pix, gt_edges, edge_index = data.to(sp_feature_ext.device), \ node_labeling.squeeze().to(sp_feature_ext.device), \ gt_pix.to(sp_feature_ext.device), \ gt_edges.squeeze().to(sp_feature_ext.device), \ edge_index.squeeze().to(sp_feature_ext.device) node_labeling = node_labeling.squeeze() stacked_superpixels = [ node_labeling == n for n in node_labeling.unique() ] sp_indices = [sp.nonzero() for sp in stacked_superpixels] edge_features, pred_embeddings, side_loss = sp_feature_ext( data, edge_index, torch.zeros_like(gt_edges, dtype=torch.float), sp_indices) pred_edge_weights = small_lcf(edge_features) l2_reg = None if self.args.l2_reg_params_weight != 0: for W in list(sp_feature_ext.parameters()): if l2_reg is None: l2_reg = W.norm(2) else: l2_reg = l2_reg + W.norm(2) if l2_reg is None: l2_reg = 0 loss_pix = contrastive_l(pred_embeddings.unsqueeze(0), gt_pix) loss_edge = dice(pred_edge_weights.squeeze(), gt_edges.squeeze()) loss = loss_pix + self.args.weight_edge_loss * loss_edge + \ self.args.weight_side_loss * side_loss + l2_reg * self.args.l2_reg_params_weight optimizer.zero_grad() loss.backward() optimizer.step() if writer is not None: writer.add_scalar("loss/fe_warm_start/ttl", loss.item(), self.writer_idx_warmup_loss) writer.add_scalar("loss/fe_warm_start/pix_embeddings", loss_pix.item(), self.writer_idx_warmup_loss) writer.add_scalar("loss/fe_warm_start/edge_embeddings", loss_edge.item(), self.writer_idx_warmup_loss) writer.add_scalar("loss/fe_warm_start/gcn_sideloss", side_loss.item(), self.writer_idx_warmup_loss) self.writer_idx_warmup_loss += 1