Exemple #1
0
    def __init__(self, cfg, args, device, writer=None):
        super(GcnEdgeAC, self).__init__()
        self.writer = writer
        self.args = args
        self.cfg = cfg
        self.log_std_bounds = cfg.diag_gaussian_actor.log_std_bounds
        self.device = device
        self.writer_counter = 0
        n_q_vals = args.s_subgraph
        if "sg_rew" in args.algorithm:
            n_q_vals = 1

        self.fe_ext = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device,
                                 writer)

        self.actor = PolicyNet(
            self.args.n_embedding_features * self.args.s_subgraph,
            self.args.s_subgraph * 2, args, device, writer)
        self.critic = DoubleQValueNet(
            (1 + self.args.n_embedding_features) * self.args.s_subgraph,
            n_q_vals, args, device, writer)
        self.critic_tgt = DoubleQValueNet(
            (1 + self.args.n_embedding_features) * self.args.s_subgraph,
            n_q_vals, args, device, writer)
Exemple #2
0
    def __init__(self, cfg, device, writer=None):
        super(GcnEdgeAC, self).__init__()
        self.writer = writer
        self.cfg = cfg
        self.log_std_bounds = self.cfg.sac.diag_gaussian_actor.log_std_bounds
        self.device = device
        self.writer_counter = 0

        self.fe_ext = SpVecsUnet(self.cfg.fe.n_raw_channels,
                                 self.cfg.fe.n_embedding_features, device,
                                 writer)

        self.actor = PolicyNet(self.cfg.fe.n_embedding_features, 2,
                               cfg.model.n_hidden, cfg.model.hl_factor, device,
                               writer)
        self.critic = DoubleQValueNet(self.cfg.sac.s_subgraph,
                                      self.cfg.fe.n_embedding_features, 1,
                                      cfg.model.n_hidden, cfg.model.hl_factor,
                                      device, writer)
        self.critic_tgt = DoubleQValueNet(self.cfg.sac.s_subgraph,
                                          self.cfg.fe.n_embedding_features, 1,
                                          cfg.model.n_hidden,
                                          cfg.model.hl_factor, device, writer)

        self.log_alpha = torch.tensor([np.log(self.cfg.sac.init_temperature)] *
                                      len(self.cfg.sac.s_subgraph)).to(device)
        self.log_alpha.requires_grad = True
Exemple #3
0
    def __init__(self, n_raw_channels, n_embedding_channels, n_edge_features_in, n_edge_classes, device, softmax=True,
                 writer=None):
        super(GcnEdgeAngle1dPQV, self).__init__()
        self.writer = writer
        self.fe_ext = SpVecsUnet(n_raw_channels, n_embedding_channels, device)
        n_embedding_channels += 1
        self.softmax = softmax
        self.node_conv1 = NodeConv(n_embedding_channels, n_embedding_channels, n_hidden_layer=5)
        self.edge_conv1 = EdgeConv(n_embedding_channels, n_embedding_channels, 3 * n_embedding_channels, n_hidden_layer=5)
        self.node_conv2 = NodeConv(n_embedding_channels, n_embedding_channels, n_hidden_layer=5)
        self.edge_conv2 = EdgeConv(n_embedding_channels, n_embedding_channels, 3 * n_embedding_channels,
                                   use_init_edge_feats=True, n_init_edge_channels=3 * n_embedding_channels,
                                   n_hidden_layer=5)

        # self.lstm = nn.LSTMCell(n_embedding_channels + n_edge_features_in + 1, hidden_size)

        self.out_p = nn.Sequential(
            nn.Linear(n_embedding_channels + n_edge_features_in + 1, 256),
            nn.Linear(256, 512),
            nn.Linear(512, 1024),
            nn.Linear(1024, 256),
            nn.Linear(256, n_edge_classes),
        )

        self.out_q = nn.Sequential(
            nn.Linear(n_embedding_channels + n_edge_features_in + 1, 256),
            nn.Linear(256, 512),
            nn.Linear(512, 1024),
            nn.Linear(1024, 256),
            nn.Linear(256, n_edge_classes),
        )

        self.device = device
        self.writer_counter = 0
