Example #1
0
    def update_rt_vars(self, critic_optimizer, actor_optimizer):
        with portalocker.Lock(os.path.join(self.log_dir, 'runtime_cfg.yaml'),
                              'rb+',
                              timeout=60) as fh:
            with open(os.path.join(self.log_dir, 'runtime_cfg.yaml')) as info:
                args_dict = yaml.full_load(info)
                if args_dict is not None:
                    if 'safe_model' in args_dict:
                        self.cfg.rt_vars.safe_model = args_dict['safe_model']
                        args_dict['safe_model'] = False
                    if 'add_noise' in args_dict:
                        self.cfg.rt_vars.add_noise = args_dict['add_noise']
                    if 'critic_lr' in args_dict and args_dict[
                            'critic_lr'] != self.cfg.sac.critic_lr:
                        self.cfg.sac.critic_lr = args_dict['critic_lr']
                        adjust_learning_rate(critic_optimizer,
                                             self.cfg.sac.critic_lr)
                    if 'actor_lr' in args_dict and args_dict[
                            'actor_lr'] != self.cfg.sac.actor_lr:
                        self.cfg.sac.actor_lr = args_dict['actor_lr']
                        adjust_learning_rate(actor_optimizer,
                                             self.cfg.sac.actor_lr)
            with open(os.path.join(self.log_dir, 'runtime_cfg.yaml'),
                      "w") as info:
                yaml.dump(args_dict, info)

            # flush and sync to filesystem
            fh.flush()
            os.fsync(fh.fileno())
Example #2
0
    def _update_networks(self, loss, optimizer, shared_model, writer=None):
        # Zero shared and local grads
        optimizer.zero_grad()
        """
        Calculate gradients for gradient descent on loss functions
        Note that math comments follow the paper, which is formulated for gradient ascent
        """
        loss.backward()
        # Gradient L2 normalisation
        nn.utils.clip_grad_norm_(shared_model.parameters(),
                                 self.args.max_gradient_norm)
        optimizer.step()
        with open(os.path.join(self.save_dir, 'runtime_cfg.yaml')) as info:
            args_dict = yaml.full_load(info)
            if args_dict is not None:
                if 'lr' in args_dict:
                    if self.args.lr != args_dict['lr']:
                        print("lr changed from ", self.args.lr, " to ",
                              args_dict['lr'], " at loss step ",
                              self.global_writer_loss_count.value())
                        self.args.lr = args_dict['lr']
        self.args.lr = args_dict['lr']
        new_lr = self.args.lr
        if self.args.min_lr != 0 and self.eps <= 0.6:
            # Linearly decay learning rate
            # new_lr = self.args.lr - ((self.args.lr - self.args.min_lr) * (1 - (self.eps * 2))) # (1 - max((self.args.T_max - self.global_count.value()) / self.args.T_max, 1e-32)))
            new_lr = self.args.lr * 10**(-(0.6 - self.eps))

        adjust_learning_rate(optimizer, new_lr)
        if writer is not None:
            writer.add_scalar("loss/learning_rate", new_lr,
                              self.global_writer_loss_count.value())
