class SAC(Algorithm): """ Soft Actor-Critic (SAC) variant with stochastic policy and two Q-functions and two Q-targets (no V-function) .. seealso:: [1] T. Haarnoja, A. Zhou, P. Abbeel, S. Levine, "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor", ICML, 2018 [2] This implementation was inspired by https://github.com/pranz24/pytorch-soft-actor-critic which is seems to be based on https://github.com/vitchyr/rlkit """ name: str = 'sac' def __init__(self, save_dir: str, env: Env, policy: TwoHeadedPolicy, q_fcn_1: Policy, q_fcn_2: Policy, memory_size: int, gamma: float, max_iter: int, num_batch_updates: int, tau: float = 0.995, alpha_init: float = 0.2, learn_alpha: bool = True, target_update_intvl: int = 1, standardize_rew: bool = True, batch_size: int = 500, min_rollouts: int = None, min_steps: int = None, num_sampler_envs: int = 4, max_grad_norm: float = 5., lr: float = 3e-4, lr_scheduler=None, lr_scheduler_hparam: [dict, None] = None, logger: StepLogger = None): """ Constructor :param save_dir: directory to save the snapshots i.e. the results in :param env: the environment which the policy operates :param policy: policy to be updated :param q_fcn_1: state-action value function $Q(s,a)$, the associated target Q-functions is created from a re-initialized copies of this one :param q_fcn_2: state-action value function $Q(s,a)$, the associated target Q-functions is created from a re-initialized copies of this one :param memory_size: number of transitions in the replay memory buffer, e.g. 1000000 :param gamma: temporal discount factor for the state values :param max_iter: number of iterations (policy updates) :param num_batch_updates: number of batch updates per algorithm steps :param tau: interpolation factor in averaging for target networks, update used for the soft update a.k.a. polyak update, between 0 and 1 :param alpha_init: initial weighting factor of the entropy term in the loss function :param learn_alpha: adapt the weighting factor of the entropy term :param target_update_intvl: number of iterations that pass before updating the target network :param standardize_rew: bool to flag if the rewards should be standardized :param batch_size: number of samples per policy update batch :param min_rollouts: minimum number of rollouts sampled per policy update batch :param min_steps: minimum number of state transitions sampled per policy update batch :param num_sampler_envs: number of environments for parallel sampling :param max_grad_norm: maximum L2 norm of the gradients for clipping, set to `None` to disable gradient clipping :param lr: (initial) learning rate for the optimizer which can be by modified by the scheduler. By default, the learning rate is constant. :param lr_scheduler: learning rate scheduler type for the policy and the Q-functions that does one step per `update()` call :param lr_scheduler_hparam: hyper-parameters for the learning rate scheduler :param logger: logger for every step of the algorithm, if `None` the default logger will be created """ if not isinstance(env, Env): raise pyrado.TypeErr(given=env, expected_type=Env) if typed_env(env, ActNormWrapper) is None: raise pyrado.TypeErr( msg='SAC required an environment wrapped by an ActNormWrapper!' ) if not isinstance(q_fcn_1, Policy): raise pyrado.TypeErr(given=q_fcn_1, expected_type=Policy) if not isinstance(q_fcn_2, Policy): raise pyrado.TypeErr(given=q_fcn_2, expected_type=Policy) if logger is None: # Create logger that only logs every 100 steps of the algorithm logger = StepLogger(print_interval=100) logger.printers.append(ConsolePrinter()) logger.printers.append( CSVPrinter(osp.join(save_dir, 'progress.csv'))) # Call Algorithm's constructor super().__init__(save_dir, max_iter, policy, logger) # Store the inputs self._env = env self.q_fcn_1 = q_fcn_1 self.q_fcn_2 = q_fcn_2 self.q_targ_1 = deepcopy(self.q_fcn_1) self.q_targ_2 = deepcopy(self.q_fcn_2) self.q_targ_1.eval() self.q_targ_2.eval() self.gamma = gamma self.tau = tau self.learn_alpha = learn_alpha self.target_update_intvl = target_update_intvl self.standardize_rew = standardize_rew self.num_batch_updates = num_batch_updates self.batch_size = batch_size self.max_grad_norm = max_grad_norm # Initialize self._memory = ReplayMemory(memory_size) if policy.is_recurrent: init_expl_policy = RecurrentDummyPolicy(env.spec, policy.hidden_size) else: init_expl_policy = DummyPolicy(env.spec) self.sampler_init = ParallelSampler( env, init_expl_policy, # samples uniformly random from the action space num_envs=num_sampler_envs, min_steps=memory_size, ) self._expl_strat = SACExplStrat( self._policy, std_init=1.) # std_init will be overwritten by 2nd policy head self.sampler = ParallelSampler( env, self._expl_strat, num_envs=1, min_steps=min_steps, # in [2] this would be 1 min_rollouts=min_rollouts # in [2] this would be None ) self.sampler_eval = ParallelSampler(env, self._policy, num_envs=num_sampler_envs, min_steps=100 * env.max_steps, min_rollouts=None) self._optim_policy = to.optim.Adam([{ 'params': self._policy.parameters() }], lr=lr) self._optim_q_fcn_1 = to.optim.Adam( [{ 'params': self.q_fcn_1.parameters() }], lr=lr) self._optim_q_fcn_2 = to.optim.Adam( [{ 'params': self.q_fcn_2.parameters() }], lr=lr) log_alpha_init = to.log( to.tensor(alpha_init, dtype=to.get_default_dtype())) if learn_alpha: # Automatic entropy tuning self._log_alpha = nn.Parameter(log_alpha_init, requires_grad=True) self._alpha_optim = to.optim.Adam([{ 'params': self._log_alpha }], lr=lr) self.target_entropy = -to.prod(to.tensor(env.act_space.shape)) else: self._log_alpha = log_alpha_init self._lr_scheduler_policy = lr_scheduler self._lr_scheduler_hparam = lr_scheduler_hparam if lr_scheduler is not None: self._lr_scheduler_policy = lr_scheduler(self._optim_policy, **lr_scheduler_hparam) self._lr_scheduler_q_fcn_1 = lr_scheduler(self._optim_q_fcn_1, **lr_scheduler_hparam) self._lr_scheduler_q_fcn_2 = lr_scheduler(self._optim_q_fcn_2, **lr_scheduler_hparam) @property def expl_strat(self) -> SACExplStrat: return self._expl_strat @property def memory(self) -> ReplayMemory: """ Get the replay memory. """ return self._memory @property def alpha(self) -> to.Tensor: """ Get the detached entropy coefficient. """ return to.exp(self._log_alpha.detach()) def step(self, snapshot_mode: str, meta_info: dict = None): if self._memory.isempty: # Warm-up phase print_cbt( 'Collecting samples until replay memory contains if full.', 'w') # Sample steps and store them in the replay memory ros = self.sampler_init.sample() self._memory.push(ros) else: # Sample steps and store them in the replay memory ros = self.sampler.sample() self._memory.push(ros) # Log return-based metrics if self._curr_iter % self.logger.print_interval == 0: ros = self.sampler_eval.sample() rets = [ro.undiscounted_return() for ro in ros] ret_max = np.max(rets) ret_med = np.median(rets) ret_avg = np.mean(rets) ret_min = np.min(rets) ret_std = np.std(rets) else: ret_max, ret_med, ret_avg, ret_min, ret_std = 5 * [ -pyrado.inf ] # dummy values self.logger.add_value('max return', np.round(ret_max, 4)) self.logger.add_value('median return', np.round(ret_med, 4)) self.logger.add_value('avg return', np.round(ret_avg, 4)) self.logger.add_value('min return', np.round(ret_min, 4)) self.logger.add_value('std return', np.round(ret_std, 4)) self.logger.add_value('avg rollout length', np.round(np.mean([ro.length for ro in ros]), 2)) self.logger.add_value('num rollouts', len(ros)) self.logger.add_value('avg memory reward', np.round(self._memory.avg_reward(), 4)) # Use data in the memory to update the policy and the Q-functions self.update() # Save snapshot data self.make_snapshot(snapshot_mode, float(ret_avg), meta_info) @staticmethod def soft_update(target: nn.Module, source: nn.Module, tau: float = 0.995): """ Moving average update, a.k.a. Polyak update. Modifies the input argument `target`. :param target: PyTroch module with parameters to be updated :param source: PyTroch module with parameters to update to :param tau: interpolation factor for averaging, between 0 and 1 """ if not 0 < tau < 1: raise pyrado.ValueErr(given=tau, g_constraint='0', l_constraint='1') for targ_param, src_param in zip(target.parameters(), source.parameters()): targ_param.data = targ_param.data * tau + src_param.data * (1. - tau) def update(self): """ Update the policy's and Q-functions' parameters on transitions sampled from the replay memory. """ # Containers for logging policy_losses = to.zeros(self.num_batch_updates) expl_strat_stds = to.zeros(self.num_batch_updates) q_fcn_1_losses = to.zeros(self.num_batch_updates) q_fcn_2_losses = to.zeros(self.num_batch_updates) policy_grad_norm = to.zeros(self.num_batch_updates) q_fcn_1_grad_norm = to.zeros(self.num_batch_updates) q_fcn_2_grad_norm = to.zeros(self.num_batch_updates) for b in tqdm(range(self.num_batch_updates), total=self.num_batch_updates, desc=f'Updating', unit='batches', file=sys.stdout, leave=False): # Sample steps and the associated next step from the replay memory steps, next_steps = self._memory.sample(self.batch_size) steps.torch(data_type=to.get_default_dtype()) next_steps.torch(data_type=to.get_default_dtype()) # Standardize rewards if self.standardize_rew: rewards = standardize(steps.rewards).unsqueeze(1) else: rewards = steps.rewards.unsqueeze(1) rew_scale = 1. rewards *= rew_scale with to.no_grad(): # Create masks for the non-final observations not_done = to.tensor(1. - steps.done, dtype=to.get_default_dtype()).unsqueeze(1) # Compute the (next)state-(next)action values Q(s',a') from the target networks if self.policy.is_recurrent: next_act_expl, next_log_probs, _ = self._expl_strat( next_steps.observations, next_steps.hidden_states) else: next_act_expl, next_log_probs = self._expl_strat( next_steps.observations) next_q_val_target_1 = self.q_targ_1( to.cat([next_steps.observations, next_act_expl], dim=1)) next_q_val_target_2 = self.q_targ_2( to.cat([next_steps.observations, next_act_expl], dim=1)) next_q_val_target_min = to.min( next_q_val_target_1, next_q_val_target_2) - self.alpha * next_log_probs next_q_val = rewards + not_done * self.gamma * next_q_val_target_min # Compute the two Q-function losses # E_{(s_t, a_t) ~ D} [1/2 * (Q_i(s_t, a_t) - r_t - gamma * E_{s_{t+1} ~ p} [V(s_{t+1})] )^2] q_val_1 = self.q_fcn_1( to.cat([steps.observations, steps.actions], dim=1)) q_val_2 = self.q_fcn_2( to.cat([steps.observations, steps.actions], dim=1)) q_1_loss = nn.functional.mse_loss(q_val_1, next_q_val) q_2_loss = nn.functional.mse_loss(q_val_2, next_q_val) q_fcn_1_losses[b] = q_1_loss.data q_fcn_2_losses[b] = q_2_loss.data # Compute the policy loss # E_{s_t ~ D, eps_t ~ N} [log( pi( f(eps_t; s_t) ) ) - Q(s_t, f(eps_t; s_t))] if self.policy.is_recurrent: act_expl, log_probs, _ = self._expl_strat( steps.observations, steps.hidden_states) else: act_expl, log_probs = self._expl_strat(steps.observations) q1_pi = self.q_fcn_1(to.cat([steps.observations, act_expl], dim=1)) q2_pi = self.q_fcn_2(to.cat([steps.observations, act_expl], dim=1)) min_q_pi = to.min(q1_pi, q2_pi) policy_loss = to.mean(self.alpha * log_probs - min_q_pi) policy_losses[b] = policy_loss.data expl_strat_stds[b] = to.mean(self._expl_strat.std.data) # Do one optimization step for each optimizer, and clip the gradients if desired # Q-fcn 1 self._optim_q_fcn_1.zero_grad() q_1_loss.backward() q_fcn_1_grad_norm[b] = self.clip_grad(self.q_fcn_1, None) self._optim_q_fcn_1.step() # Q-fcn 2 self._optim_q_fcn_2.zero_grad() q_2_loss.backward() q_fcn_2_grad_norm[b] = self.clip_grad(self.q_fcn_2, None) self._optim_q_fcn_2.step() # Policy self._optim_policy.zero_grad() policy_loss.backward() policy_grad_norm[b] = self.clip_grad(self._expl_strat.policy, self.max_grad_norm) self._optim_policy.step() if self.learn_alpha: # Compute entropy coefficient loss alpha_loss = -to.mean( self._log_alpha * (log_probs.detach() + self.target_entropy)) # Do one optimizer step for the entropy coefficient optimizer self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() # Soft-update the target networks if (self._curr_iter * self.num_batch_updates + b) % self.target_update_intvl == 0: SAC.soft_update(self.q_targ_1, self.q_fcn_1, self.tau) SAC.soft_update(self.q_targ_2, self.q_fcn_2, self.tau) # Update the learning rate if the schedulers have been specified if self._lr_scheduler_policy is not None: self._lr_scheduler_policy.step() self._lr_scheduler_q_fcn_1.step() self._lr_scheduler_q_fcn_2.step() # Logging self.logger.add_value('Q1 loss', to.mean(q_fcn_1_losses).item()) self.logger.add_value('Q2 loss', to.mean(q_fcn_2_losses).item()) self.logger.add_value('policy loss', to.mean(policy_losses).item()) self.logger.add_value('avg policy grad norm', to.mean(policy_grad_norm).item()) self.logger.add_value('avg expl strat std', to.mean(expl_strat_stds).item()) self.logger.add_value('alpha', self.alpha.item()) if self._lr_scheduler_policy is not None: self.logger.add_value('learning rate', self._lr_scheduler_policy.get_lr()) def save_snapshot(self, meta_info: dict = None): super().save_snapshot(meta_info) if meta_info is None: # This instance is not a subroutine of a meta-algorithm joblib.dump(self._env, osp.join(self._save_dir, 'env.pkl')) to.save(self.q_targ_1, osp.join(self._save_dir, 'target1.pt')) to.save(self.q_targ_2, osp.join(self._save_dir, 'target2.pt')) else: # This algorithm instance is a subroutine of a meta-algorithm if 'prefix' in meta_info and 'suffix' in meta_info: to.save( self.q_targ_1, osp.join( self._save_dir, f"{meta_info['prefix']}_target1_{meta_info['suffix']}.pt" )) to.save( self.q_targ_2, osp.join( self._save_dir, f"{meta_info['prefix']}_target2_{meta_info['suffix']}.pt" )) elif 'prefix' in meta_info and 'suffix' not in meta_info: to.save( self.q_targ_1, osp.join(self._save_dir, f"{meta_info['prefix']}_target1.pt")) to.save( self.q_targ_2, osp.join(self._save_dir, f"{meta_info['prefix']}_target2.pt")) elif 'prefix' not in meta_info and 'suffix' in meta_info: to.save( self.q_targ_1, osp.join(self._save_dir, f"target1_{meta_info['suffix']}.pt")) to.save( self.q_targ_2, osp.join(self._save_dir, f"target2_{meta_info['suffix']}.pt")) else: raise NotImplementedError def load_snapshot(self, load_dir: str = None, meta_info: dict = None): # Get the directory to load from ld = load_dir if load_dir is not None else self._save_dir super().load_snapshot(ld, meta_info) if meta_info is None: # This algorithm instance is not a subroutine of a meta-algorithm self._env = joblib.load(osp.join(ld, 'env.pkl')) self.q_targ_1.load_state_dict( to.load(osp.join(ld, 'target1.pt')).state_dict()) self.q_targ_2.load_state_dict( to.load(osp.join(ld, 'target2.pt')).state_dict()) else: # This algorithm instance is a subroutine of a meta-algorithm if 'prefix' in meta_info and 'suffix' in meta_info: self.q_targ_1.load_state_dict( to.load( osp.join( ld, f"{meta_info['prefix']}_target1_{meta_info['suffix']}.pt" )).state_dict()) self.q_targ_2.load_state_dict( to.load( osp.join( ld, f"{meta_info['prefix']}_target2_{meta_info['suffix']}.pt" )).state_dict()) elif 'prefix' in meta_info and 'suffix' not in meta_info: self.q_targ_1.load_state_dict( to.load(osp.join( ld, f"{meta_info['prefix']}_target1.pt")).state_dict()) self.q_targ_2.load_state_dict( to.load(osp.join( ld, f"{meta_info['prefix']}_target2.pt")).state_dict()) elif 'prefix' not in meta_info and 'suffix' in meta_info: self.q_targ_1.load_state_dict( to.load(osp.join( ld, f"target1_{meta_info['suffix']}.pt")).state_dict()) self.q_targ_2.load_state_dict( to.load(osp.join( ld, f"target2_{meta_info['suffix']}.pt")).state_dict()) else: raise NotImplementedError def reset(self, seed: int = None): # Reset the exploration strategy, internal variables and the random seeds super().reset(seed) # Re-initialize sampler in case env or policy changed self.sampler.reinit() # Reset the replay memory self._memory.reset() # Reset the learning rate schedulers if self._lr_scheduler_policy is not None: self._lr_scheduler_policy.last_epoch = -1 if self._lr_scheduler_q_fcn_1 is not None: self._lr_scheduler_q_fcn_1.last_epoch = -1 if self._lr_scheduler_q_fcn_2 is not None: self._lr_scheduler_q_fcn_2.last_epoch = -1
class DQL(Algorithm): """ Deep Q-Learning (without bells and whistles) .. seealso:: [1] V. Mnih et.al., "Human-level control through deep reinforcement learning", Nature, 2015 """ name: str = 'dql' def __init__(self, save_dir: str, env: Env, policy: DiscrActQValFNNPolicy, memory_size: int, eps_init: float, eps_schedule_gamma: float, gamma: float, max_iter: int, num_batch_updates: int, target_update_intvl: int = 5, min_rollouts: int = None, min_steps: int = None, batch_size: int = 256, num_sampler_envs: int = 4, max_grad_norm: float = 0.5, lr: float = 5e-4, lr_scheduler=None, lr_scheduler_hparam: [dict, None] = None, logger: StepLogger = None): """ Constructor :param save_dir: directory to save the snapshots i.e. the results in :param env: environment which the policy operates :param policy: (current) Q-network updated by this algorithm :param memory_size: number of transitions in the replay memory buffer :param eps_init: initial value for the probability of taking a random action, constant if `eps_schedule_gamma==1` :param eps_schedule_gamma: temporal discount factor for the exponential decay of epsilon :param gamma: temporal discount factor for the state values :param max_iter: number of iterations (policy updates) :param num_batch_updates: number of batch updates per algorithm steps :param target_update_intvl: number of iterations that pass before updating the target network :param min_rollouts: minimum number of rollouts sampled per policy update batch :param min_steps: minimum number of state transitions sampled per policy update batch :param batch_size: number of samples per policy update batch :param num_sampler_envs: number of environments for parallel sampling :param max_grad_norm: maximum L2 norm of the gradients for clipping, set to `None` to disable gradient clipping :param lr: (initial) learning rate for the optimizer which can be by modified by the scheduler. By default, the learning rate is constant. :param lr_scheduler: learning rate scheduler that does one step per epoch (pass through the whole data set) :param lr_scheduler_hparam: hyper-parameters for the learning rate scheduler :param logger: logger for every step of the algorithm """ if not isinstance(env, Env): raise pyrado.TypeErr(given=env, expected_type=Env) if not isinstance(policy, DiscrActQValFNNPolicy): raise pyrado.TypeErr(given=policy, expected_type=DiscrActQValFNNPolicy) if logger is None: # Create logger that only logs every 100 steps of the algorithm logger = StepLogger(print_interval=100) logger.printers.append(ConsolePrinter()) logger.printers.append(CSVPrinter(osp.join(save_dir, 'progress.csv'))) # Call Algorithm's constructor super().__init__(save_dir, max_iter, policy, logger) # Store the inputs self._env = env self.target = deepcopy(self._policy) self.target.eval() # will not be trained using the optimizer self._memory_size = memory_size self.eps = eps_init self.gamma = gamma self.target_update_intvl = target_update_intvl self.num_batch_updates = num_batch_updates self.batch_size = batch_size self.max_grad_norm = max_grad_norm # Initialize self._expl_strat = EpsGreedyExplStrat(self._policy, eps_init, eps_schedule_gamma) self._memory = ReplayMemory(memory_size) self.sampler = ParallelSampler( env, self._expl_strat, num_envs=1, min_steps=min_steps, min_rollouts=min_rollouts ) self.sampler_eval = ParallelSampler( env, self._policy, num_envs=num_sampler_envs, min_steps=100*env.max_steps, min_rollouts=None ) self.optim = to.optim.RMSprop([{'params': self._policy.parameters()}], lr=lr) self._lr_scheduler = lr_scheduler self._lr_scheduler_hparam = lr_scheduler_hparam if lr_scheduler is not None: self._lr_scheduler = lr_scheduler(self.optim, **lr_scheduler_hparam) @property def expl_strat(self) -> EpsGreedyExplStrat: return self._expl_strat @property def memory(self) -> ReplayMemory: """ Get the replay memory. """ return self._memory def step(self, snapshot_mode: str, meta_info: dict = None): # Sample steps and store them in the replay memory ros = self.sampler.sample() self._memory.push(ros) while len(self._memory) < self.memory.capacity: # Warm-up phase print_cbt('Collecting samples until replay memory contains if full.', 'w') # Sample steps and store them in the replay memory ros = self.sampler.sample() self._memory.push(ros) # Log return-based metrics if self._curr_iter % self.logger.print_interval == 0: ros = self.sampler_eval.sample() rets = [ro.undiscounted_return() for ro in ros] ret_max = np.max(rets) ret_med = np.median(rets) ret_avg = np.mean(rets) ret_min = np.min(rets) ret_std = np.std(rets) else: ret_max, ret_med, ret_avg, ret_min, ret_std = 5*[-pyrado.inf] # dummy values self.logger.add_value('max return', np.round(ret_max, 4)) self.logger.add_value('median return', np.round(ret_med, 4)) self.logger.add_value('avg return', np.round(ret_avg, 4)) self.logger.add_value('min return', np.round(ret_min, 4)) self.logger.add_value('std return', np.round(ret_std, 4)) self.logger.add_value('avg rollout length', np.round(np.mean([ro.length for ro in ros]), 2)) self.logger.add_value('num rollouts', len(ros)) self.logger.add_value('avg memory reward', np.round(self._memory.avg_reward(), 4)) # Use data in the memory to update the policy and the target Q-function self.update() # Save snapshot data self.make_snapshot(snapshot_mode, float(ret_avg), meta_info) def loss_fcn(self, q_vals: to.Tensor, expected_q_vals: to.Tensor) -> to.Tensor: r""" The Huber loss function on the one-step TD error $\delta = Q(s,a) - (r + \gamma \max_a Q(s^\prime, a))$. :param q_vals: state-action values $Q(s,a)$, from policy network :param expected_q_vals: expected state-action values $r + \gamma \max_a Q(s^\prime, a)$, from target network :return: loss value """ return nn.functional.smooth_l1_loss(q_vals, expected_q_vals) def update(self): """ Update the policy's and target Q-function's parameters on transitions sampled from the replay memory. """ losses = to.zeros(self.num_batch_updates) policy_grad_norm = to.zeros(self.num_batch_updates) for b in tqdm(range(self.num_batch_updates), total=self.num_batch_updates, desc=f'Updating', unit='batches', file=sys.stdout, leave=False): # Sample steps and the associated next step from the replay memory steps, next_steps = self._memory.sample(self.batch_size) steps.torch(data_type=to.get_default_dtype()) next_steps.torch(data_type=to.get_default_dtype()) # Create masks for the non-final observations not_done = to.tensor(1. - steps.done, dtype=to.get_default_dtype()) # Compute the state-action values Q(s,a) using the current DQN policy q_vals = self.expl_strat.policy.q_values_chosen(steps.observations) # Compute the second term of TD-error next_v_vals = self.target.q_values_chosen(next_steps.observations).detach() expected_q_val = steps.rewards + not_done*self.gamma*next_v_vals # Compute the loss, clip the gradients if desired, and do one optimization step loss = self.loss_fcn(q_vals, expected_q_val) losses[b] = loss.data self.optim.zero_grad() loss.backward() policy_grad_norm[b] = self.clip_grad(self.expl_strat.policy, self.max_grad_norm) self.optim.step() # Update the target network by copying all weights and biases from the DQN policy if (self._curr_iter*self.num_batch_updates + b)%self.target_update_intvl == 0: self.target.load_state_dict(self.expl_strat.policy.state_dict()) # Schedule the exploration parameter epsilon self.expl_strat.schedule_eps(self._curr_iter) # Update the learning rate if a scheduler has been specified if self._lr_scheduler is not None: self._lr_scheduler.step() # Logging with to.no_grad(): self.logger.add_value('loss after', to.mean(losses).item()) self.logger.add_value('expl strat eps', self.expl_strat.eps.item()) self.logger.add_value('avg policy grad norm', to.mean(policy_grad_norm).item()) if self._lr_scheduler is not None: self.logger.add_value('learning rate', self._lr_scheduler.get_lr()) def save_snapshot(self, meta_info: dict = None): super().save_snapshot(meta_info) if meta_info is None: # This instance is not a subroutine of a meta-algorithm joblib.dump(self._env, osp.join(self._save_dir, 'env.pkl')) to.save(self.target, osp.join(self._save_dir, 'target.pt')) else: # This algorithm instance is a subroutine of a meta-algorithm if 'prefix' in meta_info and 'suffix' in meta_info: to.save(self.target, osp.join(self._save_dir, f"{meta_info['prefix']}_target_{meta_info['suffix']}.pt")) elif 'prefix' in meta_info and 'suffix' not in meta_info: to.save(self.target, osp.join(self._save_dir, f"{meta_info['prefix']}_target.pt")) elif 'prefix' not in meta_info and 'suffix' in meta_info: to.save(self.target, osp.join(self._save_dir, f"target_{meta_info['suffix']}.pt")) else: raise NotImplementedError def load_snapshot(self, load_dir: str = None, meta_info: dict = None): # Get the directory to load from ld = load_dir if load_dir is not None else self._save_dir super().load_snapshot(ld, meta_info) if meta_info is None: # This algorithm instance is not a subroutine of a meta-algorithm self._env = joblib.load(osp.join(ld, 'env.pkl')) self.target.load_state_dict(to.load(osp.join(ld, 'target.pt')).state_dict()) else: # This algorithm instance is a subroutine of a meta-algorithm if 'prefix' in meta_info and 'suffix' in meta_info: self.target.load_state_dict( to.load(osp.join(ld, f"{meta_info['prefix']}_target_{meta_info['suffix']}.pt")).state_dict() ) elif 'prefix' in meta_info and 'suffix' not in meta_info: self.target.load_state_dict( to.load(osp.join(ld, f"{meta_info['prefix']}_target.pt")).state_dict() ) elif 'prefix' not in meta_info and 'suffix' in meta_info: self.target.load_state_dict( to.load(osp.join(ld, f"target_{meta_info['suffix']}.pt")).state_dict() ) else: raise NotImplementedError def reset(self, seed: int = None): # Reset the exploration strategy, internal variables and the random seeds super().reset(seed) # Re-initialize sampler in case env or policy changed self.sampler.reinit() # Reset the replay memory self._memory.reset() # Reset the learning rate scheduler if self._lr_scheduler is not None: self._lr_scheduler.last_epoch = -1