Exemple #4
0
    def __init__(self, n_raw_channels, n_embedding_channels, n_edge_features_in, n_edge_classes, exp_steps, p_sigma,
                 device, density_eval_range, writer):
        super(GcnEdgeAngle1dPQA_dueling_1, self).__init__()
        self.writer = writer
        self.fe_ext = SpVecsUnet(n_raw_channels, n_embedding_channels, device)
        n_embedding_channels += 1
        self.p_sigma = p_sigma
        self.density_eval_range = density_eval_range
        self.exp_steps = exp_steps
        self.edge_conv1 = EdgeConv(n_embedding_channels, n_embedding_channels, n_embedding_channels, n_hidden_layer=5)

        self.out_p1 = nn.Linear(n_embedding_channels + n_edge_features_in, 256)
        self.out_p2 = nn.Linear(256, n_edge_classes)
        self.out_v1 = nn.Linear(n_embedding_channels + n_edge_features_in, 256)
        self.out_v2 = nn.Linear(256, n_edge_classes)
        self.out_a1 = nn.Linear(n_embedding_channels + n_edge_features_in + 1, 256)
        self.out_a2 = nn.Linear(256, n_edge_classes)
        self.device = device
        self.writer_counter = 0
Exemple #5
0
    def __init__(self, n_raw_channels, n_embedding_channels,
                 n_edge_features_in, n_edge_classes, exp_steps, p_sigma,
                 device, density_eval_range):
        super(GcnEdgeAngle1dPQA_dueling, self).__init__()
        self.fe_ext = SpVecsUnet(n_raw_channels, n_embedding_channels, device)
        n_embedding_channels += 1
        self.p_sigma = p_sigma
        self.density_eval_range = density_eval_range
        self.exp_steps = exp_steps
        self.node_conv1 = NodeConv(n_embedding_channels,
                                   n_embedding_channels,
                                   n_hidden_layer=5)
        self.edge_conv1 = EdgeConv(n_embedding_channels,
                                   n_embedding_channels,
                                   n_embedding_channels,
                                   n_hidden_layer=5)
        self.node_conv2 = NodeConv(n_embedding_channels,
                                   n_embedding_channels,
                                   n_hidden_layer=5)
        self.edge_conv2 = EdgeConv(n_embedding_channels,
                                   n_embedding_channels,
                                   n_embedding_channels,
                                   use_init_edge_feats=True,
                                   n_init_edge_channels=n_embedding_channels,
                                   n_hidden_layer=5)

        # self.lstm = nn.LSTMCell(n_embedding_channels + n_edge_features_in + 1, hidden_size)

        self.out_p1 = nn.Linear(n_embedding_channels + n_edge_features_in, 256)
        self.out_p2 = nn.Linear(256, n_edge_classes)
        self.out_v1 = nn.Linear(n_embedding_channels + n_edge_features_in, 256)
        self.out_v2 = nn.Linear(256, n_edge_classes)
        self.out_a1 = nn.Linear(n_embedding_channels + n_edge_features_in + 1,
                                256)
        self.out_a2 = nn.Linear(256, n_edge_classes)
        self.device = device