Example #3
0
    def _update_networks(self, loss, optimizer, shared_model, writer=None):
        # Zero shared and local grads
        optimizer.zero_grad()
        """
        Calculate gradients for gradient descent on loss functions
        Note that math comments follow the paper, which is formulated for gradient ascent
        """
        loss.backward()
        # Gradient L2 normalisation
        nn.utils.clip_grad_norm_(shared_model.parameters(), self.args.max_gradient_norm)
        optimizer.step()
        if self.args.min_lr != 0:
            # Linearly decay learning rate
            new_lr = self.args.lr - ((self.args.lr - self.args.min_lr) *
                                     (1 - max((self.args.T_max - self.global_count.value()) / self.args.T_max, 1e-32)))
            adjust_learning_rate(optimizer, new_lr)

            if writer is not None:
                writer.add_scalar("loss/learning_rate", new_lr, self.global_writer_loss_count.value())
    def train_step(self, rank, start_time, return_dict, writer):
        device = torch.device("cuda:" + str(rank))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        self.setup(rank, self.args.num_processes)
        if self.cfg.MC_DQL:
            transition = namedtuple('Transition', ('episode'))
        else:
            transition = namedtuple(
                'Transition',
                ('state', 'action', 'reward', 'next_state', 'done'))
        memory = TransitionData_ts(capacity=self.args.t_max,
                                   storage_object=transition)

        env = SpGcnEnv(self.args,
                       device,
                       writer=writer,
                       writer_counter=self.global_writer_quality_count,
                       win_event_counter=self.global_win_event_count)
        # Create shared network

        # model = GcnEdgeAC_1(self.cfg, self.args.n_raw_channels, self.args.n_embedding_features, 1, device, writer=writer)
        model = GcnEdgeAC(self.cfg, self.args, device, writer=writer)
        # model = GcnEdgeAC(self.cfg, self.args.n_raw_channels, self.args.n_embedding_features, 1, device, writer=writer)

        model.cuda(device)
        shared_model = DDP(model,
                           device_ids=[model.device],
                           find_unused_parameters=True)

        # dloader = DataLoader(MultiDiscSpGraphDsetBalanced(no_suppix=False, create=False), batch_size=1, shuffle=True, pin_memory=True,
        #                      num_workers=0)
        dloader = DataLoader(SpgDset(),
                             batch_size=self.cfg.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        # Create optimizer for shared network parameters with shared statistics
        # optimizer = CstmAdam(shared_model.parameters(), lr=self.args.lr, betas=self.args.Adam_betas,
        #                      weight_decay=self.args.Adam_weight_decay)
        ######################
        self.action_range = 1
        self.device = torch.device(device)
        self.discount = 0.5
        self.critic_tau = self.cfg.critic_tau
        self.actor_update_frequency = self.cfg.actor_update_frequency
        self.critic_target_update_frequency = self.cfg.critic_target_update_frequency
        self.batch_size = self.cfg.batch_size

        self.log_alpha = torch.tensor(np.log(self.cfg.init_temperature)).to(
            self.device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        ######################
        # optimizers
        OptimizerContainer = namedtuple('OptimizerContainer',
                                        ('actor', 'critic', 'temperature'))
        actor_optimizer = torch.optim.Adam(
            shared_model.module.actor.parameters(),
            lr=self.cfg.actor_lr,
            betas=self.cfg.actor_betas)

        critic_optimizer = torch.optim.Adam(
            shared_model.module.critic.parameters(),
            lr=self.cfg.critic_lr,
            betas=self.cfg.critic_betas)

        temp_optimizer = torch.optim.Adam([self.log_alpha],
                                          lr=self.cfg.alpha_lr,
                                          betas=self.cfg.alpha_betas)

        optimizers = OptimizerContainer(actor_optimizer, critic_optimizer,
                                        temp_optimizer)

        if self.args.fe_extr_warmup and rank == 0 and not self.args.test_score_only:
            fe_extr = shared_model.module.fe_ext
            fe_extr.cuda(device)
            self.fe_extr_warm_start_1(fe_extr, writer=writer)
            # self.fe_extr_warm_start(fe_extr, writer=writer)
            if self.args.model_name == "" and not self.args.no_save:
                torch.save(fe_extr.state_dict(),
                           os.path.join(self.save_dir, 'agent_model_fe_extr'))
            elif not self.args.no_save:
                torch.save(fe_extr.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))

        dist.barrier()
        for param in model.fe_ext.parameters():
            param.requires_grad = False

        if self.args.model_name != "":
            shared_model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.model_fe_name != "":
            shared_model.module.fe_ext.load_state_dict(
                torch.load(os.path.join(self.save_dir,
                                        self.args.model_fe_name)))
        elif self.args.fe_extr_warmup:
            print('loaded fe extractor')
            shared_model.module.fe_ext.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model_fe_extr')))

        if not self.args.test_score_only:
            quality = self.args.stop_qual_scaling + self.args.stop_qual_offset
            best_quality = np.inf
            last_quals = []
            while self.global_count.value() <= self.args.T_max:
                if self.global_count.value() == 78:
                    a = 1
                self.update_env_data(env, dloader, device)
                # waff_dis = torch.softmax(env.edge_features[:, 0].squeeze() + 1e-30, dim=0)
                # waff_dis = torch.softmax(env.gt_edge_weights + 0.5, dim=0)
                waff_dis = torch.softmax(torch.ones_like(
                    env.b_gt_edge_weights),
                                         dim=0)
                loss_weight = torch.softmax(env.b_gt_edge_weights + 1, dim=0)
                env.reset()
                # self.target_entropy = - float(env.gt_edge_weights.shape[0])
                self.target_entropy = -8.0

                env.stop_quality = self.stop_qual_rule.apply(
                    self.global_count.value(), quality)
                if self.cfg.temperature_regulation == 'follow_quality':
                    self.alpha = self.eps_rule.apply(self.global_count.value(),
                                                     quality)
                    print(self.alpha.item())

                with open(os.path.join(self.save_dir,
                                       'runtime_cfg.yaml')) as info:
                    args_dict = yaml.full_load(info)
                    if args_dict is not None:
                        if 'safe_model' in args_dict:
                            self.args.safe_model = args_dict['safe_model']
                            args_dict['safe_model'] = False
                        if 'add_noise' in args_dict:
                            self.args.add_noise = args_dict['add_noise']
                        if 'critic_lr' in args_dict and args_dict[
                                'critic_lr'] != self.cfg.critic_lr:
                            self.cfg.critic_lr = args_dict['critic_lr']
                            adjust_learning_rate(critic_optimizer,
                                                 self.cfg.critic_lr)
                        if 'actor_lr' in args_dict and args_dict[
                                'actor_lr'] != self.cfg.actor_lr:
                            self.cfg.actor_lr = args_dict['actor_lr']
                            adjust_learning_rate(actor_optimizer,
                                                 self.cfg.actor_lr)
                        if 'alpha_lr' in args_dict and args_dict[
                                'alpha_lr'] != self.cfg.alpha_lr:
                            self.cfg.alpha_lr = args_dict['alpha_lr']
                            adjust_learning_rate(temp_optimizer,
                                                 self.cfg.alpha_lr)
                with open(os.path.join(self.save_dir, 'runtime_cfg.yaml'),
                          "w") as info:
                    yaml.dump(args_dict, info)

                if self.args.safe_model:
                    best_quality = quality
                    if rank == 0:
                        if self.args.model_name_dest != "":
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir,
                                             self.args.model_name_dest))
                        else:
                            torch.save(
                                shared_model.state_dict(),
                                os.path.join(self.save_dir, 'agent_model'))

                state = env.get_state()
                while not env.done:
                    # Calculate policy and values
                    post_input = True if (
                        self.global_count.value() +
                        1) % 15 == 0 and env.counter == 0 else False
                    round_n = env.counter
                    # sample action for data collection
                    distr = None
                    if self.global_count.value() < self.cfg.num_seed_steps:
                        action = torch.rand_like(env.b_current_edge_weights)
                    else:
                        distr, _, _, action = self.agent_forward(
                            env,
                            shared_model,
                            state=state,
                            grad=False,
                            post_input=post_input)

                    logg_dict = {'temperature': self.alpha.item()}
                    if distr is not None:
                        logg_dict['mean_loc'] = distr.loc.mean().item()
                        logg_dict['mean_scale'] = distr.scale.mean().item()

                    if self.global_count.value(
                    ) >= self.cfg.num_seed_steps and memory.is_full():
                        self._step(memory,
                                   optimizers,
                                   env,
                                   shared_model,
                                   self.global_count.value(),
                                   writer=writer)
                        self.global_writer_loss_count.increment()

                    next_state, reward, quality = env.execute_action(
                        action, logg_dict)

                    last_quals.append(quality)
                    if len(last_quals) > 10:
                        last_quals.pop(0)

                    if self.args.add_noise:
                        noise = torch.randn_like(reward) * self.alpha.item()
                        reward = reward + noise

                    memory.push(self.state_to_cpu(state), action, reward,
                                self.state_to_cpu(next_state), env.done)

                    # Train the network
                    # self._step(memory, shared_model, env, optimizer, loss_weight, off_policy=True, writer=writer)

                    # reward = self.args.reward_clip and min(max(reward, -1), 1) or reward  # Optionally clamp rewards
                    # done = done or episode_length >= self.args.max_episode_length  # Stop episodes at a max length
                    state = next_state

                self.global_count.increment()

        dist.barrier()
        if rank == 0:
            if not self.args.cross_validate_hp and not self.args.test_score_only and not self.args.no_save:
                # pass
                if self.args.model_name_dest != "":
                    torch.save(
                        shared_model.state_dict(),
                        os.path.join(self.save_dir, self.args.model_name_dest))
                    print('saved')
                else:
                    torch.save(shared_model.state_dict(),
                               os.path.join(self.save_dir, 'agent_model'))

        self.cleanup()
        return sum(last_quals) / 10
