class TAC(SarlOffPolicy): """Tsallis Actor Critic, TAC with V neural Network. https://arxiv.org/abs/1902.00137 """ policy_mode = 'off-policy' def __init__(self, alpha=0.2, annealing=True, last_alpha=0.01, polyak=0.995, entropic_index=1.5, discrete_tau=1.0, network_settings={ 'actor_continuous': { 'share': [128, 128], 'mu': [64], 'log_std': [64], 'soft_clip': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [64, 32], 'q': [128, 128] }, auto_adaption=True, actor_lr=5.0e-4, critic_lr=1.0e-3, alpha_lr=5.0e-4, **kwargs): super().__init__(**kwargs) self.polyak = polyak self.discrete_tau = discrete_tau self.entropic_index = 2 - entropic_index self.auto_adaption = auto_adaption self.annealing = annealing self.critic = TargetTwin(CriticQvalueOne(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, network_settings=network_settings['q']), self.polyak).to(self.device) self.critic2 = deepcopy(self.critic) if self.is_continuous: self.actor = ActorCts(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to(self.device) else: self.actor = ActorDct(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to(self.device) # entropy = -log(1/|A|) = log |A| self.target_entropy = 0.98 * (-self.a_dim if self.is_continuous else np.log(self.a_dim)) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR([self.critic, self.critic2], critic_lr, **self._oplr_params) if self.auto_adaption: self.log_alpha = th.tensor(0., requires_grad=True).to(self.device) self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params) self._trainer_modules.update(alpha_oplr=self.alpha_oplr) else: self.log_alpha = th.tensor(alpha).log().to(self.device) if self.annealing: self.alpha_annealing = LinearAnnealing(alpha, last_alpha, int(1e6)) self._trainer_modules.update(actor=self.actor, critic=self.critic, critic2=self.critic2, log_alpha=self.log_alpha, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) @property def alpha(self): return self.log_alpha.exp() @iton def select_action(self, obs): if self.is_continuous: mu, log_std = self.actor(obs, rnncs=self.rnncs) # [B, A] pi = td.Normal(mu, log_std.exp()).sample().tanh() # [B, A] mu.tanh_() # squash mu # [B, A] else: logits = self.actor(obs, rnncs=self.rnncs) # [B, A] mu = logits.argmax(-1) # [B,] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] self.rnncs_ = self.actor.get_rnncs() actions = pi if self._is_train_mode else mu return actions, Data(action=actions) @iton def _train(self, BATCH): if self.is_continuous: target_mu, target_log_std = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(target_mu, target_log_std.exp()), 1) target_pi = dist.sample() # [T, B, A] target_pi, target_log_pi = squash_action(target_pi, dist.log_prob( target_pi).unsqueeze(-1), is_independent=False) # [T, B, A] target_log_pi = tsallis_entropy_log_q(target_log_pi, self.entropic_index) # [T, B, 1] else: target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze(-1) # [T, B, 1] target_pi = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q2 = self.critic2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q1_target = self.critic.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q2_target = self.critic2.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q_target = th.minimum(q1_target, q2_target) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, (q_target - self.alpha * target_log_pi), BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action(pi, dist.log_prob(pi).unsqueeze(-1), is_independent=False) # [T, B, A] log_pi = tsallis_entropy_log_q(log_pi, self.entropic_index) # [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(-1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 q_s_pi = th.minimum(self.critic(BATCH.obs, pi, begin_mask=BATCH.begin_mask), self.critic2(BATCH.obs, pi, begin_mask=BATCH.begin_mask)) # [T, B, 1] actor_loss = -(q_s_pi - self.alpha * log_pi).mean() # 1 self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy, 'Statistics/q_min': th.minimum(q1, q2).min(), 'Statistics/q_mean': th.minimum(q1, q2).mean(), 'Statistics/q_max': th.maximum(q1, q2).max() } if self.auto_adaption: alpha_loss = -(self.alpha * (log_pi + self.target_entropy).detach()).mean() # 1 self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries def _after_train(self): super()._after_train() self.critic.sync() self.critic2.sync() if self.annealing and not self.auto_adaption: self.log_alpha.copy_(self.alpha_annealing(self._cur_train_step).log())
class PPO(SarlOnPolicy): """ Proximal Policy Optimization, https://arxiv.org/abs/1707.06347 Emergence of Locomotion Behaviours in Rich Environments, http://arxiv.org/abs/1707.02286, DPPO """ policy_mode = 'on-policy' def __init__(self, agent_spec, ent_coef: float = 1.0e-2, vf_coef: float = 0.5, lr: float = 5.0e-4, lambda_: float = 0.95, epsilon: float = 0.2, use_duel_clip: bool = False, duel_epsilon: float = 0., use_vclip: bool = False, value_epsilon: float = 0.2, share_net: bool = True, actor_lr: float = 3e-4, critic_lr: float = 1e-3, kl_reverse: bool = False, kl_target: float = 0.02, kl_target_cutoff: float = 2, kl_target_earlystop: float = 4, kl_beta: List[float] = [0.7, 1.3], kl_alpha: float = 1.5, kl_coef: float = 1.0, extra_coef: float = 1000.0, use_kl_loss: bool = False, use_extra_loss: bool = False, use_early_stop: bool = False, network_settings: Dict = { 'share': { 'continuous': { 'condition_sigma': False, 'log_std_bound': [-20, 2], 'share': [32, 32], 'mu': [32, 32], 'v': [32, 32] }, 'discrete': { 'share': [32, 32], 'logits': [32, 32], 'v': [32, 32] } }, 'actor_continuous': { 'hidden_units': [64, 64], 'condition_sigma': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [32, 32], 'critic': [32, 32] }, **kwargs): super().__init__(agent_spec=agent_spec, **kwargs) self._ent_coef = ent_coef self.lambda_ = lambda_ assert 0.0 <= lambda_ <= 1.0, "GAE lambda should be in [0, 1]." self._epsilon = epsilon self._use_vclip = use_vclip self._value_epsilon = value_epsilon self._share_net = share_net self._kl_reverse = kl_reverse self._kl_target = kl_target self._kl_alpha = kl_alpha self._kl_coef = kl_coef self._extra_coef = extra_coef self._vf_coef = vf_coef self._use_duel_clip = use_duel_clip self._duel_epsilon = duel_epsilon if self._use_duel_clip: assert - \ self._epsilon < self._duel_epsilon < self._epsilon, "duel_epsilon should be set in the range of (-epsilon, epsilon)." self._kl_cutoff = kl_target * kl_target_cutoff self._kl_stop = kl_target * kl_target_earlystop self._kl_low = kl_target * kl_beta[0] self._kl_high = kl_target * kl_beta[-1] self._use_kl_loss = use_kl_loss self._use_extra_loss = use_extra_loss self._use_early_stop = use_early_stop if self._share_net: if self.is_continuous: self.net = ActorCriticValueCts(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['share']['continuous']).to(self.device) else: self.net = ActorCriticValueDct(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['share']['discrete']).to(self.device) self.oplr = OPLR(self.net, lr, **self._oplr_params) self._trainer_modules.update(model=self.net, oplr=self.oplr) else: if self.is_continuous: self.actor = ActorMuLogstd(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to(self.device) else: self.actor = ActorDct(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to(self.device) self.critic = CriticValue(self.obs_spec, rep_net_params=self._rep_net_params, network_settings=network_settings['critic']).to(self.device) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) @iton def select_action(self, obs): if self.is_continuous: if self._share_net: mu, log_std, value = self.net(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.net.get_rnncs() else: mu, log_std = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() value = self.critic(obs, rnncs=self.rnncs) # [B, 1] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] log_prob = dist.log_prob(action).unsqueeze(-1) # [B, 1] else: if self._share_net: logits, value = self.net(obs, rnncs=self.rnncs) # [B, A], [B, 1] self.rnncs_ = self.net.get_rnncs() else: logits = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() value = self.critic(obs, rnncs=self.rnncs) # [B, 1] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action).unsqueeze(-1) # [B, 1] acts_info = Data(action=action, value=value, log_prob=log_prob + th.finfo().eps) if self.use_rnn: acts_info.update(rnncs=self.rnncs) return action, acts_info @iton def _get_value(self, obs, rnncs=None): if self._share_net: if self.is_continuous: _, _, value = self.net(obs, rnncs=rnncs) # [B, 1] else: _, value = self.net(obs, rnncs=rnncs) # [B, 1] else: value = self.critic(obs, rnncs=rnncs) # [B, 1] return value def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH = super()._preprocess_BATCH(BATCH) value = self._get_value(BATCH.obs_[-1], rnncs=self.rnncs) BATCH.discounted_reward = discounted_sum(BATCH.reward, self.gamma, BATCH.done, BATCH.begin_mask, init_value=value) td_error = calculate_td_error(BATCH.reward, self.gamma, BATCH.done, value=BATCH.value, next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]), 0)) BATCH.gae_adv = discounted_sum(td_error, self.lambda_ * self.gamma, BATCH.done, BATCH.begin_mask, init_value=0., normalize=True) return BATCH def learn(self, BATCH: Data): BATCH = self._preprocess_BATCH(BATCH) # [T, B, *] for _ in range(self._epochs): kls = [] for _BATCH in BATCH.sample(self._chunk_length, self.batch_size, repeat=self._sample_allow_repeat): _BATCH = self._before_train(_BATCH) summaries, kl = self._train(_BATCH) kls.append(kl) self.summaries.update(summaries) self._after_train() if self._use_early_stop and sum(kls) / len(kls) > self._kl_stop: break def _train(self, BATCH): if self._share_net: summaries, kl = self.train_share(BATCH) else: summaries = dict() actor_summaries, kl = self.train_actor(BATCH) critic_summaries = self.train_critic(BATCH) summaries.update(actor_summaries) summaries.update(critic_summaries) if self._use_kl_loss: # ref: https://github.com/joschu/modular_rl/blob/6970cde3da265cf2a98537250fea5e0c0d9a7639/modular_rl/ppo.py#L93 if kl > self._kl_high: self._kl_coef *= self._kl_alpha elif kl < self._kl_low: self._kl_coef /= self._kl_alpha summaries.update({ 'Statistics/kl_coef': self._kl_coef }) return summaries, kl @iton def train_share(self, BATCH): if self.is_continuous: # [T, B, A], [T, B, A], [T, B, 1] mu, log_std, value = self.net(BATCH.obs, begin_mask=BATCH.begin_mask) dist = td.Independent(td.Normal(mu, log_std.exp()), 1) new_log_prob = dist.log_prob(BATCH.action).unsqueeze(-1) # [T, B, 1] entropy = dist.entropy().unsqueeze(-1) # [T, B, 1] else: # [T, B, A], [T, B, 1] logits, value = self.net(BATCH.obs, begin_mask=BATCH.begin_mask) logp_all = logits.log_softmax(-1) # [T, B, 1] new_log_prob = (BATCH.action * logp_all).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1, keepdim=True) # [T, B, 1] ratio = (new_log_prob - BATCH.log_prob).exp() # [T, B, 1] surrogate = ratio * BATCH.gae_adv # [T, B, 1] clipped_surrogate = th.minimum( surrogate, ratio.clamp(1.0 - self._epsilon, 1.0 + self._epsilon) * BATCH.gae_adv ) # [T, B, 1] # ref: https://github.com/thu-ml/tianshou/blob/c97aa4065ee8464bd5897bb86f1f81abd8e2cff9/tianshou/policy/modelfree/ppo.py#L159 if self._use_duel_clip: clipped_surrogate2 = th.maximum( clipped_surrogate, (1.0 + self._duel_epsilon) * BATCH.gae_adv ) # [T, B, 1] clipped_surrogate = th.where(BATCH.gae_adv < 0, clipped_surrogate2, clipped_surrogate) # [T, B, 1] actor_loss = -(clipped_surrogate + self._ent_coef * entropy).mean() # 1 # ref: https://github.com/joschu/modular_rl/blob/6970cde3da265cf2a98537250fea5e0c0d9a7639/modular_rl/ppo.py#L40 # ref: https://github.com/hill-a/stable-baselines/blob/b3f414f4f2900403107357a2206f80868af16da3/stable_baselines/ppo2/ppo2.py#L185 if self._kl_reverse: # TODO: kl = .5 * (new_log_prob - BATCH.log_prob).square().mean() # 1 else: # a sample estimate for KL-divergence, easy to compute kl = .5 * (BATCH.log_prob - new_log_prob).square().mean() if self._use_kl_loss: kl_loss = self._kl_coef * kl # 1 actor_loss += kl_loss if self._use_extra_loss: extra_loss = self._extra_coef * th.maximum(th.zeros_like(kl), kl - self._kl_cutoff).square().mean() # 1 actor_loss += extra_loss td_error = BATCH.discounted_reward - value # [T, B, 1] if self._use_vclip: # ref: https://github.com/llSourcell/OpenAI_Five_vs_Dota2_Explained/blob/c5def7e57aa70785c2394ea2eeb3e5f66ad59a53/train.py#L154 # ref: https://github.com/hill-a/stable-baselines/blob/b3f414f4f2900403107357a2206f80868af16da3/stable_baselines/ppo2/ppo2.py#L172 value_clip = BATCH.value + (value - BATCH.value).clamp(-self._value_epsilon, self._value_epsilon) # [T, B, 1] td_error_clip = BATCH.discounted_reward - value_clip # [T, B, 1] td_square = th.maximum(td_error.square(), td_error_clip.square()) # [T, B, 1] else: td_square = td_error.square() # [T, B, 1] critic_loss = 0.5 * td_square.mean() # 1 loss = actor_loss + self._vf_coef * critic_loss # 1 self.oplr.optimize(loss) return { 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/kl': kl, 'Statistics/entropy': entropy.mean(), 'LEARNING_RATE/lr': self.oplr.lr }, kl @iton def train_actor(self, BATCH): if self.is_continuous: # [T, B, A], [T, B, A] mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) dist = td.Independent(td.Normal(mu, log_std.exp()), 1) new_log_prob = dist.log_prob(BATCH.action).unsqueeze(-1) # [T, B, 1] entropy = dist.entropy().unsqueeze(-1) # [T, B, 1] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] new_log_prob = (BATCH.action * logp_all).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1, keepdim=True) # [T, B, 1] ratio = (new_log_prob - BATCH.log_prob).exp() # [T, B, 1] kl = (BATCH.log_prob - new_log_prob).square().mean() # 1 surrogate = ratio * BATCH.gae_adv # [T, B, 1] clipped_surrogate = th.minimum( surrogate, th.where(BATCH.gae_adv > 0, (1 + self._epsilon) * BATCH.gae_adv, (1 - self._epsilon) * BATCH.gae_adv) ) # [T, B, 1] if self._use_duel_clip: clipped_surrogate = th.maximum( clipped_surrogate, (1.0 + self._duel_epsilon) * BATCH.gae_adv ) # [T, B, 1] actor_loss = -(clipped_surrogate + self._ent_coef * entropy).mean() # 1 if self._use_kl_loss: kl_loss = self._kl_coef * kl # 1 actor_loss += kl_loss if self._use_extra_loss: extra_loss = self._extra_coef * th.maximum(th.zeros_like(kl), kl - self._kl_cutoff).square().mean() # 1 actor_loss += extra_loss self.actor_oplr.optimize(actor_loss) return { 'LOSS/actor_loss': actor_loss, 'Statistics/kl': kl, 'Statistics/entropy': entropy.mean(), 'LEARNING_RATE/actor_lr': self.actor_oplr.lr }, kl @iton def train_critic(self, BATCH): value = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] td_error = BATCH.discounted_reward - value # [T, B, 1] if self._use_vclip: value_clip = BATCH.value + (value - BATCH.value).clamp(-self._value_epsilon, self._value_epsilon) # [T, B, 1] td_error_clip = BATCH.discounted_reward - value_clip # [T, B, 1] td_square = th.maximum(td_error.square(), td_error_clip.square()) # [T, B, 1] else: td_square = td_error.square() # [T, B, 1] critic_loss = 0.5 * td_square.mean() # 1 self.critic_oplr.optimize(critic_loss) return { 'LOSS/critic_loss': critic_loss, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr }
class SAC_V(SarlOffPolicy): """ Soft Actor Critic with Value neural network. https://arxiv.org/abs/1812.05905 Soft Actor-Critic for Discrete Action Settings. https://arxiv.org/abs/1910.07207 """ policy_mode = 'off-policy' def __init__( self, alpha=0.2, annealing=True, last_alpha=0.01, polyak=0.995, use_gumbel=True, discrete_tau=1.0, network_settings={ 'actor_continuous': { 'share': [128, 128], 'mu': [64], 'log_std': [64], 'soft_clip': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [64, 32], 'q': [128, 128], 'v': [128, 128] }, actor_lr=5.0e-4, critic_lr=1.0e-3, alpha_lr=5.0e-4, auto_adaption=True, **kwargs): super().__init__(**kwargs) self.polyak = polyak self.use_gumbel = use_gumbel self.discrete_tau = discrete_tau self.auto_adaption = auto_adaption self.annealing = annealing self.v_net = TargetTwin( CriticValue(self.obs_spec, rep_net_params=self._rep_net_params, network_settings=network_settings['v']), self.polyak).to(self.device) if self.is_continuous: self.actor = ActorCts( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to( self.device) else: self.actor = ActorDct( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to( self.device) # entropy = -log(1/|A|) = log |A| self.target_entropy = 0.98 * (-self.a_dim if self.is_continuous else np.log(self.a_dim)) if self.is_continuous or self.use_gumbel: self.q_net = CriticQvalueOne( self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, network_settings=network_settings['q']).to(self.device) else: self.q_net = CriticQvalueAll( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['q']).to(self.device) self.q_net2 = deepcopy(self.q_net) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR([self.q_net, self.q_net2, self.v_net], critic_lr, **self._oplr_params) if self.auto_adaption: self.log_alpha = th.tensor(0., requires_grad=True).to(self.device) self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params) self._trainer_modules.update(alpha_oplr=self.alpha_oplr) else: self.log_alpha = th.tensor(alpha).log().to(self.device) if self.annealing: self.alpha_annealing = LinearAnnealing(alpha, last_alpha, int(1e6)) self._trainer_modules.update(actor=self.actor, v_net=self.v_net, q_net=self.q_net, q_net2=self.q_net2, log_alpha=self.log_alpha, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) @property def alpha(self): return self.log_alpha.exp() @iton def select_action(self, obs): if self.is_continuous: mu, log_std = self.actor(obs, rnncs=self.rnncs) # [B, A] pi = td.Normal(mu, log_std.exp()).sample().tanh() # [B, A] mu.tanh_() # squash mu # [B, A] else: logits = self.actor(obs, rnncs=self.rnncs) # [B, A] mu = logits.argmax(-1) # [B,] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] self.rnncs_ = self.actor.get_rnncs() actions = pi if self._is_train_mode else mu return actions, Data(action=actions) def _train(self, BATCH): if self.is_continuous or self.use_gumbel: td_error, summaries = self._train_continuous(BATCH) else: td_error, summaries = self._train_discrete(BATCH) return td_error, summaries @iton def _train_continuous(self, BATCH): v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] v_target = self.v_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, 1] if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action( pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] q1 = self.q_net(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q2 = self.q_net2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q1_pi = self.q_net(BATCH.obs, pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q2_pi = self.q_net2(BATCH.obs, pi, begin_mask=BATCH.begin_mask) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target, BATCH.begin_mask).detach() # [T, B, 1] v_from_q_stop = (th.minimum(q1_pi, q2_pi) - self.alpha * log_pi).detach() # [T, B, 1] td_v = v - v_from_q_stop # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action( pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 q1_pi = self.q_net(BATCH.obs, pi, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -(q1_pi - self.alpha * log_pi).mean() # 1 self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/v_loss': v_loss_stop, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy, 'Statistics/q_min': th.minimum(q1, q2).min(), 'Statistics/q_mean': th.minimum(q1, q2).mean(), 'Statistics/q_max': th.maximum(q1, q2).max(), 'Statistics/v_mean': v.mean() } if self.auto_adaption: alpha_loss = -(self.alpha * (log_pi.detach() + self.target_entropy)).mean() self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries @iton def _train_discrete(self, BATCH): v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] v_target = self.v_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, 1] q1_all = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2_all = self.q_net2(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q1 = (q1_all * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q2 = (q2_all * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target, BATCH.begin_mask).detach() # [T, B, 1] td_v = v - (th.minimum((logp_all.exp() * q1_all).sum(-1, keepdim=True), (logp_all.exp() * q2_all).sum( -1, keepdim=True))).detach() # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop self.critic_oplr.optimize(critic_loss) q1_all = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2_all = self.q_net2(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] entropy = -(logp_all.exp() * logp_all).sum(-1, keepdim=True) # [T, B, 1] q_all = th.minimum(q1_all, q2_all) # [T, B, A] actor_loss = -((q_all - self.alpha * logp_all) * logp_all.exp()).sum( -1) # [T, B, A] => [T, B] actor_loss = actor_loss.mean() # 1 self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/v_loss': v_loss_stop, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy.mean(), 'Statistics/v_mean': v.mean() } if self.auto_adaption: corr = (self.target_entropy - entropy).detach() # [T, B, 1] # corr = ((logp_all - self.a_dim) * logp_all.exp()).sum(-1).detach() alpha_loss = -(self.alpha * corr) # [T, B, 1] alpha_loss = alpha_loss.mean() # 1 self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries def _after_train(self): super()._after_train() if self.annealing and not self.auto_adaption: self.log_alpha.copy_( self.alpha_annealing(self._cur_train_step).log()) self.v_net.sync()
class DPG(SarlOffPolicy): """ Deterministic Policy Gradient, https://hal.inria.fr/file/index/docid/938992/filename/dpg-icml2014.pdf """ policy_mode = 'off-policy' def __init__(self, actor_lr=5.0e-4, critic_lr=1.0e-3, use_target_action_noise=False, noise_action='ou', noise_params={ 'sigma': 0.2 }, discrete_tau=1.0, network_settings={ 'actor_continuous': [32, 32], 'actor_discrete': [32, 32], 'q': [32, 32] }, **kwargs): super().__init__(**kwargs) self.discrete_tau = discrete_tau self.use_target_action_noise = use_target_action_noise if self.is_continuous: self.target_noised_action = ClippedNormalNoisedAction(sigma=0.2, noise_bound=0.2) self.noised_action = Noise_action_REGISTER[noise_action](**noise_params) self.actor = ActorDPG(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to(self.device) else: self.actor = ActorDct(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to(self.device) self.critic = CriticQvalueOne(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, network_settings=network_settings['q']).to(self.device) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) def episode_reset(self): super().episode_reset() if self.is_continuous: self.noised_action.reset() @iton def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() if self.is_continuous: mu = output # [B, A] pi = self.noised_action(mu) # [B, A] else: logits = output # [B, A] mu = logits.argmax(-1) # [B,] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] actions = pi if self._is_train_mode else mu return actions, Data(action=actions) @iton def _train(self, BATCH): if self.is_continuous: action_target = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] if self.use_target_action_noise: action_target = self.target_noised_action(action_target) # [T, B, A] else: target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] action_target = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q_target = self.critic(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target, BATCH.begin_mask).detach() # [T, B, 1] q = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, A] td_error = dc_r - q # [T, B, A] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.critic_oplr.optimize(q_loss) if self.is_continuous: mu = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] _pi = logits.softmax(-1) # [T, B, A] _pi_true_one_hot = F.one_hot( logits.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] q_actor = self.critic(BATCH.obs, mu, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -q_actor.mean() # 1 self.actor_oplr.optimize(actor_loss) return td_error, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': q_loss, 'Statistics/q_min': q.min(), 'Statistics/q_mean': q.mean(), 'Statistics/q_max': q.max() }
class A2C(SarlOnPolicy): """ Synchronous Advantage Actor-Critic, A2C, http://arxiv.org/abs/1602.01783 """ policy_mode = 'on-policy' def __init__( self, agent_spec, beta=1.0e-3, actor_lr=5.0e-4, critic_lr=1.0e-3, network_settings={ 'actor_continuous': { 'hidden_units': [64, 64], 'condition_sigma': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [32, 32], 'critic': [32, 32] }, **kwargs): super().__init__(agent_spec=agent_spec, **kwargs) self.beta = beta if self.is_continuous: self.actor = ActorMuLogstd( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to( self.device) else: self.actor = ActorDct( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to( self.device) self.critic = CriticValue( self.obs_spec, rep_net_params=self._rep_net_params, network_settings=network_settings['critic']).to(self.device) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) @iton def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() if self.is_continuous: mu, log_std = output # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] else: logits = output # [B, A] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] acts_info = Data(action=action) if self.use_rnn: acts_info.update(rnncs=self.rnncs) return action, acts_info @iton def _get_value(self, obs, rnncs=None): value = self.critic(obs, rnncs=self.rnncs) return value def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH = super()._preprocess_BATCH(BATCH) value = self._get_value(BATCH.obs_[-1], rnncs=self.rnncs) BATCH.discounted_reward = discounted_sum(BATCH.reward, self.gamma, BATCH.done, BATCH.begin_mask, init_value=value) td_error = calculate_td_error( BATCH.reward, self.gamma, BATCH.done, value=BATCH.value, next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]), 0)) BATCH.gae_adv = discounted_sum(td_error, self.lambda_ * self.gamma, BATCH.done, BATCH.begin_mask, init_value=0., normalize=True) return BATCH @iton def _train(self, BATCH): v = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] td_error = BATCH.discounted_reward - v # [T, B, 1] critic_loss = td_error.square().mean() # 1 self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_act_prob = dist.log_prob(BATCH.action).unsqueeze( -1) # [T, B, 1] entropy = dist.entropy().unsqueeze(-1) # [T, B, 1] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] log_act_prob = (BATCH.action * logp_all).sum( -1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum( -1, keepdim=True) # [T, B, 1] # advantage = BATCH.discounted_reward - v.detach() # [T, B, 1] actor_loss = -(log_act_prob * BATCH.gae_adv + self.beta * entropy).mean() # 1 self.actor_oplr.optimize(actor_loss) return { 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/entropy': entropy.mean(), 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr }
class NPG(SarlOnPolicy): """ Natural Policy Gradient, NPG https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf """ policy_mode = 'on-policy' def __init__( self, agent_spec, actor_step_size=0.5, beta=1.0e-3, lambda_=0.95, cg_iters=10, damping_coeff=0.1, epsilon=0.2, critic_lr=1e-3, train_critic_iters=10, network_settings={ 'actor_continuous': { 'hidden_units': [64, 64], 'condition_sigma': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [32, 32], 'critic': [32, 32] }, **kwargs): super().__init__(agent_spec=agent_spec, **kwargs) self.actor_step_size = actor_step_size self.beta = beta self.lambda_ = lambda_ self._epsilon = epsilon self._cg_iters = cg_iters self._damping_coeff = damping_coeff self._train_critic_iters = train_critic_iters if self.is_continuous: self.actor = ActorMuLogstd( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to( self.device) else: self.actor = ActorDct( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to( self.device) self.critic = CriticValue( self.obs_spec, rep_net_params=self._rep_net_params, network_settings=network_settings['critic']).to(self.device) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, critic_oplr=self.critic_oplr) @iton def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() value = self.critic(obs, rnncs=self.rnncs) # [B, 1] if self.is_continuous: mu, log_std = output # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] log_prob = dist.log_prob(action).unsqueeze(-1) # [B, 1] else: logits = output # [B, A] logp_all = logits.log_softmax(-1) # [B, A] norm_dist = td.Categorical(logits=logp_all) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action).unsqueeze(-1) # [B, 1] acts_info = Data(action=action, value=value, log_prob=log_prob + th.finfo().eps) if self.use_rnn: acts_info.update(rnncs=self.rnncs) if self.is_continuous: acts_info.update(mu=mu, log_std=log_std) else: acts_info.update(logp_all=logp_all) return action, acts_info @iton def _get_value(self, obs, rnncs=None): value = self.critic(obs, rnncs=rnncs) # [B, 1] return value def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH = super()._preprocess_BATCH(BATCH) value = self._get_value(BATCH.obs_[-1], rnncs=self.rnncs) BATCH.discounted_reward = discounted_sum(BATCH.reward, self.gamma, BATCH.done, BATCH.begin_mask, init_value=value) td_error = calculate_td_error( BATCH.reward, self.gamma, BATCH.done, value=BATCH.value, next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]), 0)) BATCH.gae_adv = discounted_sum(td_error, self.lambda_ * self.gamma, BATCH.done, BATCH.begin_mask, init_value=0., normalize=True) return BATCH @iton def _train(self, BATCH): output = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] if self.is_continuous: mu, log_std = output # [T, B, A], [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) new_log_prob = dist.log_prob(BATCH.action).unsqueeze( -1) # [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = output # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] new_log_prob = (BATCH.action * logp_all).sum( -1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 ratio = (new_log_prob - BATCH.log_prob).exp() # [T, B, 1] actor_loss = -(ratio * BATCH.gae_adv).mean() # 1 flat_grads = grads_flatten(actor_loss, self.actor, retain_graph=True).detach() # [1,] if self.is_continuous: kl = td.kl_divergence( td.Independent(td.Normal(BATCH.mu, BATCH.log_std.exp()), 1), td.Independent(td.Normal(mu, log_std.exp()), 1)).mean() else: kl = (BATCH.logp_all.exp() * (BATCH.logp_all - logp_all)).sum(-1).mean() # 1 flat_kl_grad = grads_flatten(kl, self.actor, create_graph=True) search_direction = -self._conjugate_gradients( flat_grads, flat_kl_grad, cg_iters=self._cg_iters) # [1,] with th.no_grad(): flat_params = th.cat( [param.data.view(-1) for param in self.actor.parameters()]) new_flat_params = flat_params + self.actor_step_size * search_direction set_from_flat_params(self.actor, new_flat_params) for _ in range(self._train_critic_iters): value = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] td_error = BATCH.discounted_reward - value # [T, B, 1] critic_loss = td_error.square().mean() # 1 self.critic_oplr.optimize(critic_loss) return { 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/entropy': entropy.mean(), 'LEARNING_RATE/critic_lr': self.critic_oplr.lr } def _conjugate_gradients(self, flat_grads, flat_kl_grad, cg_iters: int = 10, residual_tol: float = 1e-10): """ Conjugate gradient algorithm (see https://en.wikipedia.org/wiki/Conjugate_gradient_method) """ x = th.zeros_like(flat_grads) r, p = flat_grads.clone(), flat_grads.clone() # Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0. # Change if doing warm start. rdotr = r.dot(r) for i in range(cg_iters): z = self._MVP(p, flat_kl_grad) alpha = rdotr / (p.dot(z) + th.finfo().eps) x += alpha * p r -= alpha * z new_rdotr = r.dot(r) if new_rdotr < residual_tol: break p = r + new_rdotr / rdotr * p rdotr = new_rdotr return x def _MVP(self, v, flat_kl_grad): """Matrix vector product.""" # caculate second order gradient of kl with respect to theta kl_v = (flat_kl_grad * v).sum() mvp = grads_flatten(kl_v, self.actor, retain_graph=True).detach() mvp += max(0, self._damping_coeff) * v return mvp
class PG(SarlOnPolicy): policy_mode = 'on-policy' def __init__( self, agent_spec, lr=5.0e-4, network_settings={ 'actor_continuous': { 'hidden_units': [32, 32], 'condition_sigma': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [32, 32] }, **kwargs): super().__init__(agent_spec=agent_spec, **kwargs) if self.is_continuous: self.net = ActorMuLogstd( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to( self.device) else: self.net = ActorDct( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to( self.device) self.oplr = OPLR(self.net, lr, **self._oplr_params) self._trainer_modules.update(model=self.net, oplr=self.oplr) @iton def select_action(self, obs): output = self.net(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.net.get_rnncs() if self.is_continuous: mu, log_std = output # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] else: logits = output # [B, A] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] acts_info = Data(action=action) if self.use_rnn: acts_info.update(rnncs=self.rnncs) return action, acts_info def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH = super()._preprocess_BATCH(BATCH) BATCH.discounted_reward = discounted_sum(BATCH.reward, self.gamma, BATCH.done, BATCH.begin_mask, init_value=0., normalize=True) return BATCH @iton def _train(self, BATCH): # [B, T, *] output = self.net(BATCH.obs, begin_mask=BATCH.begin_mask) # [B, T, A] if self.is_continuous: mu, log_std = output # [B, T, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_act_prob = dist.log_prob(BATCH.action).unsqueeze( -1) # [B, T, 1] entropy = dist.entropy().unsqueeze(-1) # [B, T, 1] else: logits = output # [B, T, A] logp_all = logits.log_softmax(-1) # [B, T, A] log_act_prob = (logp_all * BATCH.action).sum( -1, keepdim=True) # [B, T, 1] entropy = -(logp_all.exp() * logp_all).sum( 1, keepdim=True) # [B, T, 1] loss = -(log_act_prob * BATCH.discounted_reward).mean() self.oplr.optimize(loss) return { 'LOSS/loss': loss, 'Statistics/entropy': entropy.mean(), 'LEARNING_RATE/lr': self.oplr.lr }
class AC(SarlOffPolicy): policy_mode = 'off-policy' # off-policy actor-critic def __init__( self, actor_lr=5.0e-4, critic_lr=1.0e-3, network_settings={ 'actor_continuous': { 'hidden_units': [64, 64], 'condition_sigma': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [32, 32], 'critic': [32, 32] }, **kwargs): super().__init__(**kwargs) if self.is_continuous: self.actor = ActorMuLogstd( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to( self.device) else: self.actor = ActorDct( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to( self.device) self.critic = CriticQvalueOne( self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, network_settings=network_settings['critic']).to(self.device) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) @iton def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, *] self.rnncs_ = self.actor.get_rnncs() if self.is_continuous: mu, log_std = output # [B, *] dist = td.Independent(td.Normal(mu, log_std.exp()), -1) action = dist.sample().clamp(-1, 1) # [B, *] log_prob = dist.log_prob(action) # [B,] else: logits = output # [B, *] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action) # [B,] return action, Data(action=action, log_prob=log_prob) def random_action(self): actions = super().random_action() if self.is_continuous: self._acts_info.update(log_prob=np.full(self.n_copies, np.log(0.5))) # [B,] else: self._acts_info.update(log_prob=np.full(self.n_copies, 1. / self.a_dim)) # [B,] return actions @iton def _train(self, BATCH): q = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] if self.is_continuous: next_mu, _ = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, *] max_q_next = self.critic( BATCH.obs_, next_mu, begin_mask=BATCH.begin_mask).detach() # [T, B, 1] else: logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, *] max_a = logits.argmax(-1) # [T, B] max_a_one_hot = F.one_hot(max_a, self.a_dim).float() # [T, B, N] max_q_next = self.critic(BATCH.obs_, max_a_one_hot).detach() # [T, B, 1] td_error = q - n_step_return(BATCH.reward, self.gamma, BATCH.done, max_q_next, BATCH.begin_mask).detach() # [T, B, 1] critic_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, *] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_prob = dist.log_prob(BATCH.action) # [T, B] entropy = dist.entropy().mean() # 1 else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, *] logp_all = logits.log_softmax(-1) # [T, B, *] log_prob = (logp_all * BATCH.action).sum(-1) # [T, B] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 ratio = (log_prob - BATCH.log_prob).exp().detach() # [T, B] actor_loss = -(ratio * log_prob * q.squeeze(-1).detach()).mean() # [T, B] => 1 self.actor_oplr.optimize(actor_loss) return td_error, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/q_max': q.max(), 'Statistics/q_min': q.min(), 'Statistics/q_mean': q.mean(), 'Statistics/ratio': ratio.mean(), 'Statistics/entropy': entropy }