Exemple #6
0
    def train(self, rank, start_time, return_dict):
        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        self.setup(rank, self.args.num_processes)

        transition = namedtuple('Transition',
                                ('state', 'action', 'reward', 'state_',
                                 'behav_policy_proba', 'time', 'terminal'))
        memory = TransitionData(capacity=self.args.t_max,
                                storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count)
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create shared network
        model = GcnEdgeAngle1dPQV(self.args.n_raw_channels,
                                  self.args.n_embedding_features,
                                  self.args.n_edge_features,
                                  self.args.n_actions, device)
        model.cuda(device)
        shared_model = DDP(model, device_ids=[model.device])
        # Create optimizer for shared network parameters with shared statistics
        optimizer = CstmAdam(shared_model.parameters(),
                             lr=self.args.lr,
                             betas=self.args.Adam_betas,
                             weight_decay=self.args.Adam_weight_decay)

        if self.args.fe_extr_warmup and rank == 0:
            fe_extr = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device)
            fe_extr.cuda(device)
            self.fe_extr_warm_start(fe_extr, writer=writer)
            shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict())
            if self.args.model_name == "":
                torch.save(fe_extr.state_dict(),
                           os.path.join(self.save_dir, 'agent_model'))
            else:
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))
        dist.barrier()
        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.fe_extr_warmup:
            print('loaded fe extractor')
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model')))

        self.shared_damped_model.load_state_dict(shared_model.state_dict())
        env.done = True  # Start new episode
        while self.global_count.value() <= self.args.T_max:
            if env.done:
                edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt = \
                    next(iter(dloader))
                edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt = \
                    edges.squeeze().to(device), edge_feat.squeeze()[:, 0:self.args.n_edge_features].to(
                        device), diff_to_gt.squeeze().to(device), \
                    gt_edge_weights.squeeze().to(device), node_labeling.squeeze().to(device), raw.squeeze().to(
                        device), nodes.squeeze().to(device), \
                    angles.squeeze().to(device), affinities.squeeze().numpy(), gt.squeeze()
                env.update_data(edges, edge_feat, diff_to_gt, gt_edge_weights,
                                node_labeling, raw, nodes, angles, affinities,
                                gt)
                env.reset()
                state = [env.state[0].clone(), env.state[1].clone()]
                episode_length = 0

                self.eps = self.eps_rule.apply(self.global_count.value())
                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value())
                if writer is not None:
                    writer.add_scalar("step/epsilon", self.eps,
                                      env.writer_counter.value())

            while not env.done:
                # Calculate policy and values
                policy_proba, q, v = self.agent_forward(env,
                                                        shared_model,
                                                        grad=False)
                # average_policy_proba, _, _ = self.agent_forward(env, self.shared_average_model)
                # q_ret = v.detach()

                # Sample action
                # action = torch.multinomial(policy, 1)[0, 0]

                # Step
                action, behav_policy_proba = self.get_action(
                    policy_proba, q, v, policy='off_uniform', device=device)
                state_, reward = env.execute_action(action,
                                                    self.global_count.value())

                memory.push(state, action,
                            reward.to(shared_model.module.device), state_,
                            behav_policy_proba, episode_length, env.done)

                # Train the network
                self._step(memory,
                           shared_model,
                           env,
                           optimizer,
                           off_policy=True,
                           writer=writer)

                # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                episode_length += 1  # Increase episode counter
                state = state_

            # Break graph for last values calculated (used for targets, not directly as model outputs)
            self.global_count.increment()
            # Qret = 0 for terminal s

            while len(memory) > 0:
                self._step(memory,
                           shared_model,
                           env,
                           optimizer,
                           off_policy=True,
                           writer=writer)
                memory.pop(0)

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp:
                if self.args.model_name != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name))
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))
            else:
                test_score = 0
                env.writer = None
                for i in range(20):
                    self.update_env_data(env, dloader, device)
                    env.reset()
                    self.eps = 0
                    while not env.done:
                        # Calculate policy and values
                        policy_proba, q, v = self.agent_forward(env,
                                                                shared_model,
                                                                grad=False)
                        action, behav_policy_proba = self.get_action(
                            policy_proba,
                            q,
                            v,
                            policy='off_uniform',
                            device=device)
                        _, _ = env.execute_action(action,
                                                  self.global_count.value())
                    if env.win:
                        test_score += 1
                return_dict['test_score'] = test_score
                writer.add_text("time_needed", str((time.time() - start_time)))
        self.cleanup()