Example #5
0
    def train(self):
        step_counter = 0
        device = torch.device("cuda:" + str(0))
        print('Running on device: ', device)
        torch.cuda.set_device(device)
        torch.set_default_tensor_type(torch.FloatTensor)

        writer = None
        if not self.args.cross_validate_hp:
            writer = SummaryWriter(logdir=os.path.join(self.save_dir, 'logs'))
            # posting parameters
            param_string = ""
            for k, v in vars(self.args).items():
                param_string += ' ' * 10 + k + ': ' + str(v) + '\n'
            writer.add_text("params", param_string)

        # Create shared network
        model = GcnEdgeAngle1dQ(self.args.n_raw_channels,
                                self.args.n_embedding_features,
                                self.args.n_edge_features,
                                1,
                                device,
                                writer=writer)

        if self.args.no_fe_extr_optim:
            for param in model.fe_ext.parameters():
                param.requires_grad = False

        model.cuda(device)
        dloader = DataLoader(MultiDiscSpGraphDset(no_suppix=False),
                             batch_size=1,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
        optimizer = Adam(model.parameters(), lr=self.args.lr)
        loss = GraphDiceLoss()

        if self.args.fe_extr_warmup and not self.args.test_score_only:
            fe_extr = SpVecsUnet(self.args.n_raw_channels,
                                 self.args.n_embedding_features, device)
            fe_extr.cuda(device)
            self.fe_extr_warm_start(fe_extr, writer=writer)
            model.fe_ext.load_state_dict(fe_extr.state_dict())
            if self.args.model_name == "":
                torch.save(model.state_dict(),
                           os.path.join(self.save_dir, 'agent_model'))
            else:
                torch.save(model.state_dict(),
                           os.path.join(self.save_dir, self.args.model_name))

        if self.args.model_name != "":
            model.load_state_dict(
                torch.load(os.path.join(self.save_dir, self.args.model_name)))
        elif self.args.fe_extr_warmup:
            print('loaded fe extractor')
            model.load_state_dict(
                torch.load(os.path.join(self.save_dir, 'agent_model')))

        while step_counter <= self.args.T_max:
            if step_counter == 78:
                a = 1
            if (step_counter + 1) % 1000 == 0:
                post_input = True
            else:
                post_input = False
            with open(os.path.join(self.save_dir, 'config.yaml')) as info:
                args_dict = yaml.full_load(info)
                if args_dict is not None:
                    if 'lr' in args_dict:
                        self.args.lr = args_dict['lr']
                        adjust_learning_rate(optimizer, self.args.lr)

            round_n = 0

            raw, gt, sp_seg, sp_indices, edge_ids, edge_weights, gt_edges, edge_features = \
                self._get_data(dloader, device)

            inp = [
                obj.float().to(model.device)
                for obj in [edge_weights, sp_seg, raw + gt, sp_seg]
            ]
            pred, side_loss = model(inp,
                                    sp_indices=sp_indices,
                                    edge_index=edge_ids.to(model.device),
                                    angles=None,
                                    edge_features_1d=edge_features.to(
                                        model.device),
                                    round_n=round_n,
                                    post_input=post_input)

            pred = pred.squeeze()

            loss_val = loss(pred, gt_edges.to(device))

            ttl_loss = loss_val + side_loss
            quality = (pred - gt_edges.to(device)).abs().sum()

            optimizer.zero_grad()
            ttl_loss.backward()
            optimizer.step()

            if writer is not None:
                writer.add_scalar("step/lr", self.args.lr, step_counter)
                writer.add_scalar("step/dice_loss", loss_val.item(),
                                  step_counter)
                writer.add_scalar("step/side_loss", side_loss.item(),
                                  step_counter)
                writer.add_scalar("step/quality", quality.item(), step_counter)

            step_counter += 1

        a = 1