def instantiate(self, model: Model) -> torch.optim.Adam: if self.layer_groups: parameters = mu.to_parameter_groups(model.get_layer_groups()) if isinstance(self.lr, collections.Sequence): for idx, lr in enumerate(self.lr): parameters[idx]['lr'] = lr default_lr = self.lr[0] else: default_lr = self.lr if isinstance(self.weight_decay, collections.Sequence): for idx, weight_decay in enumerate(self.weight_decay): parameters[idx]['weight_decay'] = weight_decay default_weight_decay = self.weight_decay[0] else: default_weight_decay = self.weight_decay return torch.optim.Adam(parameters, lr=default_lr, betas=self.betas, eps=self.eps, weight_decay=default_weight_decay, amsgrad=self.amsgrad) else: parameters = filter(lambda p: p.requires_grad, model.parameters()) return torch.optim.Adam(parameters, lr=self.lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, amsgrad=self.amsgrad)
def instantiate(self, model: Model) -> torch.optim.SGD: if self.layer_groups: parameters = mu.to_parameter_groups(model.get_layer_groups()) else: parameters = filter(lambda p: p.requires_grad, model.parameters()) return torch.optim.SGD(parameters, lr=self.lr, momentum=self.momentum, dampening=self.dampening, weight_decay=self.weight_decay, nesterov=self.nesterov)
def resume(self, train_info: TrainingInfo, model: Model) -> dict: """ Resume learning process and return loaded hidden state dictionary """ last_epoch = train_info.start_epoch_idx model.load_state_dict(torch.load(self.checkpoint_filename(last_epoch))) hidden_state = torch.load(self.checkpoint_hidden_filename(last_epoch)) self.checkpoint_strategy.restore(hidden_state) train_info.restore(hidden_state) return hidden_state
def checkpoint(self, epoch_info: EpochInfo, model: Model, state_dict: dict = None): """ When epoch is done, we persist the training state """ state_dict = state_dict if state_dict is not None else {} self.clean(epoch_info.global_epoch_idx) self._make_sure_dir_exists() # Checkpoint latest torch.save(model.state_dict(), self.checkpoint_filename(epoch_info.global_epoch_idx)) hidden_state = state_dict.copy() if epoch_info.optimizer is not None: hidden_state['optimizer'] = epoch_info.optimizer.state_dict() for callback in epoch_info.callbacks: callback.write_state_dict(hidden_state) self.checkpoint_strategy.write_state_dict(hidden_state) torch.save( hidden_state, self.checkpoint_hidden_filename(epoch_info.global_epoch_idx)) if epoch_info.global_epoch_idx > 1 and self.checkpoint_strategy.should_delete_previous_checkpoint( epoch_info.global_epoch_idx): prev_epoch_idx = epoch_info.global_epoch_idx - 1 os.remove(self.checkpoint_filename(prev_epoch_idx)) os.remove(self.checkpoint_hidden_filename(prev_epoch_idx)) if self.checkpoint_strategy.should_store_best_checkpoint( epoch_info.global_epoch_idx, epoch_info.result): best_checkpoint_idx = self.checkpoint_strategy.current_best_checkpoint_idx if best_checkpoint_idx is not None: os.remove(self.checkpoint_best_filename(best_checkpoint_idx)) torch.save( model.state_dict(), self.checkpoint_best_filename(epoch_info.global_epoch_idx)) self.checkpoint_strategy.store_best_checkpoint_idx( epoch_info.global_epoch_idx) self.backend.store(epoch_info.result)
def instantiate(self, model: Model) -> torch.optim.Adadelta: return torch.optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()), lr=self.lr, rho=self.rho, eps=self.eps, weight_decay=self.weight_decay)
def instantiate(self, model: Model) -> torch.optim.RMSprop: return torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=self.lr, alpha=self.alpha, eps=self.eps, weight_decay=self.weight_decay, momentum=self.momentum, centered=self.centered)
def __init__(self, device: torch.device, settings: OnPolicyIterationReinforcerSettings, model: Model, algo: AlgoBase, env_roller: EnvRollerBase) -> None: self.device = device self.settings = settings self._trained_model = model.to(self.device) self.env_roller = env_roller self.algo = algo
def __init__(self, device: torch.device, settings: BufferedSingleOffPolicyIterationReinforcerSettings, environment: gym.Env, model: Model, algo: AlgoBase, env_roller: ReplayEnvRollerBase): self.device = device self.settings = settings self.environment = environment self._trained_model = model.to(self.device) self.algo = algo self.env_roller = env_roller
def __init__(self, device: torch.device, settings: BufferedMixedPolicyIterationReinforcerSettings, env: VecEnv, model: Model, env_roller: ReplayEnvRollerBase, algo: AlgoBase) -> None: self.device = device self.settings = settings self.environment = env self._trained_model = model.to(self.device) self.env_roller = env_roller self.algo = algo
def checkpoint(self, epoch_info: EpochInfo, model: Model): """ When epoch is done, we persist the training state """ self.clean(epoch_info.global_epoch_idx - 1) self._make_sure_dir_exists() # Checkpoint latest torch.save(model.state_dict(), self.checkpoint_filename(epoch_info.global_epoch_idx)) hidden_state = epoch_info.state_dict() self.checkpoint_strategy.write_state_dict(hidden_state) torch.save( hidden_state, self.checkpoint_hidden_filename(epoch_info.global_epoch_idx)) if epoch_info.global_epoch_idx > 1 and self.checkpoint_strategy.should_delete_previous_checkpoint( epoch_info.global_epoch_idx): prev_epoch_idx = epoch_info.global_epoch_idx - 1 os.remove(self.checkpoint_filename(prev_epoch_idx)) os.remove(self.checkpoint_hidden_filename(prev_epoch_idx)) if self.checkpoint_strategy.should_store_best_checkpoint( epoch_info.global_epoch_idx, epoch_info.result): best_checkpoint_idx = self.checkpoint_strategy.current_best_checkpoint_idx if best_checkpoint_idx is not None: os.remove(self.checkpoint_best_filename(best_checkpoint_idx)) torch.save( model.state_dict(), self.checkpoint_best_filename(epoch_info.global_epoch_idx)) self.checkpoint_strategy.store_best_checkpoint_idx( epoch_info.global_epoch_idx) self.backend.store(epoch_info.result)