コード例 #1
0
    def instantiate(self, model: Model) -> torch.optim.Adam:
        if self.layer_groups:
            parameters = mu.to_parameter_groups(model.get_layer_groups())

            if isinstance(self.lr, collections.Sequence):
                for idx, lr in enumerate(self.lr):
                    parameters[idx]['lr'] = lr

                default_lr = self.lr[0]
            else:
                default_lr = self.lr

            if isinstance(self.weight_decay, collections.Sequence):
                for idx, weight_decay in enumerate(self.weight_decay):
                    parameters[idx]['weight_decay'] = weight_decay

                default_weight_decay = self.weight_decay[0]
            else:
                default_weight_decay = self.weight_decay

            return torch.optim.Adam(parameters,
                                    lr=default_lr,
                                    betas=self.betas,
                                    eps=self.eps,
                                    weight_decay=default_weight_decay,
                                    amsgrad=self.amsgrad)
        else:
            parameters = filter(lambda p: p.requires_grad, model.parameters())

            return torch.optim.Adam(parameters,
                                    lr=self.lr,
                                    betas=self.betas,
                                    eps=self.eps,
                                    weight_decay=self.weight_decay,
                                    amsgrad=self.amsgrad)
コード例 #2
0
ファイル: sgd.py プロジェクト: ryan-leung/ml_monorepo
    def instantiate(self, model: Model) -> torch.optim.SGD:
        if self.layer_groups:
            parameters = mu.to_parameter_groups(model.get_layer_groups())
        else:
            parameters = filter(lambda p: p.requires_grad, model.parameters())

        return torch.optim.SGD(parameters,
                               lr=self.lr,
                               momentum=self.momentum,
                               dampening=self.dampening,
                               weight_decay=self.weight_decay,
                               nesterov=self.nesterov)
コード例 #3
0
    def resume(self, train_info: TrainingInfo, model: Model) -> dict:
        """
        Resume learning process and return loaded hidden state dictionary
        """
        last_epoch = train_info.start_epoch_idx

        model.load_state_dict(torch.load(self.checkpoint_filename(last_epoch)))
        hidden_state = torch.load(self.checkpoint_hidden_filename(last_epoch))

        self.checkpoint_strategy.restore(hidden_state)
        train_info.restore(hidden_state)

        return hidden_state
コード例 #4
0
ファイル: classic.py プロジェクト: wanjinchang/vel
    def checkpoint(self,
                   epoch_info: EpochInfo,
                   model: Model,
                   state_dict: dict = None):
        """ When epoch is done, we persist the training state """
        state_dict = state_dict if state_dict is not None else {}

        self.clean(epoch_info.global_epoch_idx)

        self._make_sure_dir_exists()

        # Checkpoint latest
        torch.save(model.state_dict(),
                   self.checkpoint_filename(epoch_info.global_epoch_idx))

        hidden_state = state_dict.copy()

        if epoch_info.optimizer is not None:
            hidden_state['optimizer'] = epoch_info.optimizer.state_dict()

        for callback in epoch_info.callbacks:
            callback.write_state_dict(hidden_state)

        self.checkpoint_strategy.write_state_dict(hidden_state)

        torch.save(
            hidden_state,
            self.checkpoint_hidden_filename(epoch_info.global_epoch_idx))

        if epoch_info.global_epoch_idx > 1 and self.checkpoint_strategy.should_delete_previous_checkpoint(
                epoch_info.global_epoch_idx):
            prev_epoch_idx = epoch_info.global_epoch_idx - 1

            os.remove(self.checkpoint_filename(prev_epoch_idx))
            os.remove(self.checkpoint_hidden_filename(prev_epoch_idx))

        if self.checkpoint_strategy.should_store_best_checkpoint(
                epoch_info.global_epoch_idx, epoch_info.result):
            best_checkpoint_idx = self.checkpoint_strategy.current_best_checkpoint_idx

            if best_checkpoint_idx is not None:
                os.remove(self.checkpoint_best_filename(best_checkpoint_idx))

            torch.save(
                model.state_dict(),
                self.checkpoint_best_filename(epoch_info.global_epoch_idx))

            self.checkpoint_strategy.store_best_checkpoint_idx(
                epoch_info.global_epoch_idx)

        self.backend.store(epoch_info.result)
コード例 #5
0
 def instantiate(self, model: Model) -> torch.optim.Adadelta:
     return torch.optim.Adadelta(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=self.lr,
                                 rho=self.rho,
                                 eps=self.eps,
                                 weight_decay=self.weight_decay)
コード例 #6
0
 def instantiate(self, model: Model) -> torch.optim.RMSprop:
     return torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                       model.parameters()),
                                lr=self.lr,
                                alpha=self.alpha,
                                eps=self.eps,
                                weight_decay=self.weight_decay,
                                momentum=self.momentum,
                                centered=self.centered)
コード例 #7
0
    def __init__(self, device: torch.device, settings: OnPolicyIterationReinforcerSettings, model: Model,
                 algo: AlgoBase, env_roller: EnvRollerBase) -> None:
        self.device = device
        self.settings = settings

        self._trained_model = model.to(self.device)

        self.env_roller = env_roller
        self.algo = algo
    def __init__(self, device: torch.device, settings: BufferedSingleOffPolicyIterationReinforcerSettings,
                 environment: gym.Env, model: Model, algo: AlgoBase, env_roller: ReplayEnvRollerBase):
        self.device = device
        self.settings = settings
        self.environment = environment

        self._trained_model = model.to(self.device)
        self.algo = algo

        self.env_roller = env_roller
コード例 #9
0
    def __init__(self, device: torch.device,
                 settings: BufferedMixedPolicyIterationReinforcerSettings,
                 env: VecEnv, model: Model, env_roller: ReplayEnvRollerBase,
                 algo: AlgoBase) -> None:
        self.device = device
        self.settings = settings

        self.environment = env
        self._trained_model = model.to(self.device)

        self.env_roller = env_roller
        self.algo = algo
コード例 #10
0
    def checkpoint(self, epoch_info: EpochInfo, model: Model):
        """ When epoch is done, we persist the training state """
        self.clean(epoch_info.global_epoch_idx - 1)

        self._make_sure_dir_exists()

        # Checkpoint latest
        torch.save(model.state_dict(),
                   self.checkpoint_filename(epoch_info.global_epoch_idx))

        hidden_state = epoch_info.state_dict()
        self.checkpoint_strategy.write_state_dict(hidden_state)

        torch.save(
            hidden_state,
            self.checkpoint_hidden_filename(epoch_info.global_epoch_idx))

        if epoch_info.global_epoch_idx > 1 and self.checkpoint_strategy.should_delete_previous_checkpoint(
                epoch_info.global_epoch_idx):
            prev_epoch_idx = epoch_info.global_epoch_idx - 1

            os.remove(self.checkpoint_filename(prev_epoch_idx))
            os.remove(self.checkpoint_hidden_filename(prev_epoch_idx))

        if self.checkpoint_strategy.should_store_best_checkpoint(
                epoch_info.global_epoch_idx, epoch_info.result):
            best_checkpoint_idx = self.checkpoint_strategy.current_best_checkpoint_idx

            if best_checkpoint_idx is not None:
                os.remove(self.checkpoint_best_filename(best_checkpoint_idx))

            torch.save(
                model.state_dict(),
                self.checkpoint_best_filename(epoch_info.global_epoch_idx))

            self.checkpoint_strategy.store_best_checkpoint_idx(
                epoch_info.global_epoch_idx)

        self.backend.store(epoch_info.result)