Exemple #7
0
    def train(self, rank, start_time, return_dict):

        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        self.setup(rank, self.args.num_processes)
        transition = namedtuple(
            'Transition',
            ('state', 'action', 'reward', 'behav_policy_proba', 'done'))
        memory = TransitionData(capacity=self.args.t_max,
                                storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count,
                       discrete_action_space=False)
        # Create shared network
        model = GcnEdgeAngle1dPQA_dueling_1(self.args.n_raw_channels,
                                            self.args.n_embedding_features,
                                            self.args.n_edge_features,
                                            1,
                                            self.args.exp_steps,
                                            self.args.p_sigma,
                                            device,
                                            self.args.density_eval_range,
                                            writer=writer)
        if self.args.no_fe_extr_optim:
            for param in model.fe_ext.parameters():
                param.requires_grad = False

        model.cuda(device)
        shared_model = DDP(model, device_ids=[model.device])
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create optimizer for shared network parameters with shared statistics
        optimizer = torch.optim.Adam(shared_model.parameters(),
                                     lr=self.args.lr,
                                     betas=self.args.Adam_betas,
                                     weight_decay=self.args.Adam_weight_decay)

        if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only:
            fe_extr = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device)
            fe_extr.cuda(device)
            self.fe_extr_warm_start(fe_extr, writer=writer)
            shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict())
            if self.args.model_name == "":
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, 'agent_model'))
            else:
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))

        dist.barrier()

        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.fe_extr_warmup:
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model')))

        self.shared_average_model.load_state_dict(shared_model.state_dict())

        if not self.args.test_score_only:
            quality = self.args.stop_qual_scaling + self.args.stop_qual_offset
            while self.global_count.value() <= self.args.T_max:
                if self.global_count.value() == 190:
                    a = 1
                self.update_env_data(env, dloader, device)
                env.reset()
                state = [env.state[0].clone(), env.state[1].clone()]

                self.b_sigma = self.b_sigma_rule.apply(
                    self.global_count.value(), quality)
                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value(), quality)

                with open(os.path.join(self.save_dir, 'config.yaml')) as info:
                    args_dict = yaml.full_load(info)
                    if args_dict is not None:
                        if 'eps' in args_dict:
                            if self.args.eps != args_dict['eps']:
                                self.eps = args_dict['eps']
                        if 'safe_model' in args_dict:
                            self.args.safe_model = args_dict['safe_model']
                        if 'add_noise' in args_dict:
                            self.args.add_noise = args_dict['add_noise']

                if writer is not None:
                    writer.add_scalar("step/b_variance", self.b_sigma,
                                      env.writer_counter.value())

                if self.args.safe_model:
                    if rank == 0:
                        if self.args.model_name_dest != "":
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir,
                                             self.args.model_name_dest))
                        else:
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir, 'agent_model'))

                while not env.done:
                    post_input = True if self.global_count.value(
                    ) % 50 and env.counter == 0 else False
                    # Calculate policy and values
                    policy_means, p_dis = self.agent_forward(
                        env,
                        shared_model,
                        grad=False,
                        stats_only=True,
                        post_input=post_input)

                    # Step
                    action, b_rvs = self.get_action(policy_means, p_dis,
                                                    device)
                    state_, reward, quality = env.execute_action(action)

                    if self.args.add_noise:
                        if self.global_count.value(
                        ) > 110 and self.global_count.value() % 5:
                            noise = torch.randn_like(reward) * 0.8
                            reward = reward + noise

                    memory.push(state, action, reward, b_rvs, env.done)

                    # Train the network
                    # self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer)

                    # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                    # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                    state = state_

                # Break graph for last values calculated (used for targets, not directly as model outputs)
                self.global_count.increment()

                self._step(memory,
                           shared_model,
                           env,
                           optimizer,
                           off_policy=True,
                           writer=writer)
                memory.clear()
                # while len(memory) > 0:
                #     self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer)
                #     memory.pop(0)

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save:
                # pass
                if self.args.model_name != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name))
                    print('saved')
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))
            if self.args.cross_validate_hp or self.args.test_score_only:
                test_score = 0
                env.writer = None
                for i in range(20):
                    self.update_env_data(env, dloader, device)
                    env.reset()
                    self.b_sigma = self.args.p_sigma
                    env.stop_quality = 40
                    while not env.done:
                        # Calculate policy and values
                        policy_means, p_dis = self.agent_forward(
                            env, shared_model, grad=False, stats_only=True)
                        action, b_rvs = self.get_action(
                            policy_means, p_dis, device)
                        _, _ = env.execute_action(action,
                                                  self.global_count.value())

                    # import matplotlib.pyplot as plt;
                    # plt.imshow(env.get_current_soln());
                    # plt.show()
                    if env.win:
                        test_score += 1
                return_dict['test_score'] = test_score
                writer.add_text("time_needed", str((time.time() - start_time)))
        self.cleanup()
