Esempio n. 1
0
    def fe_extr_warm_start(self, sp_feature_ext, writer=None):
        # dloader = DataLoader(MultiDiscSpGraphDsetBalanced(length=self.args.fe_warmup_iterations * 10), batch_size=10,
        #                      shuffle=True, pin_memory=True)
        dloader = DataLoader(
            MultiDiscSpGraphDset(length=self.args.fe_warmup_iterations * 10),
            batch_size=10,
            shuffle=True,
            pin_memory=True)
        criterion = ContrastiveLoss(delta_var=0.5, delta_dist=1.5)
        optimizer = torch.optim.Adam(sp_feature_ext.parameters(), lr=2e-3)
        for i, (data, gt) in enumerate(dloader):
            data, gt = data.to(sp_feature_ext.device), gt.to(
                sp_feature_ext.device)
            pred = sp_feature_ext(data)

            l2_reg = None
            if self.args.l2_reg_params_weight != 0:
                for W in list(sp_feature_ext.parameters()):
                    if l2_reg is None:
                        l2_reg = W.norm(2)
                    else:
                        l2_reg = l2_reg + W.norm(2)
            if l2_reg is None:
                l2_reg = 0

            loss = criterion(pred,
                             gt) + l2_reg * self.args.l2_reg_params_weight
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if writer is not None:
                writer.add_scalar("loss/fe_warm_start", loss.item(),
                                  self.writer_idx_warmup_loss)
                self.writer_idx_warmup_loss += 1
Esempio n. 2
0
 def fe_extr_warm_start(self, sp_feature_ext, writer=None):
     dataloader = DataLoader(MultiDiscSpGraphDset(length=100), batch_size=10,
                             shuffle=True, pin_memory=True)
     criterion = ContrastiveLoss(delta_var=0.5, delta_dist=1.5)
     optimizer = torch.optim.Adam(sp_feature_ext.parameters())
     for i, (data, gt) in enumerate(dataloader):
         data, gt = data.to(sp_feature_ext.device), gt.to(sp_feature_ext.device)
         pred = sp_feature_ext(data[:,0,:,:].unsqueeze(1))
         loss = criterion(pred, gt)
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         if writer is not None:
             writer.add_scalar("loss/fe_warm_start", loss.item(), self.writer_idx_warmup_loss)
             self.writer_idx_warmup_loss += 1
