def update_rt_vars(self, critic_optimizer, actor_optimizer): with portalocker.Lock(os.path.join(self.log_dir, 'runtime_cfg.yaml'), 'rb+', timeout=60) as fh: with open(os.path.join(self.log_dir, 'runtime_cfg.yaml')) as info: args_dict = yaml.full_load(info) if args_dict is not None: if 'safe_model' in args_dict: self.cfg.rt_vars.safe_model = args_dict['safe_model'] args_dict['safe_model'] = False if 'add_noise' in args_dict: self.cfg.rt_vars.add_noise = args_dict['add_noise'] if 'critic_lr' in args_dict and args_dict[ 'critic_lr'] != self.cfg.sac.critic_lr: self.cfg.sac.critic_lr = args_dict['critic_lr'] adjust_learning_rate(critic_optimizer, self.cfg.sac.critic_lr) if 'actor_lr' in args_dict and args_dict[ 'actor_lr'] != self.cfg.sac.actor_lr: self.cfg.sac.actor_lr = args_dict['actor_lr'] adjust_learning_rate(actor_optimizer, self.cfg.sac.actor_lr) with open(os.path.join(self.log_dir, 'runtime_cfg.yaml'), "w") as info: yaml.dump(args_dict, info) # flush and sync to filesystem fh.flush() os.fsync(fh.fileno())
def _update_networks(self, loss, optimizer, shared_model, writer=None): # Zero shared and local grads optimizer.zero_grad() """ Calculate gradients for gradient descent on loss functions Note that math comments follow the paper, which is formulated for gradient ascent """ loss.backward() # Gradient L2 normalisation nn.utils.clip_grad_norm_(shared_model.parameters(), self.args.max_gradient_norm) optimizer.step() with open(os.path.join(self.save_dir, 'runtime_cfg.yaml')) as info: args_dict = yaml.full_load(info) if args_dict is not None: if 'lr' in args_dict: if self.args.lr != args_dict['lr']: print("lr changed from ", self.args.lr, " to ", args_dict['lr'], " at loss step ", self.global_writer_loss_count.value()) self.args.lr = args_dict['lr'] self.args.lr = args_dict['lr'] new_lr = self.args.lr if self.args.min_lr != 0 and self.eps <= 0.6: # Linearly decay learning rate # new_lr = self.args.lr - ((self.args.lr - self.args.min_lr) * (1 - (self.eps * 2))) # (1 - max((self.args.T_max - self.global_count.value()) / self.args.T_max, 1e-32))) new_lr = self.args.lr * 10**(-(0.6 - self.eps)) adjust_learning_rate(optimizer, new_lr) if writer is not None: writer.add_scalar("loss/learning_rate", new_lr, self.global_writer_loss_count.value())
def _update_networks(self, loss, optimizer, shared_model, writer=None): # Zero shared and local grads optimizer.zero_grad() """ Calculate gradients for gradient descent on loss functions Note that math comments follow the paper, which is formulated for gradient ascent """ loss.backward() # Gradient L2 normalisation nn.utils.clip_grad_norm_(shared_model.parameters(), self.args.max_gradient_norm) optimizer.step() if self.args.min_lr != 0: # Linearly decay learning rate new_lr = self.args.lr - ((self.args.lr - self.args.min_lr) * (1 - max((self.args.T_max - self.global_count.value()) / self.args.T_max, 1e-32))) adjust_learning_rate(optimizer, new_lr) if writer is not None: writer.add_scalar("loss/learning_rate", new_lr, self.global_writer_loss_count.value())
def train_step(self, rank, start_time, return_dict, writer): device = torch.device("cuda:" + str(rank)) print('Running on device: ', device) torch.cuda.set_device(device) torch.set_default_tensor_type(torch.FloatTensor) self.setup(rank, self.args.num_processes) if self.cfg.MC_DQL: transition = namedtuple('Transition', ('episode')) else: transition = namedtuple( 'Transition', ('state', 'action', 'reward', 'next_state', 'done')) memory = TransitionData_ts(capacity=self.args.t_max, storage_object=transition) env = SpGcnEnv(self.args, device, writer=writer, writer_counter=self.global_writer_quality_count, win_event_counter=self.global_win_event_count) # Create shared network # model = GcnEdgeAC_1(self.cfg, self.args.n_raw_channels, self.args.n_embedding_features, 1, device, writer=writer) model = GcnEdgeAC(self.cfg, self.args, device, writer=writer) # model = GcnEdgeAC(self.cfg, self.args.n_raw_channels, self.args.n_embedding_features, 1, device, writer=writer) model.cuda(device) shared_model = DDP(model, device_ids=[model.device], find_unused_parameters=True) # dloader = DataLoader(MultiDiscSpGraphDsetBalanced(no_suppix=False, create=False), batch_size=1, shuffle=True, pin_memory=True, # num_workers=0) dloader = DataLoader(SpgDset(), batch_size=self.cfg.batch_size, shuffle=True, pin_memory=True, num_workers=0) # Create optimizer for shared network parameters with shared statistics # optimizer = CstmAdam(shared_model.parameters(), lr=self.args.lr, betas=self.args.Adam_betas, # weight_decay=self.args.Adam_weight_decay) ###################### self.action_range = 1 self.device = torch.device(device) self.discount = 0.5 self.critic_tau = self.cfg.critic_tau self.actor_update_frequency = self.cfg.actor_update_frequency self.critic_target_update_frequency = self.cfg.critic_target_update_frequency self.batch_size = self.cfg.batch_size self.log_alpha = torch.tensor(np.log(self.cfg.init_temperature)).to( self.device) self.log_alpha.requires_grad = True # set target entropy to -|A| ###################### # optimizers OptimizerContainer = namedtuple('OptimizerContainer', ('actor', 'critic', 'temperature')) actor_optimizer = torch.optim.Adam( shared_model.module.actor.parameters(), lr=self.cfg.actor_lr, betas=self.cfg.actor_betas) critic_optimizer = torch.optim.Adam( shared_model.module.critic.parameters(), lr=self.cfg.critic_lr, betas=self.cfg.critic_betas) temp_optimizer = torch.optim.Adam([self.log_alpha], lr=self.cfg.alpha_lr, betas=self.cfg.alpha_betas) optimizers = OptimizerContainer(actor_optimizer, critic_optimizer, temp_optimizer) if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only: fe_extr = shared_model.module.fe_ext fe_extr.cuda(device) self.fe_extr_warm_start_1(fe_extr, writer=writer) # self.fe_extr_warm_start(fe_extr, writer=writer) if self.args.model_name == "" and not self.args.no_save: torch.save(fe_extr.state_dict(), os.path.join(self.save_dir, 'agent_model_fe_extr')) elif not self.args.no_save: torch.save(fe_extr.state_dict(), os.path.join(self.save_dir, self.args.model_name)) dist.barrier() for param in model.fe_ext.parameters(): param.requires_grad = False if self.args.model_name != "": shared_model.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_name))) elif self.args.model_fe_name != "": shared_model.module.fe_ext.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_fe_name))) elif self.args.fe_extr_warmup: print('loaded fe extractor') shared_model.module.fe_ext.load_state_dict( torch.load(os.path.join(self.save_dir, 'agent_model_fe_extr'))) if not self.args.test_score_only: quality = self.args.stop_qual_scaling + self.args.stop_qual_offset best_quality = np.inf last_quals = [] while self.global_count.value() <= self.args.T_max: if self.global_count.value() == 78: a = 1 self.update_env_data(env, dloader, device) # waff_dis = torch.softmax(env.edge_features[:, 0].squeeze() + 1e-30, dim=0) # waff_dis = torch.softmax(env.gt_edge_weights + 0.5, dim=0) waff_dis = torch.softmax(torch.ones_like( env.b_gt_edge_weights), dim=0) loss_weight = torch.softmax(env.b_gt_edge_weights + 1, dim=0) env.reset() # self.target_entropy = - float(env.gt_edge_weights.shape[0]) self.target_entropy = -8.0 env.stop_quality = self.stop_qual_rule.apply( self.global_count.value(), quality) if self.cfg.temperature_regulation == 'follow_quality': self.alpha = self.eps_rule.apply(self.global_count.value(), quality) print(self.alpha.item()) with open(os.path.join(self.save_dir, 'runtime_cfg.yaml')) as info: args_dict = yaml.full_load(info) if args_dict is not None: if 'safe_model' in args_dict: self.args.safe_model = args_dict['safe_model'] args_dict['safe_model'] = False if 'add_noise' in args_dict: self.args.add_noise = args_dict['add_noise'] if 'critic_lr' in args_dict and args_dict[ 'critic_lr'] != self.cfg.critic_lr: self.cfg.critic_lr = args_dict['critic_lr'] adjust_learning_rate(critic_optimizer, self.cfg.critic_lr) if 'actor_lr' in args_dict and args_dict[ 'actor_lr'] != self.cfg.actor_lr: self.cfg.actor_lr = args_dict['actor_lr'] adjust_learning_rate(actor_optimizer, self.cfg.actor_lr) if 'alpha_lr' in args_dict and args_dict[ 'alpha_lr'] != self.cfg.alpha_lr: self.cfg.alpha_lr = args_dict['alpha_lr'] adjust_learning_rate(temp_optimizer, self.cfg.alpha_lr) with open(os.path.join(self.save_dir, 'runtime_cfg.yaml'), "w") as info: yaml.dump(args_dict, info) if self.args.safe_model: best_quality = quality if rank == 0: if self.args.model_name_dest != "": torch.save( shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name_dest)) else: torch.save( shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) state = env.get_state() while not env.done: # Calculate policy and values post_input = True if ( self.global_count.value() + 1) % 15 == 0 and env.counter == 0 else False round_n = env.counter # sample action for data collection distr = None if self.global_count.value() < self.cfg.num_seed_steps: action = torch.rand_like(env.b_current_edge_weights) else: distr, _, _, action = self.agent_forward( env, shared_model, state=state, grad=False, post_input=post_input) logg_dict = {'temperature': self.alpha.item()} if distr is not None: logg_dict['mean_loc'] = distr.loc.mean().item() logg_dict['mean_scale'] = distr.scale.mean().item() if self.global_count.value( ) >= self.cfg.num_seed_steps and memory.is_full(): self._step(memory, optimizers, env, shared_model, self.global_count.value(), writer=writer) self.global_writer_loss_count.increment() next_state, reward, quality = env.execute_action( action, logg_dict) last_quals.append(quality) if len(last_quals) > 10: last_quals.pop(0) if self.args.add_noise: noise = torch.randn_like(reward) * self.alpha.item() reward = reward + noise memory.push(self.state_to_cpu(state), action, reward, self.state_to_cpu(next_state), env.done) # Train the network # self._step(memory, shared_model, env, optimizer, loss_weight, off_policy=True, writer=writer) # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward # Optionally clamp rewards # done = done or episode_length >= self.args.max_episode_length # Stop episodes at a max length state = next_state self.global_count.increment() dist.barrier() if rank == 0: if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save: # pass if self.args.model_name_dest != "": torch.save( shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name_dest)) print('saved') else: torch.save(shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) self.cleanup() return sum(last_quals) / 10
def train(self): step_counter = 0 device = torch.device("cuda:" + str(0)) print('Running on device: ', device) torch.cuda.set_device(device) torch.set_default_tensor_type(torch.FloatTensor) writer = None if not self.args.cross_validate_hp: writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs')) # posting parameters param_string = "" for k, v in vars(self.args).items(): param_string += ' ' * 10 + k + ': ' + str(v) + '\n' writer.add_text("params", param_string) # Create shared network model = GcnEdgeAngle1dQ(self.args.n_raw_channels, self.args.n_embedding_features, self.args.n_edge_features, 1, device, writer=writer) if self.args.no_fe_extr_optim: for param in model.fe_ext.parameters(): param.requires_grad = False model.cuda(device) dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False), batch_size=1, shuffle=True, pin_memory=True, num_workers=0) optimizer = Adam(model.parameters(), lr=self.args.lr) loss = GraphDiceLoss() if self.args.fe_extr_warmup and not self.args.test_score_only: fe_extr = SpVecsUnet(self.args.n_raw_channels, self.args.n_embedding_features, device) fe_extr.cuda(device) self.fe_extr_warm_start(fe_extr, writer=writer) model.fe_ext.load_state_dict(fe_extr.state_dict()) if self.args.model_name == "": torch.save(model.state_dict(), os.path.join(self.save_dir, 'agent_model')) else: torch.save(model.state_dict(), os.path.join(self.save_dir, self.args.model_name)) if self.args.model_name != "": model.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_name))) elif self.args.fe_extr_warmup: print('loaded fe extractor') model.load_state_dict( torch.load(os.path.join(self.save_dir, 'agent_model'))) while step_counter <= self.args.T_max: if step_counter == 78: a = 1 if (step_counter + 1) % 1000 == 0: post_input = True else: post_input = False with open(os.path.join(self.save_dir, 'config.yaml')) as info: args_dict = yaml.full_load(info) if args_dict is not None: if 'lr' in args_dict: self.args.lr = args_dict['lr'] adjust_learning_rate(optimizer, self.args.lr) round_n = 0 raw, gt, sp_seg, sp_indices, edge_ids, edge_weights, gt_edges, edge_features = \ self._get_data(dloader, device) inp = [ obj.float().to(model.device) for obj in [edge_weights, sp_seg, raw + gt, sp_seg] ] pred, side_loss = model(inp, sp_indices=sp_indices, edge_index=edge_ids.to(model.device), angles=None, edge_features_1d=edge_features.to( model.device), round_n=round_n, post_input=post_input) pred = pred.squeeze() loss_val = loss(pred, gt_edges.to(device)) ttl_loss = loss_val + side_loss quality = (pred - gt_edges.to(device)).abs().sum() optimizer.zero_grad() ttl_loss.backward() optimizer.step() if writer is not None: writer.add_scalar("step/lr", self.args.lr, step_counter) writer.add_scalar("step/dice_loss", loss_val.item(), step_counter) writer.add_scalar("step/side_loss", side_loss.item(), step_counter) writer.add_scalar("step/quality", quality.item(), step_counter) step_counter += 1 a = 1