Exemple #8
0
class GcnEdgeAC(torch.nn.Module):
    def __init__(self, cfg, device, writer=None):
        super(GcnEdgeAC, self).__init__()
        self.writer = writer
        self.cfg = cfg
        self.log_std_bounds = self.cfg.sac.diag_gaussian_actor.log_std_bounds
        self.device = device
        self.writer_counter = 0

        self.fe_ext = SpVecsUnet(self.cfg.fe.n_raw_channels,
                                 self.cfg.fe.n_embedding_features, device,
                                 writer)

        self.actor = PolicyNet(self.cfg.fe.n_embedding_features, 2,
                               cfg.model.n_hidden, cfg.model.hl_factor, device,
                               writer)
        self.critic = DoubleQValueNet(self.cfg.sac.s_subgraph,
                                      self.cfg.fe.n_embedding_features, 1,
                                      cfg.model.n_hidden, cfg.model.hl_factor,
                                      device, writer)
        self.critic_tgt = DoubleQValueNet(self.cfg.sac.s_subgraph,
                                          self.cfg.fe.n_embedding_features, 1,
                                          cfg.model.n_hidden,
                                          cfg.model.hl_factor, device, writer)

        self.log_alpha = torch.tensor([np.log(self.cfg.sac.init_temperature)] *
                                      len(self.cfg.sac.s_subgraph)).to(device)
        self.log_alpha.requires_grad = True

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @alpha.setter
    def alpha(self, value):
        self.log_alpha = torch.tensor(np.log(value)).to(self.device)
        self.log_alpha.requires_grad = True

    def forward(self,
                raw,
                sp_seg,
                gt_edges=None,
                sp_indices=None,
                edge_index=None,
                angles=None,
                round_n=None,
                sub_graphs=None,
                sep_subgraphs=None,
                actions=None,
                post_input=False,
                policy_opt=False,
                embeddings_opt=False):

        if sp_indices is None:
            return self.fe_ext(raw, post_input)
        with torch.set_grad_enabled(embeddings_opt):
            embeddings = self.fe_ext(raw, post_input)
        node_feats = []
        for i, sp_ind in enumerate(sp_indices):
            n_f = self.fe_ext.get_node_features(embeddings[i], sp_ind)
            node_feats.append(n_f)

        node_features = torch.cat(node_feats, dim=0)

        edge_index = torch.cat(
            [edge_index,
             torch.stack([edge_index[1], edge_index[0]], dim=0)],
            dim=1)  # gcnn expects two directed edges for one undirected edge

        if actions is None:
            with torch.set_grad_enabled(policy_opt):
                out, side_loss = self.actor(node_features, edge_index, angles,
                                            gt_edges, post_input)
                mu, log_std = out.chunk(2, dim=-1)
                mu, log_std = mu.squeeze(), log_std.squeeze()

                if post_input and self.writer is not None:
                    self.writer.add_histogram(
                        "hist_logits/loc",
                        mu.view(-1).detach().cpu().numpy(),
                        self.writer_counter)
                    self.writer.add_histogram(
                        "hist_logits/scale",
                        log_std.view(-1).detach().cpu().numpy(),
                        self.writer_counter)
                    self.writer_counter += 1

                log_std = torch.tanh(log_std)
                log_std_min, log_std_max = self.log_std_bounds
                log_std = log_std_min + 0.5 * (log_std_max -
                                               log_std_min) * (log_std + 1)

                std = log_std.exp()

                dist = SigmNorm(mu, std)
                actions = dist.rsample()

            q1, q2, sl = self.critic_tgt(node_features, actions, edge_index,
                                         angles, sub_graphs, sep_subgraphs,
                                         gt_edges, post_input)
            side_loss = (side_loss + sl) / 2
            if policy_opt:
                return dist, q1, q2, actions, side_loss
            else:
                # this means either exploration,critic opt or embedding opt
                return dist, q1, q2, actions, embeddings, side_loss

        q1, q2, side_loss = self.critic(node_features, actions, edge_index,
                                        angles, sub_graphs, sep_subgraphs,
                                        gt_edges, post_input)
        return q1, q2, side_loss
