def instantiate(self, model: Model) -> torch.optim.Adam: if self.layer_groups: parameters = mu.to_parameter_groups(model.get_layer_groups()) if isinstance(self.lr, collections.Sequence): for idx, lr in enumerate(self.lr): parameters[idx]['lr'] = lr default_lr = self.lr[0] else: default_lr = float(self.lr) if isinstance(self.weight_decay, collections.Sequence): for idx, weight_decay in enumerate(self.weight_decay): parameters[idx]['weight_decay'] = weight_decay default_weight_decay = self.weight_decay[0] else: default_weight_decay = self.weight_decay return torch.optim.Adam(parameters, lr=default_lr, betas=self.betas, eps=self.eps, weight_decay=default_weight_decay, amsgrad=self.amsgrad) else: parameters = filter(lambda p: p.requires_grad, model.parameters()) return torch.optim.Adam(parameters, lr=self.lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, amsgrad=self.amsgrad)
def checkpoint(self, epoch_info: EpochInfo, model: Model): """ When epoch is done, we persist the training state """ self.clean(epoch_info.global_epoch_idx - 1) self._make_sure_dir_exists() # Checkpoint latest torch.save(model.state_dict(), self.checkpoint_filename(epoch_info.global_epoch_idx)) hidden_state = epoch_info.state_dict() self.checkpoint_strategy.write_state_dict(hidden_state) torch.save(hidden_state, self.checkpoint_hidden_filename(epoch_info.global_epoch_idx)) if epoch_info.global_epoch_idx > 1 and self.checkpoint_strategy.should_delete_previous_checkpoint( epoch_info.global_epoch_idx): prev_epoch_idx = epoch_info.global_epoch_idx - 1 os.remove(self.checkpoint_filename(prev_epoch_idx)) os.remove(self.checkpoint_hidden_filename(prev_epoch_idx)) if self.checkpoint_strategy.should_store_best_checkpoint(epoch_info.global_epoch_idx, epoch_info.result): best_checkpoint_idx = self.checkpoint_strategy.current_best_checkpoint_idx if best_checkpoint_idx is not None: os.remove(self.checkpoint_best_filename(best_checkpoint_idx)) torch.save(model.state_dict(), self.checkpoint_best_filename(epoch_info.global_epoch_idx)) self.checkpoint_strategy.store_best_checkpoint_idx(epoch_info.global_epoch_idx) self.backend.store(epoch_info.result)
def rollout(self, batch_info: BatchInfo, model: Model, number_of_steps: int) -> Rollout: """ Calculate env rollout """ accumulator = TensorAccumulator() episode_information = [] # List of dictionaries with episode information if self.hidden_state is None and model.is_recurrent: self.hidden_state = model.zero_state(self.last_observation.size(0)).to(self.device) # Remember rollout initial state, we'll use that for training as well initial_hidden_state = self.hidden_state for step_idx in range(number_of_steps): if model.is_recurrent: step = model.step(self.last_observation.to(self.device), state=self.hidden_state) self.hidden_state = step['state'] else: step = model.step(self.last_observation.to(self.device)) # Add step to the tensor accumulator for name, tensor in step.items(): accumulator.add(name, tensor.cpu()) accumulator.add('observations', self.last_observation) actions_numpy = step['actions'].detach().cpu().numpy() new_obs, new_rewards, new_dones, new_infos = self.environment.step(actions_numpy) # Done is flagged true when the episode has ended AND the frame we see is already a first frame from the # next episode dones_tensor = torch.from_numpy(new_dones.astype(np.float32)).clone() accumulator.add('dones', dones_tensor) self.last_observation = torch.from_numpy(new_obs).clone() if model.is_recurrent: # Zero out state in environments that have finished self.hidden_state = self.hidden_state * (1.0 - dones_tensor.unsqueeze(-1)).to(self.device) accumulator.add('rewards', torch.from_numpy(new_rewards.astype(np.float32)).clone()) episode_information.append(new_infos) if model.is_recurrent: final_values = model.value(self.last_observation.to(self.device), state=self.hidden_state).cpu() else: final_values = model.value(self.last_observation.to(self.device)).cpu() accumulated_tensors = accumulator.result() return Trajectories( num_steps=accumulated_tensors['observations'].size(0), num_envs=accumulated_tensors['observations'].size(1), environment_information=episode_information, transition_tensors=accumulated_tensors, rollout_tensors={ 'initial_hidden_state': initial_hidden_state.cpu() if initial_hidden_state is not None else None, 'final_values': final_values } )
def instantiate(self, model: Model) -> torch.optim.SGD: if self.layer_groups: parameters = mu.to_parameter_groups(model.get_layer_groups()) else: parameters = filter(lambda p: p.requires_grad, model.parameters()) return torch.optim.SGD( parameters, lr=self.lr, momentum=self.momentum, dampening=self.dampening, weight_decay=self.weight_decay, nesterov=self.nesterov )
def instantiate(self, model: Model) -> torch.optim.Adadelta: return torch.optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()), lr=self.lr, rho=self.rho, eps=self.eps, weight_decay=self.weight_decay)
def instantiate(self, model: Model) -> RMSpropTF: return RMSpropTF(filter(lambda p: p.requires_grad, model.parameters()), lr=self.lr, alpha=self.alpha, eps=self.eps, weight_decay=self.weight_decay, momentum=self.momentum, centered=self.centered)
def __init__(self, device: torch.device, settings: BufferedOffPolicyIterationReinforcerSettings, environment: VecEnv, model: Model, algo: AlgoBase, env_roller: ReplayEnvRollerBase): self.device = device self.settings = settings self.environment = environment self._trained_model = model.to(self.device) self.algo = algo self.env_roller = env_roller
def __init__(self, device: torch.device, settings: OnPolicyIterationReinforcerSettings, model: Model, algo: AlgoBase, env_roller: EnvRollerBase) -> None: self.device = device self.settings = settings self._trained_model = model.to(self.device) self.env_roller = env_roller self.algo = algo