Esempio n. 3
0
    def train(self, rank, start_time, return_dict):
        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        self.setup(rank, self.args.num_processes)

        transition = namedtuple('Transition',
                                ('state', 'action', 'reward', 'state_',
                                 'behav_policy_proba', 'time', 'terminal'))
        memory = TransitionData(capacity=self.args.t_max,
                                storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count)
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create shared network
        model = GcnEdgeAngle1dPQV(self.args.n_raw_channels,
                                  self.args.n_embedding_features,
                                  self.args.n_edge_features,
                                  self.args.n_actions, device)
        model.cuda(device)
        shared_model = DDP(model, device_ids=[model.device])
        # Create optimizer for shared network parameters with shared statistics
        optimizer = CstmAdam(shared_model.parameters(),
                             lr=self.args.lr,
                             betas=self.args.Adam_betas,
                             weight_decay=self.args.Adam_weight_decay)

        if self.args.fe_extr_warmup and rank == 0:
            fe_extr = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device)
            fe_extr.cuda(device)
            self.fe_extr_warm_start(fe_extr, writer=writer)
            shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict())
            if self.args.model_name == "":
                torch.save(fe_extr.state_dict(),
                           os.path.join(self.save_dir, 'agent_model'))
            else:
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))
        dist.barrier()
        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.fe_extr_warmup:
            print('loaded fe extractor')
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model')))

        self.shared_damped_model.load_state_dict(shared_model.state_dict())
        env.done = True  # Start new episode
        while self.global_count.value() <= self.args.T_max:
            if env.done:
                edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt = \
                    next(iter(dloader))
                edges, edge_feat, diff_to_gt, gt_edge_weights, node_labeling, raw, nodes, angles, affinities, gt = \
                    edges.squeeze().to(device), edge_feat.squeeze()[:, 0:self.args.n_edge_features].to(
                        device), diff_to_gt.squeeze().to(device), \
                    gt_edge_weights.squeeze().to(device), node_labeling.squeeze().to(device), raw.squeeze().to(
                        device), nodes.squeeze().to(device), \
                    angles.squeeze().to(device), affinities.squeeze().numpy(), gt.squeeze()
                env.update_data(edges, edge_feat, diff_to_gt, gt_edge_weights,
                                node_labeling, raw, nodes, angles, affinities,
                                gt)
                env.reset()
                state = [env.state[0].clone(), env.state[1].clone()]
                episode_length = 0

                self.eps = self.eps_rule.apply(self.global_count.value())
                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value())
                if writer is not None:
                    writer.add_scalar("step/epsilon", self.eps,
                                      env.writer_counter.value())

            while not env.done:
                # Calculate policy and values
                policy_proba, q, v = self.agent_forward(env,
                                                        shared_model,
                                                        grad=False)
                # average_policy_proba, _, _ = self.agent_forward(env, self.shared_average_model)
                # q_ret = v.detach()

                # Sample action
                # action = torch.multinomial(policy, 1)[0, 0]

                # Step
                action, behav_policy_proba = self.get_action(
                    policy_proba, q, v, policy='off_uniform', device=device)
                state_, reward = env.execute_action(action,
                                                    self.global_count.value())

                memory.push(state, action,
                            reward.to(shared_model.module.device), state_,
                            behav_policy_proba, episode_length, env.done)

                # Train the network
                self._step(memory,
                           shared_model,
                           env,
                           optimizer,
                           off_policy=True,
                           writer=writer)

                # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                episode_length += 1  # Increase episode counter
                state = state_

            # Break graph for last values calculated (used for targets, not directly as model outputs)
            self.global_count.increment()
            # Qret = 0 for terminal s

            while len(memory) > 0:
                self._step(memory,
                           shared_model,
                           env,
                           optimizer,
                           off_policy=True,
                           writer=writer)
                memory.pop(0)

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp:
                if self.args.model_name != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name))
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))
            else:
                test_score = 0
                env.writer = None
                for i in range(20):
                    self.update_env_data(env, dloader, device)
                    env.reset()
                    self.eps = 0
                    while not env.done:
                        # Calculate policy and values
                        policy_proba, q, v = self.agent_forward(env,
                                                                shared_model,
                                                                grad=False)
                        action, behav_policy_proba = self.get_action(
                            policy_proba,
                            q,
                            v,
                            policy='off_uniform',
                            device=device)
                        _, _ = env.execute_action(action,
                                                  self.global_count.value())
                    if env.win:
                        test_score += 1
                return_dict['test_score'] = test_score
                writer.add_text("time_needed", str((time.time() - start_time)))
        self.cleanup()