Exemple #9
0
    def train(self):
        step_counter = 0
        device = torch.device("cuda:" + str(0))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        # Create shared network
        model = GcnEdgeAngle1dQ(self.args.n_raw_channels,
                                self.args.n_embedding_features,
                                self.args.n_edge_features,
                                1,
                                device,
                                writer=writer)

        if self.args.no_fe_extr_optim:
            for param in model.fe_ext.parameters():
                param.requires_grad = False

        model.cuda(device)
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = Adam(model.parameters(), lr=self.args.lr)
        loss = GraphDiceLoss()

        if self.args.fe_extr_warmup and not self.args.test_score_only:
            fe_extr = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device)
            fe_extr.cuda(device)
            self.fe_extr_warm_start(fe_extr, writer=writer)
            model.fe_ext.load_state_dict(fe_extr.state_dict())
            if self.args.model_name == "":
                torch.save(model.state_dict(),
                           os.path.join(self.save_dir, 'agent_model'))
            else:
                torch.save(model.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))

        if self.args.model_name != "":
            model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.fe_extr_warmup:
            print('loaded fe extractor')
            model.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model')))

        while step_counter <= self.args.T_max:
            if step_counter == 78:
                a = 1
            if (step_counter + 1) % 1000 == 0:
                post_input = True
            else:
                post_input = False
            with open(os.path.join(self.save_dir, 'config.yaml')) as info:
                args_dict = yaml.full_load(info)
                if args_dict is not None:
                    if 'lr' in args_dict:
                        self.args.lr = args_dict['lr']
                        adjust_learning_rate(optimizer, self.args.lr)

            round_n = 0

            raw, gt, sp_seg, sp_indices, edge_ids, edge_weights, gt_edges, edge_features = \
                self._get_data(dloader, device)

            inp = [
                obj.float().to(model.device)
                for obj in [edge_weights, sp_seg, raw + gt, sp_seg]
            ]
            pred, side_loss = model(inp,
                                    sp_indices=sp_indices,
                                    edge_index=edge_ids.to(model.device),
                                    angles=None,
                                    edge_features_1d=edge_features.to(
                                        model.device),
                                    round_n=round_n,
                                    post_input=post_input)

            pred = pred.squeeze()

            loss_val = loss(pred, gt_edges.to(device))

            ttl_loss = loss_val + side_loss
            quality = (pred - gt_edges.to(device)).abs().sum()

            optimizer.zero_grad()
            ttl_loss.backward()
            optimizer.step()

            if writer is not None:
                writer.add_scalar("step/lr", self.args.lr, step_counter)
                writer.add_scalar("step/dice_loss", loss_val.item(),
                                  step_counter)
                writer.add_scalar("step/side_loss", side_loss.item(),
                                  step_counter)
                writer.add_scalar("step/quality", quality.item(), step_counter)

            step_counter += 1

        a = 1
