def __init__(self, cfg, args, device, writer=None): super(GcnEdgeAC, self).__init__() self.writer = writer self.args = args self.cfg = cfg self.log_std_bounds = cfg.diag_gaussian_actor.log_std_bounds self.device = device self.writer_counter = 0 n_q_vals = args.s_subgraph if "sg_rew" in args.algorithm: n_q_vals = 1 self.fe_ext = SpVecsUnet(self.args.n_raw_channels, self.args.n_embedding_features, device, writer) self.actor = PolicyNet( self.args.n_embedding_features * self.args.s_subgraph, self.args.s_subgraph * 2, args, device, writer) self.critic = DoubleQValueNet( (1 + self.args.n_embedding_features) * self.args.s_subgraph, n_q_vals, args, device, writer) self.critic_tgt = DoubleQValueNet( (1 + self.args.n_embedding_features) * self.args.s_subgraph, n_q_vals, args, device, writer)
def __init__(self, cfg, device, writer=None): super(GcnEdgeAC, self).__init__() self.writer = writer self.cfg = cfg self.log_std_bounds = self.cfg.sac.diag_gaussian_actor.log_std_bounds self.device = device self.writer_counter = 0 self.fe_ext = SpVecsUnet(self.cfg.fe.n_raw_channels, self.cfg.fe.n_embedding_features, device, writer) self.actor = PolicyNet(self.cfg.fe.n_embedding_features, 2, cfg.model.n_hidden, cfg.model.hl_factor, device, writer) self.critic = DoubleQValueNet(self.cfg.sac.s_subgraph, self.cfg.fe.n_embedding_features, 1, cfg.model.n_hidden, cfg.model.hl_factor, device, writer) self.critic_tgt = DoubleQValueNet(self.cfg.sac.s_subgraph, self.cfg.fe.n_embedding_features, 1, cfg.model.n_hidden, cfg.model.hl_factor, device, writer) self.log_alpha = torch.tensor([np.log(self.cfg.sac.init_temperature)] * len(self.cfg.sac.s_subgraph)).to(device) self.log_alpha.requires_grad = True
def __init__(self, n_raw_channels, n_embedding_channels, n_edge_features_in, n_edge_classes, device, softmax=True, writer=None): super(GcnEdgeAngle1dPQV, self).__init__() self.writer = writer self.fe_ext = SpVecsUnet(n_raw_channels, n_embedding_channels, device) n_embedding_channels += 1 self.softmax = softmax self.node_conv1 = NodeConv(n_embedding_channels, n_embedding_channels, n_hidden_layer=5) self.edge_conv1 = EdgeConv(n_embedding_channels, n_embedding_channels, 3 * n_embedding_channels, n_hidden_layer=5) self.node_conv2 = NodeConv(n_embedding_channels, n_embedding_channels, n_hidden_layer=5) self.edge_conv2 = EdgeConv(n_embedding_channels, n_embedding_channels, 3 * n_embedding_channels, use_init_edge_feats=True, n_init_edge_channels=3 * n_embedding_channels, n_hidden_layer=5) # self.lstm = nn.LSTMCell(n_embedding_channels + n_edge_features_in + 1, hidden_size) self.out_p = nn.Sequential( nn.Linear(n_embedding_channels + n_edge_features_in + 1, 256), nn.Linear(256, 512), nn.Linear(512, 1024), nn.Linear(1024, 256), nn.Linear(256, n_edge_classes), ) self.out_q = nn.Sequential( nn.Linear(n_embedding_channels + n_edge_features_in + 1, 256), nn.Linear(256, 512), nn.Linear(512, 1024), nn.Linear(1024, 256), nn.Linear(256, n_edge_classes), ) self.device = device self.writer_counter = 0
def __init__(self, n_raw_channels, n_embedding_channels, n_edge_features_in, n_edge_classes, exp_steps, p_sigma, device, density_eval_range, writer): super(GcnEdgeAngle1dPQA_dueling_1, self).__init__() self.writer = writer self.fe_ext = SpVecsUnet(n_raw_channels, n_embedding_channels, device) n_embedding_channels += 1 self.p_sigma = p_sigma self.density_eval_range = density_eval_range self.exp_steps = exp_steps self.edge_conv1 = EdgeConv(n_embedding_channels, n_embedding_channels, n_embedding_channels, n_hidden_layer=5) self.out_p1 = nn.Linear(n_embedding_channels + n_edge_features_in, 256) self.out_p2 = nn.Linear(256, n_edge_classes) self.out_v1 = nn.Linear(n_embedding_channels + n_edge_features_in, 256) self.out_v2 = nn.Linear(256, n_edge_classes) self.out_a1 = nn.Linear(n_embedding_channels + n_edge_features_in + 1, 256) self.out_a2 = nn.Linear(256, n_edge_classes) self.device = device self.writer_counter = 0
def __init__(self, n_raw_channels, n_embedding_channels, n_edge_features_in, n_edge_classes, exp_steps, p_sigma, device, density_eval_range): super(GcnEdgeAngle1dPQA_dueling, self).__init__() self.fe_ext = SpVecsUnet(n_raw_channels, n_embedding_channels, device) n_embedding_channels += 1 self.p_sigma = p_sigma self.density_eval_range = density_eval_range self.exp_steps = exp_steps self.node_conv1 = NodeConv(n_embedding_channels, n_embedding_channels, n_hidden_layer=5) self.edge_conv1 = EdgeConv(n_embedding_channels, n_embedding_channels, n_embedding_channels, n_hidden_layer=5) self.node_conv2 = NodeConv(n_embedding_channels, n_embedding_channels, n_hidden_layer=5) self.edge_conv2 = EdgeConv(n_embedding_channels, n_embedding_channels, n_embedding_channels, use_init_edge_feats=True, n_init_edge_channels=n_embedding_channels, n_hidden_layer=5) # self.lstm = nn.LSTMCell(n_embedding_channels + n_edge_features_in + 1, hidden_size) self.out_p1 = nn.Linear(n_embedding_channels + n_edge_features_in, 256) self.out_p2 = nn.Linear(256, n_edge_classes) self.out_v1 = nn.Linear(n_embedding_channels + n_edge_features_in, 256) self.out_v2 = nn.Linear(256, n_edge_classes) self.out_a1 = nn.Linear(n_embedding_channels + n_edge_features_in + 1, 256) self.out_a2 = nn.Linear(256, n_edge_classes) self.device = device
def train(self, rank, start_time, return_dict): device = torch.device("cuda:" + str(rank)) print('Running on device: ', device) torch.cuda.set_device(device) torch.set_default_tensor_type(torch.FloatTensor) writer = None if not self.args.cross_validate_hp: writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs')) # posting parameters param_string = "" for k, v in vars(self.args).items(): param_string += ' ' * 10 + k + ': ' + str(v) + '\n' writer.add_text("params", param_string) self.setup(rank, self.args.num_processes) transition = namedtuple('Transition', ('state', 'action', 'reward', 'state_', 'behav_policy_proba', 'time', 'terminal')) memory = TransitionData(capacity=self.args.t_max, storage_object=transition) env = SpGcnEnv(self.args, device, writer=writer, writer_counter=self.global_writer_quality_count, win_event_counter=self.global_win_event_count) dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False), batch_size=1, shuffle=True, pin_memory=True, num_workers=0) # Create shared network model = GcnEdgeAngle1dPQV(self.args.n_raw_channels, self.args.n_embedding_features, self.args.n_edge_features, self.args.n_actions, device) model.cuda(device) shared_model = DDP(model, device_ids=[model.device]) # Create optimizer for shared network parameters with shared statistics optimizer = CstmAdam(shared_model.parameters(), lr=self.args.lr, betas=self.args.Adam_betas, weight_decay=self.args.Adam_weight_decay) if self.args.fe_extr_warmup and rank == 0: fe_extr = SpVecsUnet(self.args.n_raw_channels, self.args.n_embedding_features, device) fe_extr.cuda(device) self.fe_extr_warm_start(fe_extr, writer=writer) shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict()) if self.args.model_name == "": torch.save(fe_extr.state_dict(), os.path.join(self.save_dir, 'agent_model')) else: torch.save(shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name)) dist.barrier() if self.args.model_name != "": shared_model.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_name))) elif self.args.fe_extr_warmup: print('loaded fe extractor') shared_model.load_state_dict( torch.load(os.path.join(self.save_dir, 'agent_model'))) self.shared_damped_model.load_state_dict(shared_model.state_dict()) env.done = True # Start new episode while self.global_count.value() <= self.args.T_max: if env.done: edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt = \ next(iter(dloader)) edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt = \ edges.squeeze().to(device), edge_feat.squeeze()[:, 0:self.args.n_edge_features].to( device), diff_to_gt.squeeze().to(device), \ gt_edge_weights.squeeze().to(device), node_labeling.squeeze().to(device), raw.squeeze().to( device), nodes.squeeze().to(device), \ angles.squeeze().to(device), affinities.squeeze().numpy(), gt.squeeze() env.update_data(edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt) env.reset() state = [env.state[0].clone(), env.state[1].clone()] episode_length = 0 self.eps = self.eps_rule.apply(self.global_count.value()) env.stop_quality = self.stop_qual_rule.apply( self.global_count.value()) if writer is not None: writer.add_scalar("step/epsilon", self.eps, env.writer_counter.value()) while not env.done: # Calculate policy and values policy_proba, q, v = self.agent_forward(env, shared_model, grad=False) # average_policy_proba, _, _ = self.agent_forward(env, self.shared_average_model) # q_ret = v.detach() # Sample action # action = torch.multinomial(policy, 1)[0, 0] # Step action, behav_policy_proba = self.get_action( policy_proba, q, v, policy='off_uniform', device=device) state_, reward = env.execute_action(action, self.global_count.value()) memory.push(state, action, reward.to(shared_model.module.device), state_, behav_policy_proba, episode_length, env.done) # Train the network self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer) # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward # Optionally clamp rewards # done = done or episode_length >= self.args.max_episode_length # Stop episodes at a max length episode_length += 1 # Increase episode counter state = state_ # Break graph for last values calculated (used for targets, not directly as model outputs) self.global_count.increment() # Qret = 0 for terminal s while len(memory) > 0: self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer) memory.pop(0) dist.barrier() if rank == 0: if not self.args.cross_validate_hp: if self.args.model_name != "": torch.save( shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name)) else: torch.save(shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) else: test_score = 0 env.writer = None for i in range(20): self.update_env_data(env, dloader, device) env.reset() self.eps = 0 while not env.done: # Calculate policy and values policy_proba, q, v = self.agent_forward(env, shared_model, grad=False) action, behav_policy_proba = self.get_action( policy_proba, q, v, policy='off_uniform', device=device) _, _ = env.execute_action(action, self.global_count.value()) if env.win: test_score += 1 return_dict['test_score'] = test_score writer.add_text("time_needed", str((time.time() - start_time))) self.cleanup()
def train(self, rank, start_time, return_dict): device = torch.device("cuda:" + str(rank)) print('Running on device: ', device) torch.cuda.set_device(device) torch.set_default_tensor_type(torch.FloatTensor) writer = None if not self.args.cross_validate_hp: writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs')) # posting parameters param_string = "" for k, v in vars(self.args).items(): param_string += ' ' * 10 + k + ': ' + str(v) + '\n' writer.add_text("params", param_string) self.setup(rank, self.args.num_processes) transition = namedtuple( 'Transition', ('state', 'action', 'reward', 'behav_policy_proba', 'done')) memory = TransitionData(capacity=self.args.t_max, storage_object=transition) env = SpGcnEnv(self.args, device, writer=writer, writer_counter=self.global_writer_quality_count, win_event_counter=self.global_win_event_count, discrete_action_space=False) # Create shared network model = GcnEdgeAngle1dPQA_dueling_1(self.args.n_raw_channels, self.args.n_embedding_features, self.args.n_edge_features, 1, self.args.exp_steps, self.args.p_sigma, device, self.args.density_eval_range, writer=writer) if self.args.no_fe_extr_optim: for param in model.fe_ext.parameters(): param.requires_grad = False model.cuda(device) shared_model = DDP(model, device_ids=[model.device]) dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False), batch_size=1, shuffle=True, pin_memory=True, num_workers=0) # Create optimizer for shared network parameters with shared statistics optimizer = torch.optim.Adam(shared_model.parameters(), lr=self.args.lr, betas=self.args.Adam_betas, weight_decay=self.args.Adam_weight_decay) if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only: fe_extr = SpVecsUnet(self.args.n_raw_channels, self.args.n_embedding_features, device) fe_extr.cuda(device) self.fe_extr_warm_start(fe_extr, writer=writer) shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict()) if self.args.model_name == "": torch.save(shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) else: torch.save(shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name)) dist.barrier() if self.args.model_name != "": shared_model.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_name))) elif self.args.fe_extr_warmup: shared_model.load_state_dict( torch.load(os.path.join(self.save_dir, 'agent_model'))) self.shared_average_model.load_state_dict(shared_model.state_dict()) if not self.args.test_score_only: quality = self.args.stop_qual_scaling + self.args.stop_qual_offset while self.global_count.value() <= self.args.T_max: if self.global_count.value() == 190: a = 1 self.update_env_data(env, dloader, device) env.reset() state = [env.state[0].clone(), env.state[1].clone()] self.b_sigma = self.b_sigma_rule.apply( self.global_count.value(), quality) env.stop_quality = self.stop_qual_rule.apply( self.global_count.value(), quality) with open(os.path.join(self.save_dir, 'config.yaml')) as info: args_dict = yaml.full_load(info) if args_dict is not None: if 'eps' in args_dict: if self.args.eps != args_dict['eps']: self.eps = args_dict['eps'] if 'safe_model' in args_dict: self.args.safe_model = args_dict['safe_model'] if 'add_noise' in args_dict: self.args.add_noise = args_dict['add_noise'] if writer is not None: writer.add_scalar("step/b_variance", self.b_sigma, env.writer_counter.value()) if self.args.safe_model: if rank == 0: if self.args.model_name_dest != "": torch.save( shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name_dest)) else: torch.save( shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) while not env.done: post_input = True if self.global_count.value( ) % 50 and env.counter == 0 else False # Calculate policy and values policy_means, p_dis = self.agent_forward( env, shared_model, grad=False, stats_only=True, post_input=post_input) # Step action, b_rvs = self.get_action(policy_means, p_dis, device) state_, reward, quality = env.execute_action(action) if self.args.add_noise: if self.global_count.value( ) > 110 and self.global_count.value() % 5: noise = torch.randn_like(reward) * 0.8 reward = reward + noise memory.push(state, action, reward, b_rvs, env.done) # Train the network # self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer) # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward # Optionally clamp rewards # done = done or episode_length >= self.args.max_episode_length # Stop episodes at a max length state = state_ # Break graph for last values calculated (used for targets, not directly as model outputs) self.global_count.increment() self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer) memory.clear() # while len(memory) > 0: # self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer) # memory.pop(0) dist.barrier() if rank == 0: if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save: # pass if self.args.model_name != "": torch.save( shared_model.state_dict(), os.path.join(self.save_dir, self.args.model_name)) print('saved') else: torch.save(shared_model.state_dict(), os.path.join(self.save_dir, 'agent_model')) if self.args.cross_validate_hp or self.args.test_score_only: test_score = 0 env.writer = None for i in range(20): self.update_env_data(env, dloader, device) env.reset() self.b_sigma = self.args.p_sigma env.stop_quality = 40 while not env.done: # Calculate policy and values policy_means, p_dis = self.agent_forward( env, shared_model, grad=False, stats_only=True) action, b_rvs = self.get_action( policy_means, p_dis, device) _, _ = env.execute_action(action, self.global_count.value()) # import matplotlib.pyplot as plt; # plt.imshow(env.get_current_soln()); # plt.show() if env.win: test_score += 1 return_dict['test_score'] = test_score writer.add_text("time_needed", str((time.time() - start_time))) self.cleanup()
class GcnEdgeAC(torch.nn.Module): def __init__(self, cfg, device, writer=None): super(GcnEdgeAC, self).__init__() self.writer = writer self.cfg = cfg self.log_std_bounds = self.cfg.sac.diag_gaussian_actor.log_std_bounds self.device = device self.writer_counter = 0 self.fe_ext = SpVecsUnet(self.cfg.fe.n_raw_channels, self.cfg.fe.n_embedding_features, device, writer) self.actor = PolicyNet(self.cfg.fe.n_embedding_features, 2, cfg.model.n_hidden, cfg.model.hl_factor, device, writer) self.critic = DoubleQValueNet(self.cfg.sac.s_subgraph, self.cfg.fe.n_embedding_features, 1, cfg.model.n_hidden, cfg.model.hl_factor, device, writer) self.critic_tgt = DoubleQValueNet(self.cfg.sac.s_subgraph, self.cfg.fe.n_embedding_features, 1, cfg.model.n_hidden, cfg.model.hl_factor, device, writer) self.log_alpha = torch.tensor([np.log(self.cfg.sac.init_temperature)] * len(self.cfg.sac.s_subgraph)).to(device) self.log_alpha.requires_grad = True @property def alpha(self): return self.log_alpha.exp() @alpha.setter def alpha(self, value): self.log_alpha = torch.tensor(np.log(value)).to(self.device) self.log_alpha.requires_grad = True def forward(self, raw, sp_seg, gt_edges=None, sp_indices=None, edge_index=None, angles=None, round_n=None, sub_graphs=None, sep_subgraphs=None, actions=None, post_input=False, policy_opt=False, embeddings_opt=False): if sp_indices is None: return self.fe_ext(raw, post_input) with torch.set_grad_enabled(embeddings_opt): embeddings = self.fe_ext(raw, post_input) node_feats = [] for i, sp_ind in enumerate(sp_indices): n_f = self.fe_ext.get_node_features(embeddings[i], sp_ind) node_feats.append(n_f) node_features = torch.cat(node_feats, dim=0) edge_index = torch.cat( [edge_index, torch.stack([edge_index[1], edge_index[0]], dim=0)], dim=1) # gcnn expects two directed edges for one undirected edge if actions is None: with torch.set_grad_enabled(policy_opt): out, side_loss = self.actor(node_features, edge_index, angles, gt_edges, post_input) mu, log_std = out.chunk(2, dim=-1) mu, log_std = mu.squeeze(), log_std.squeeze() if post_input and self.writer is not None: self.writer.add_histogram( "hist_logits/loc", mu.view(-1).detach().cpu().numpy(), self.writer_counter) self.writer.add_histogram( "hist_logits/scale", log_std.view(-1).detach().cpu().numpy(), self.writer_counter) self.writer_counter += 1 log_std = torch.tanh(log_std) log_std_min, log_std_max = self.log_std_bounds log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1) std = log_std.exp() dist = SigmNorm(mu, std) actions = dist.rsample() q1, q2, sl = self.critic_tgt(node_features, actions, edge_index, angles, sub_graphs, sep_subgraphs, gt_edges, post_input) side_loss = (side_loss + sl) / 2 if policy_opt: return dist, q1, q2, actions, side_loss else: # this means either exploration,critic opt or embedding opt return dist, q1, q2, actions, embeddings, side_loss q1, q2, side_loss = self.critic(node_features, actions, edge_index, angles, sub_graphs, sep_subgraphs, gt_edges, post_input) return q1, q2, side_loss
def train(self): step_counter = 0 device = torch.device("cuda:" + str(0)) print('Running on device: ', device) torch.cuda.set_device(device) torch.set_default_tensor_type(torch.FloatTensor) writer = None if not self.args.cross_validate_hp: writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs')) # posting parameters param_string = "" for k, v in vars(self.args).items(): param_string += ' ' * 10 + k + ': ' + str(v) + '\n' writer.add_text("params", param_string) # Create shared network model = GcnEdgeAngle1dQ(self.args.n_raw_channels, self.args.n_embedding_features, self.args.n_edge_features, 1, device, writer=writer) if self.args.no_fe_extr_optim: for param in model.fe_ext.parameters(): param.requires_grad = False model.cuda(device) dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False), batch_size=1, shuffle=True, pin_memory=True, num_workers=0) optimizer = Adam(model.parameters(), lr=self.args.lr) loss = GraphDiceLoss() if self.args.fe_extr_warmup and not self.args.test_score_only: fe_extr = SpVecsUnet(self.args.n_raw_channels, self.args.n_embedding_features, device) fe_extr.cuda(device) self.fe_extr_warm_start(fe_extr, writer=writer) model.fe_ext.load_state_dict(fe_extr.state_dict()) if self.args.model_name == "": torch.save(model.state_dict(), os.path.join(self.save_dir, 'agent_model')) else: torch.save(model.state_dict(), os.path.join(self.save_dir, self.args.model_name)) if self.args.model_name != "": model.load_state_dict( torch.load(os.path.join(self.save_dir, self.args.model_name))) elif self.args.fe_extr_warmup: print('loaded fe extractor') model.load_state_dict( torch.load(os.path.join(self.save_dir, 'agent_model'))) while step_counter <= self.args.T_max: if step_counter == 78: a = 1 if (step_counter + 1) % 1000 == 0: post_input = True else: post_input = False with open(os.path.join(self.save_dir, 'config.yaml')) as info: args_dict = yaml.full_load(info) if args_dict is not None: if 'lr' in args_dict: self.args.lr = args_dict['lr'] adjust_learning_rate(optimizer, self.args.lr) round_n = 0 raw, gt, sp_seg, sp_indices, edge_ids, edge_weights, gt_edges, edge_features = \ self._get_data(dloader, device) inp = [ obj.float().to(model.device) for obj in [edge_weights, sp_seg, raw + gt, sp_seg] ] pred, side_loss = model(inp, sp_indices=sp_indices, edge_index=edge_ids.to(model.device), angles=None, edge_features_1d=edge_features.to( model.device), round_n=round_n, post_input=post_input) pred = pred.squeeze() loss_val = loss(pred, gt_edges.to(device)) ttl_loss = loss_val + side_loss quality = (pred - gt_edges.to(device)).abs().sum() optimizer.zero_grad() ttl_loss.backward() optimizer.step() if writer is not None: writer.add_scalar("step/lr", self.args.lr, step_counter) writer.add_scalar("step/dice_loss", loss_val.item(), step_counter) writer.add_scalar("step/side_loss", side_loss.item(), step_counter) writer.add_scalar("step/quality", quality.item(), step_counter) step_counter += 1 a = 1
class GcnEdgeAC(torch.nn.Module): def __init__(self, cfg, args, device, writer=None): super(GcnEdgeAC, self).__init__() self.writer = writer self.args = args self.cfg = cfg self.log_std_bounds = cfg.diag_gaussian_actor.log_std_bounds self.device = device self.writer_counter = 0 n_q_vals = args.s_subgraph if "sg_rew" in args.algorithm: n_q_vals = 1 self.fe_ext = SpVecsUnet(self.args.n_raw_channels, self.args.n_embedding_features, device, writer) self.actor = PolicyNet( self.args.n_embedding_features * self.args.s_subgraph, self.args.s_subgraph * 2, args, device, writer) self.critic = DoubleQValueNet( (1 + self.args.n_embedding_features) * self.args.s_subgraph, n_q_vals, args, device, writer) self.critic_tgt = DoubleQValueNet( (1 + self.args.n_embedding_features) * self.args.s_subgraph, n_q_vals, args, device, writer) def forward(self, raw, gt_edges=None, sp_indices=None, edge_index=None, angles=None, round_n=None, sub_graphs=None, actions=None, post_input=False, policy_opt=False): if sp_indices is None: return self.fe_ext(raw) embeddings = self.fe_ext(raw) node_features = [] for i, sp_ind in enumerate(sp_indices): post_inp = False if post_input and i == 0: post_inp = True node_features.append( self.fe_ext.get_node_features(raw[i].squeeze(), embeddings[i].squeeze(), sp_ind, post_input=post_inp)) # create one large unconnected graph where each connected component corresponds to one image node_features = torch.cat(node_features, dim=0) node_features = torch.cat([ node_features, torch.ones([node_features.shape[0], 1], device=node_features.device) * round_n ], -1) edge_index = torch.cat( [edge_index, torch.stack([edge_index[1], edge_index[0]], dim=0)], dim=1) # gcnn expects two directed edges for one undirected edge if actions is None: with torch.set_grad_enabled(policy_opt): out = self.actor(node_features, edge_index, angles, sub_graphs, gt_edges, post_input) mu, log_std = out.chunk(2, dim=-1) mu, log_std = mu.squeeze(), log_std.squeeze() if post_input and self.writer is not None: self.writer.add_scalar("mean_logits/loc", mu.mean().item(), self.writer_counter) self.writer.add_scalar("mean_logits/scale", log_std.mean().item(), self.writer_counter) self.writer_counter += 1 # constrain log_std inside [log_std_min, log_std_max] log_std = torch.tanh(log_std) log_std_min, log_std_max = self.log_std_bounds log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1) std = log_std.exp() # dist = TruncNorm(mu, std, 0, 1, 0.005) dist = SigmNorm(mu, std) actions = dist.rsample() q1, q2 = self.critic_tgt(node_features, actions, edge_index, angles, sub_graphs, gt_edges, post_input) return dist, q1, q2, actions q1, q2 = self.critic(node_features, actions, edge_index, angles, sub_graphs, gt_edges, post_input) return q1, q2