Esempio n. 4
0
    def train(self, rank, start_time, return_dict):

        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        self.setup(rank, self.args.num_processes)
        transition = namedtuple(
            'Transition',
            ('state', 'action', 'reward', 'behav_policy_proba', 'done'))
        memory = TransitionData(capacity=self.args.t_max,
                                storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count,
                       discrete_action_space=False)
        # Create shared network
        model = GcnEdgeAngle1dPQA_dueling_1(self.args.n_raw_channels,
                                            self.args.n_embedding_features,
                                            self.args.n_edge_features,
                                            1,
                                            self.args.exp_steps,
                                            self.args.p_sigma,
                                            device,
                                            self.args.density_eval_range,
                                            writer=writer)
        if self.args.no_fe_extr_optim:
            for param in model.fe_ext.parameters():
                param.requires_grad = False

        model.cuda(device)
        shared_model = DDP(model, device_ids=[model.device])
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create optimizer for shared network parameters with shared statistics
        optimizer = torch.optim.Adam(shared_model.parameters(),
                                     lr=self.args.lr,
                                     betas=self.args.Adam_betas,
                                     weight_decay=self.args.Adam_weight_decay)

        if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only:
            fe_extr = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device)
            fe_extr.cuda(device)
            self.fe_extr_warm_start(fe_extr, writer=writer)
            shared_model.module.fe_ext.load_state_dict(fe_extr.state_dict())
            if self.args.model_name == "":
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, 'agent_model'))
            else:
                torch.save(shared_model.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))

        dist.barrier()

        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.fe_extr_warmup:
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model')))

        self.shared_average_model.load_state_dict(shared_model.state_dict())

        if not self.args.test_score_only:
            quality = self.args.stop_qual_scaling + self.args.stop_qual_offset
            while self.global_count.value() <= self.args.T_max:
                if self.global_count.value() == 190:
                    a = 1
                self.update_env_data(env, dloader, device)
                env.reset()
                state = [env.state[0].clone(), env.state[1].clone()]

                self.b_sigma = self.b_sigma_rule.apply(
                    self.global_count.value(), quality)
                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value(), quality)

                with open(os.path.join(self.save_dir, 'config.yaml')) as info:
                    args_dict = yaml.full_load(info)
                    if args_dict is not None:
                        if 'eps' in args_dict:
                            if self.args.eps != args_dict['eps']:
                                self.eps = args_dict['eps']
                        if 'safe_model' in args_dict:
                            self.args.safe_model = args_dict['safe_model']
                        if 'add_noise' in args_dict:
                            self.args.add_noise = args_dict['add_noise']

                if writer is not None:
                    writer.add_scalar("step/b_variance", self.b_sigma,
                                      env.writer_counter.value())

                if self.args.safe_model:
                    if rank == 0:
                        if self.args.model_name_dest != "":
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir,
                                             self.args.model_name_dest))
                        else:
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir, 'agent_model'))

                while not env.done:
                    post_input = True if self.global_count.value(
                    ) % 50 and env.counter == 0 else False
                    # Calculate policy and values
                    policy_means, p_dis = self.agent_forward(
                        env,
                        shared_model,
                        grad=False,
                        stats_only=True,
                        post_input=post_input)

                    # Step
                    action, b_rvs = self.get_action(policy_means, p_dis,
                                                    device)
                    state_, reward, quality = env.execute_action(action)

                    if self.args.add_noise:
                        if self.global_count.value(
                        ) > 110 and self.global_count.value() % 5:
                            noise = torch.randn_like(reward) * 0.8
                            reward = reward + noise

                    memory.push(state, action, reward, b_rvs, env.done)

                    # Train the network
                    # self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer)

                    # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                    # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                    state = state_

                # Break graph for last values calculated (used for targets, not directly as model outputs)
                self.global_count.increment()

                self._step(memory,
                           shared_model,
                           env,
                           optimizer,
                           off_policy=True,
                           writer=writer)
                memory.clear()
                # while len(memory) > 0:
                #     self._step(memory, shared_model, env, optimizer, off_policy=True, writer=writer)
                #     memory.pop(0)

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save:
                # pass
                if self.args.model_name != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name))
                    print('saved')
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))
            if self.args.cross_validate_hp or self.args.test_score_only:
                test_score = 0
                env.writer = None
                for i in range(20):
                    self.update_env_data(env, dloader, device)
                    env.reset()
                    self.b_sigma = self.args.p_sigma
                    env.stop_quality = 40
                    while not env.done:
                        # Calculate policy and values
                        policy_means, p_dis = self.agent_forward(
                            env, shared_model, grad=False, stats_only=True)
                        action, b_rvs = self.get_action(
                            policy_means, p_dis, device)
                        _, _ = env.execute_action(action,
                                                  self.global_count.value())

                    # import matplotlib.pyplot as plt;
                    # plt.imshow(env.get_current_soln());
                    # plt.show()
                    if env.win:
                        test_score += 1
                return_dict['test_score'] = test_score
                writer.add_text("time_needed", str((time.time() - start_time)))
        self.cleanup()