Exemple #10
0
class GcnEdgeAC(torch.nn.Module):
    def __init__(self, cfg, args, device, writer=None):
        super(GcnEdgeAC, self).__init__()
        self.writer = writer
        self.args = args
        self.cfg = cfg
        self.log_std_bounds = cfg.diag_gaussian_actor.log_std_bounds
        self.device = device
        self.writer_counter = 0
        n_q_vals = args.s_subgraph
        if "sg_rew" in args.algorithm:
            n_q_vals = 1

        self.fe_ext = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device,
                                 writer)

        self.actor = PolicyNet(
            self.args.n_embedding_features * self.args.s_subgraph,
            self.args.s_subgraph * 2, args, device, writer)
        self.critic = DoubleQValueNet(
            (1 + self.args.n_embedding_features) * self.args.s_subgraph,
            n_q_vals, args, device, writer)
        self.critic_tgt = DoubleQValueNet(
            (1 + self.args.n_embedding_features) * self.args.s_subgraph,
            n_q_vals, args, device, writer)

    def forward(self,
                raw,
                gt_edges=None,
                sp_indices=None,
                edge_index=None,
                angles=None,
                round_n=None,
                sub_graphs=None,
                actions=None,
                post_input=False,
                policy_opt=False):

        if sp_indices is None:
            return self.fe_ext(raw)
        embeddings = self.fe_ext(raw)
        node_features = []
        for i, sp_ind in enumerate(sp_indices):
            post_inp = False
            if post_input and i == 0:
                post_inp = True
            node_features.append(
                self.fe_ext.get_node_features(raw[i].squeeze(),
                                              embeddings[i].squeeze(),
                                              sp_ind,
                                              post_input=post_inp))

        # create one large unconnected graph where each connected component corresponds to one image
        node_features = torch.cat(node_features, dim=0)
        node_features = torch.cat([
            node_features,
            torch.ones([node_features.shape[0], 1],
                       device=node_features.device) * round_n
        ], -1)

        edge_index = torch.cat(
            [edge_index,
             torch.stack([edge_index[1], edge_index[0]], dim=0)],
            dim=1)  # gcnn expects two directed edges for one undirected edge

        if actions is None:
            with torch.set_grad_enabled(policy_opt):
                out = self.actor(node_features, edge_index, angles, sub_graphs,
                                 gt_edges, post_input)
                mu, log_std = out.chunk(2, dim=-1)
                mu, log_std = mu.squeeze(), log_std.squeeze()

                if post_input and self.writer is not None:
                    self.writer.add_scalar("mean_logits/loc",
                                           mu.mean().item(),
                                           self.writer_counter)
                    self.writer.add_scalar("mean_logits/scale",
                                           log_std.mean().item(),
                                           self.writer_counter)
                    self.writer_counter += 1

                # constrain log_std inside [log_std_min, log_std_max]
                log_std = torch.tanh(log_std)
                log_std_min, log_std_max = self.log_std_bounds
                log_std = log_std_min + 0.5 * (log_std_max -
                                               log_std_min) * (log_std + 1)

                std = log_std.exp()

                # dist = TruncNorm(mu, std, 0, 1, 0.005)
                dist = SigmNorm(mu, std)
                actions = dist.rsample()

            q1, q2 = self.critic_tgt(node_features, actions, edge_index,
                                     angles, sub_graphs, gt_edges, post_input)
            return dist, q1, q2, actions

        q1, q2 = self.critic(node_features, actions, edge_index, angles,
                             sub_graphs, gt_edges, post_input)
        return q1, q2