Esempio n. 5
0
    def train(self):
        step_counter = 0
        device = torch.device("cuda:" + str(0))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        # Create shared network
        model = GcnEdgeAngle1dQ(self.args.n_raw_channels,
                                self.args.n_embedding_features,
                                self.args.n_edge_features,
                                1,
                                device,
                                writer=writer)

        if self.args.no_fe_extr_optim:
            for param in model.fe_ext.parameters():
                param.requires_grad = False

        model.cuda(device)
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = Adam(model.parameters(), lr=self.args.lr)
        loss = GraphDiceLoss()

        if self.args.fe_extr_warmup and not self.args.test_score_only:
            fe_extr = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device)
            fe_extr.cuda(device)
            self.fe_extr_warm_start(fe_extr, writer=writer)
            model.fe_ext.load_state_dict(fe_extr.state_dict())
            if self.args.model_name == "":
                torch.save(model.state_dict(),
                           os.path.join(self.save_dir, 'agent_model'))
            else:
                torch.save(model.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))

        if self.args.model_name != "":
            model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.fe_extr_warmup:
            print('loaded fe extractor')
            model.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model')))

        while step_counter <= self.args.T_max:
            if step_counter == 78:
                a = 1
            if (step_counter + 1) % 1000 == 0:
                post_input = True
            else:
                post_input = False
            with open(os.path.join(self.save_dir, 'config.yaml')) as info:
                args_dict = yaml.full_load(info)
                if args_dict is not None:
                    if 'lr' in args_dict:
                        self.args.lr = args_dict['lr']
                        adjust_learning_rate(optimizer, self.args.lr)

            round_n = 0

            raw, gt, sp_seg, sp_indices, edge_ids, edge_weights, gt_edges, edge_features = \
                self._get_data(dloader, device)

            inp = [
                obj.float().to(model.device)
                for obj in [edge_weights, sp_seg, raw + gt, sp_seg]
            ]
            pred, side_loss = model(inp,
                                    sp_indices=sp_indices,
                                    edge_index=edge_ids.to(model.device),
                                    angles=None,
                                    edge_features_1d=edge_features.to(
                                        model.device),
                                    round_n=round_n,
                                    post_input=post_input)

            pred = pred.squeeze()

            loss_val = loss(pred, gt_edges.to(device))

            ttl_loss = loss_val + side_loss
            quality = (pred - gt_edges.to(device)).abs().sum()

            optimizer.zero_grad()
            ttl_loss.backward()
            optimizer.step()

            if writer is not None:
                writer.add_scalar("step/lr", self.args.lr, step_counter)
                writer.add_scalar("step/dice_loss", loss_val.item(),
                                  step_counter)
                writer.add_scalar("step/side_loss", side_loss.item(),
                                  step_counter)
                writer.add_scalar("step/quality", quality.item(), step_counter)

            step_counter += 1

        a = 1
Esempio n. 6
0
    def train(self, rank, start_time, return_dict):
        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        self.setup(rank, self.args.num_processes)
        if self.cfg.MC_DQL:
            transition = namedtuple('Transition', ('episode'))
        else:
            transition = namedtuple(
                'Transition',
                ('state', 'action', 'reward', 'next_state', 'done'))
        memory = TransitionData_ts(capacity=self.args.t_max,
                                   storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count)
        # Create shared network

        model = GcnEdgeAC_1(self.cfg,
                            self.args.n_raw_channels,
                            self.args.n_embedding_features,
                            1,
                            device,
                            writer=writer)

        model.cuda(device)
        shared_model = DDP(model,
                           device_ids=[model.device],
                           find_unused_parameters=True)

        # dloader = DataLoader(MultiDiscSpGraphDsetBalanced(no_suppix=False, create=False), batch_size=1, shuffle=True, pin_memory=True,
        #                      num_workers=0)
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create optimizer for shared network parameters with shared statistics
        # optimizer = CstmAdam(shared_model.parameters(), lr=self.args.lr, betas=self.args.Adam_betas,
        #                      weight_decay=self.args.Adam_weight_decay)
        ######################
        self.action_range = 1
        self.device = torch.device(device)
        self.discount = 0.5
        self.critic_tau = self.cfg.critic_tau
        self.actor_update_frequency = self.cfg.actor_update_frequency
        self.critic_target_update_frequency = self.cfg.critic_target_update_frequency
        self.batch_size = self.cfg.batch_size

        self.log_alpha = torch.tensor(np.log(self.cfg.init_temperature)).to(
            self.device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        ######################
        # optimizers
        OptimizerContainer = namedtuple('OptimizerContainer',
                                        ('actor', 'critic', 'temperature'))
        actor_optimizer = torch.optim.Adam(
            shared_model.module.actor.parameters(),
            lr=self.cfg.actor_lr,
            betas=self.cfg.actor_betas)

        critic_optimizer = torch.optim.Adam(
            shared_model.module.critic.parameters(),
            lr=self.cfg.critic_lr,
            betas=self.cfg.critic_betas)

        temp_optimizer = torch.optim.Adam([self.log_alpha],
                                          lr=self.cfg.alpha_lr,
                                          betas=self.cfg.alpha_betas)

        optimizers = OptimizerContainer(actor_optimizer, critic_optimizer,
                                        temp_optimizer)

        if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only:
            fe_extr = shared_model.module.fe_ext
            fe_extr.cuda(device)
            self.fe_extr_warm_start_1(fe_extr, writer=writer)
            if self.args.model_name == "" and not self.args.no_save:
                torch.save(fe_extr.state_dict(),
                           os.path.join(self.save_dir, 'agent_model_fe_extr'))
            elif not self.args.no_save:
                torch.save(fe_extr.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))

        dist.barrier()
        for param in model.fe_ext.parameters():
            param.requires_grad = False

        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.model_fe_name != "":
            shared_model.module.fe_ext.load_state_dict(
                torch.load(os.path.join(self.save_dir,
                                        self.args.model_fe_name)))
        elif self.args.fe_extr_warmup:
            print('loaded fe extractor')
            shared_model.module.fe_ext.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model_fe_extr')))

        if not self.args.test_score_only:
            quality = self.args.stop_qual_scaling + self.args.stop_qual_offset
            while self.global_count.value() <= self.args.T_max:
                if self.global_count.value() == 78:
                    a = 1
                self.update_env_data(env, dloader, device)
                # waff_dis = torch.softmax(env.edge_features[:, 0].squeeze() + 1e-30, dim=0)
                # waff_dis = torch.softmax(env.gt_edge_weights + 0.5, dim=0)
                waff_dis = torch.softmax(torch.ones_like(env.gt_edge_weights),
                                         dim=0)
                loss_weight = torch.softmax(env.gt_edge_weights + 1, dim=0)
                env.reset()
                self.target_entropy = -float(env.gt_edge_weights.shape[0])

                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value(), quality)
                if self.cfg.temperature_regulation == 'follow_quality':
                    self.alpha = self.eps_rule.apply(self.global_count.value(),
                                                     quality)
                    print(self.alpha.item())

                with open(os.path.join(self.save_dir,
                                       'runtime_cfg.yaml')) as info:
                    args_dict = yaml.full_load(info)
                    if args_dict is not None:
                        if 'safe_model' in args_dict:
                            self.args.safe_model = args_dict['safe_model']
                        if 'critic_lr' in args_dict and args_dict[
                                'critic_lr'] != self.cfg.critic_lr:
                            self.cfg.critic_lr = args_dict['critic_lr']
                            adjust_learning_rate(critic_optimizer,
                                                 self.cfg.critic_lr)
                        if 'actor_lr' in args_dict and args_dict[
                                'actor_lr'] != self.cfg.actor_lr:
                            self.cfg.actor_lr = args_dict['actor_lr']
                            adjust_learning_rate(actor_optimizer,
                                                 self.cfg.actor_lr)
                        if 'alpha_lr' in args_dict and args_dict[
                                'alpha_lr'] != self.cfg.alpha_lr:
                            self.cfg.alpha_lr = args_dict['alpha_lr']
                            adjust_learning_rate(temp_optimizer,
                                                 self.cfg.alpha_lr)

                if self.args.safe_model and not self.args.no_save:
                    if rank == 0:
                        if self.args.model_name_dest != "":
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir,
                                             self.args.model_name_dest))
                        else:
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir, 'agent_model'))
                if self.cfg.MC_DQL:
                    state_pixels, edge_ids, sp_indices, edge_angles, counter = env.get_state(
                    )
                    state_ep = [
                        state_pixels, edge_ids, sp_indices, edge_angles
                    ]
                    episode = [state_ep]
                    state = counter
                    while not env.done:
                        # Calculate policy and values
                        post_input = True if (
                            self.global_count.value() +
                            1) % 15 == 0 and env.counter == 0 else False
                        round_n = env.counter
                        # sample action for data collection
                        if self.global_count.value() < self.cfg.num_seed_steps:
                            action = torch.rand_like(env.current_edge_weights)
                        else:
                            _, _, _, action = self.agent_forward(
                                env,
                                shared_model,
                                state=state_ep + [state],
                                grad=False,
                                post_input=post_input)

                        action = action.cpu()

                        (_, _, _, _, next_state
                         ), reward, quality = env.execute_action(action)

                        episode.append(
                            (state, action, reward, next_state, env.done))
                        state = next_state

                    memory.push(episode)
                    if self.global_count.value(
                    ) >= self.cfg.num_seed_steps and memory.is_full():
                        self._step_episodic_mem(memory,
                                                optimizers,
                                                env,
                                                shared_model,
                                                self.global_count.value(),
                                                writer=writer)
                        self.global_writer_loss_count.increment()
                else:
                    state = env.get_state()
                    while not env.done:
                        # Calculate policy and values
                        post_input = True if (
                            self.global_count.value() +
                            1) % 15 == 0 and env.counter == 0 else False
                        round_n = env.counter
                        # sample action for data collection
                        if self.global_count.value() < self.cfg.num_seed_steps:
                            action = torch.rand_like(env.current_edge_weights)
                        else:
                            _, _, _, action = self.agent_forward(
                                env,
                                shared_model,
                                state=state,
                                grad=False,
                                post_input=post_input)

                        action = action.cpu()
                        if self.global_count.value(
                        ) >= self.cfg.num_seed_steps and memory.is_full():
                            self._step(memory,
                                       optimizers,
                                       env,
                                       shared_model,
                                       self.global_count.value(),
                                       writer=writer)
                            self.global_writer_loss_count.increment()

                        next_state, reward, quality = env.execute_action(
                            action)

                        memory.push(state, action, reward, next_state,
                                    env.done)

                        # Train the network
                        # self._step(memory, shared_model, env, optimizer, loss_weight, off_policy=True, writer=writer)

                        # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                        # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                        state = next_state

                self.global_count.increment()
                if "self_reg" in self.args.eps_rule and quality <= 2:
                    break

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save:
                # pass
                if self.args.model_name_dest != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name_dest))
                    print('saved')
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))

        self.cleanup()
Esempio n. 7
0
    def fe_extr_warm_start(self, sp_feature_ext, writer=None):
        # dloader = DataLoader(MultiDiscSpGraphDsetBalanced(length=self.args.fe_warmup_iterations * 10), batch_size=10,
        #                      shuffle=True, pin_memory=True)
        dloader = DataLoader(MultiDiscSpGraphDset(
            length=self.args.fe_warmup_iterations * 10,
            less=True,
            no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True)
        contrastive_l = ContrastiveLoss(delta_var=0.5, delta_dist=1.5)
        dice = GraphDiceLoss()
        small_lcf = nn.Sequential(
            nn.Linear(sp_feature_ext.n_embedding_channels, 256),
            nn.Linear(256, 512),
            nn.Linear(512, 1024),
            nn.Linear(1024, 256),
            nn.Linear(256, 1),
        )
        small_lcf.cuda(device=sp_feature_ext.device)
        optimizer = torch.optim.Adam(sp_feature_ext.parameters(), lr=1e-3)
        for i, (data, node_labeling, gt_pix, gt_edges,
                edge_index) in enumerate(dloader):
            data, node_labeling, gt_pix, gt_edges, edge_index = data.to(sp_feature_ext.device), \
                                                                node_labeling.squeeze().to(sp_feature_ext.device), \
                                                                gt_pix.to(sp_feature_ext.device), \
                                                                gt_edges.squeeze().to(sp_feature_ext.device), \
                                                                edge_index.squeeze().to(sp_feature_ext.device)
            node_labeling = node_labeling.squeeze()
            stacked_superpixels = [
                node_labeling == n for n in node_labeling.unique()
            ]
            sp_indices = [sp.nonzero() for sp in stacked_superpixels]

            edge_features, pred_embeddings, side_loss = sp_feature_ext(
                data, edge_index,
                torch.zeros_like(gt_edges, dtype=torch.float), sp_indices)

            pred_edge_weights = small_lcf(edge_features)

            l2_reg = None
            if self.args.l2_reg_params_weight != 0:
                for W in list(sp_feature_ext.parameters()):
                    if l2_reg is None:
                        l2_reg = W.norm(2)
                    else:
                        l2_reg = l2_reg + W.norm(2)
            if l2_reg is None:
                l2_reg = 0

            loss_pix = contrastive_l(pred_embeddings.unsqueeze(0), gt_pix)
            loss_edge = dice(pred_edge_weights.squeeze(), gt_edges.squeeze())
            loss = loss_pix + self.args.weight_edge_loss * loss_edge + \
                   self.args.weight_side_loss * side_loss + l2_reg * self.args.l2_reg_params_weight
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if writer is not None:
                writer.add_scalar("loss/fe_warm_start/ttl", loss.item(),
                                  self.writer_idx_warmup_loss)
                writer.add_scalar("loss/fe_warm_start/pix_embeddings",
                                  loss_pix.item(), self.writer_idx_warmup_loss)
                writer.add_scalar("loss/fe_warm_start/edge_embeddings",
                                  loss_edge.item(),
                                  self.writer_idx_warmup_loss)
                writer.add_scalar("loss/fe_warm_start/gcn_sideloss",
                                  side_loss.item(),
                                  self.writer_idx_warmup_loss)
                self.writer_idx_warmup_loss += 1