class DDPG(SarlOffPolicy): """ Deep Deterministic Policy Gradient, https://arxiv.org/abs/1509.02971 """ policy_mode = 'off-policy' def __init__(self, polyak=0.995, noise_action='ou', noise_params={'sigma': 0.2}, use_target_action_noise=False, actor_lr=5.0e-4, critic_lr=1.0e-3, discrete_tau=1.0, network_settings={ 'actor_continuous': [32, 32], 'actor_discrete': [32, 32], 'q': [32, 32] }, **kwargs): super().__init__(**kwargs) self.polyak = polyak self.discrete_tau = discrete_tau self.use_target_action_noise = use_target_action_noise if self.is_continuous: actor = ActorDPG( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']) self.target_noised_action = ClippedNormalNoisedAction( sigma=0.2, noise_bound=0.2) if noise_action in ['ou', 'clip_normal']: self.noised_action = Noise_action_REGISTER[noise_action]( **noise_params) elif noise_action == 'normal': self.noised_action = self.target_noised_action else: raise Exception( f'cannot use noised action type of {noise_action}') else: actor = ActorDct( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']) self.actor = TargetTwin(actor, self.polyak).to(self.device) self.critic = TargetTwin( CriticQvalueOne(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, network_settings=network_settings['q']), self.polyak).to(self.device) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) def episode_reset(self): super().episode_reset() if self.is_continuous: self.noised_action.reset() @iton def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() if self.is_continuous: mu = output # [B, A] pi = self.noised_action(mu) # [B, A] else: logits = output # [B, A] mu = logits.argmax(-1) # [B, ] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] actions = pi if self._is_train_mode else mu return actions, Data(action=actions) @iton def _train(self, BATCH): if self.is_continuous: action_target = self.actor.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] if self.use_target_action_noise: action_target = self.target_noised_action( action_target) # [T, B, A] else: target_logits = self.actor.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] action_target = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q_target = self.critic.t(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target, BATCH.begin_mask).detach() # [T, B, 1] td_error = dc_r - q # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.critic_oplr.optimize(q_loss) if self.is_continuous: mu = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] q_actor = self.critic(BATCH.obs, mu, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -q_actor.mean() # 1 self.actor_oplr.optimize(actor_loss) return td_error, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': q_loss, 'Statistics/q_min': q.min(), 'Statistics/q_mean': q.mean(), 'Statistics/q_max': q.max() } def _after_train(self): super()._after_train() self.actor.sync() self.critic.sync()
class OC(SarlOffPolicy): """ The Option-Critic Architecture. http://arxiv.org/abs/1609.05140 """ policy_mode = 'off-policy' def __init__(self, q_lr=5.0e-3, intra_option_lr=5.0e-4, termination_lr=5.0e-4, use_eps_greedy=False, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, boltzmann_temperature=1.0, options_num=4, ent_coff=0.01, double_q=False, use_baseline=True, terminal_mask=True, termination_regularizer=0.01, assign_interval=1000, network_settings={ 'q': [32, 32], 'intra_option': [32, 32], 'termination': [32, 32] }, **kwargs): super().__init__(**kwargs) self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.options_num = options_num self.termination_regularizer = termination_regularizer self.ent_coff = ent_coff self.use_baseline = use_baseline self.terminal_mask = terminal_mask self.double_q = double_q self.boltzmann_temperature = boltzmann_temperature self.use_eps_greedy = use_eps_greedy self.q_net = TargetTwin( CriticQvalueAll(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.options_num, network_settings=network_settings['q'])).to( self.device) self.intra_option_net = OcIntraOption( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, options_num=self.options_num, network_settings=network_settings['intra_option']).to(self.device) self.termination_net = CriticQvalueAll( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.options_num, network_settings=network_settings['termination'], out_act='sigmoid').to(self.device) if self.is_continuous: # https://discuss.pytorch.org/t/valueerror-cant-optimize-a-non-leaf-tensor/21751 # https://blog.csdn.net/nkhgl/article/details/100047276 self.log_std = th.as_tensor( np.full((self.options_num, self.a_dim), -0.5)).requires_grad_().to(self.device) # [P, A] self.intra_option_oplr = OPLR( [self.intra_option_net, self.log_std], intra_option_lr, **self._oplr_params) else: self.intra_option_oplr = OPLR(self.intra_option_net, intra_option_lr, **self._oplr_params) self.q_oplr = OPLR(self.q_net, q_lr, **self._oplr_params) self.termination_oplr = OPLR(self.termination_net, termination_lr, **self._oplr_params) self._trainer_modules.update(q_net=self.q_net, intra_option_net=self.intra_option_net, termination_net=self.termination_net, q_oplr=self.q_oplr, intra_option_oplr=self.intra_option_oplr, termination_oplr=self.termination_oplr) self.options = self.new_options = self._generate_random_options() def _generate_random_options(self): # [B,] return th.tensor(np.random.randint(0, self.options_num, self.n_copies)).to(self.device) def episode_step(self, obs: Data, env_rets: Data, begin_mask: np.ndarray): super().episode_step(obs, env_rets, begin_mask) self.options = self.new_options @iton def select_action(self, obs): q = self.q_net(obs, rnncs=self.rnncs) # [B, P] self.rnncs_ = self.q_net.get_rnncs() pi = self.intra_option_net(obs, rnncs=self.rnncs) # [B, P, A] beta = self.termination_net(obs, rnncs=self.rnncs) # [B, P] options_onehot = F.one_hot(self.options, self.options_num).float() # [B, P] options_onehot_expanded = options_onehot.unsqueeze(-1) # [B, P, 1] pi = (pi * options_onehot_expanded).sum(-2) # [B, A] if self.is_continuous: mu = pi.tanh() # [B, A] log_std = self.log_std[self.options] # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) actions = dist.sample().clamp(-1, 1) # [B, A] else: pi = pi / self.boltzmann_temperature # [B, A] dist = td.Categorical(logits=pi) actions = dist.sample() # [B, ] max_options = q.argmax(-1).long() # [B, P] => [B, ] if self.use_eps_greedy: # epsilon greedy if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): self.new_options = self._generate_random_options() else: self.new_options = max_options else: beta_probs = (beta * options_onehot).sum(-1) # [B, P] => [B,] beta_dist = td.Bernoulli(probs=beta_probs) self.new_options = th.where(beta_dist.sample() < 1, self.options, max_options) return actions, Data(action=actions, last_options=self.options, options=self.new_options) def random_action(self): actions = super().random_action() self._acts_info.update( last_options=np.random.randint(0, self.options_num, self.n_copies), options=np.random.randint(0, self.options_num, self.n_copies)) return actions def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH = super()._preprocess_BATCH(BATCH) BATCH.last_options = int2one_hot(BATCH.last_options, self.options_num) BATCH.options = int2one_hot(BATCH.options, self.options_num) return BATCH @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] beta_next = self.termination_net( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] qu_eval = (q * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] beta_s_ = (beta_next * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] q_s_ = (q_next * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L94 if self.double_q: q_ = self.q_net(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] # [T, B, P] => [T, B] => [T, B, P] max_a_idx = F.one_hot(q_.argmax(-1), self.options_num).float() q_s_max = (q_next * max_a_idx).sum(-1, keepdim=True) # [T, B, 1] else: q_s_max = q_next.max(-1, keepdim=True)[0] # [T, B, 1] u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max # [T, B, 1] qu_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, u_target, BATCH.begin_mask).detach() # [T, B, 1] td_error = qu_target - qu_eval # gradient : q [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # [T, B, 1] => 1 self.q_oplr.optimize(q_loss) q_s = qu_eval.detach() # [T, B, 1] # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L130 if self.use_baseline: adv = (qu_target - q_s).detach() # [T, B, 1] else: adv = qu_target.detach() # [T, B, 1] # [T, B, P] => [T, B, P, 1] options_onehot_expanded = BATCH.options.unsqueeze(-1) pi = self.intra_option_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P, A] # [T, B, P, A] => [T, B, A] pi = (pi * options_onehot_expanded).sum(-2) if self.is_continuous: mu = pi.tanh() # [T, B, A] log_std = self.log_std[BATCH.options.argmax(-1)] # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_p = dist.log_prob(BATCH.action).unsqueeze(-1) # [T, B, 1] entropy = dist.entropy().unsqueeze(-1) # [T, B, 1] else: pi = pi / self.boltzmann_temperature # [T, B, A] log_pi = pi.log_softmax(-1) # [T, B, A] entropy = -(log_pi.exp() * log_pi).sum(-1, keepdim=True) # [T, B, 1] log_p = (BATCH.action * log_pi).sum(-1, keepdim=True) # [T, B, 1] pi_loss = -(log_p * adv + self.ent_coff * entropy).mean() # 1 beta = self.termination_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] beta_s = (beta * BATCH.last_options).sum(-1, keepdim=True) # [T, B, 1] if self.use_eps_greedy: v_s = q.max( -1, keepdim=True)[0] - self.termination_regularizer # [T, B, 1] else: v_s = (1 - beta_s) * q_s + beta_s * q.max( -1, keepdim=True)[0] # [T, B, 1] # v_s = q.mean(-1, keepdim=True) # [T, B, 1] beta_loss = beta_s * (q_s - v_s).detach() # [T, B, 1] # https://github.com/lweitkamp/option-critic-pytorch/blob/0c57da7686f8903ed2d8dded3fae832ee9defd1a/option_critic.py#L238 if self.terminal_mask: beta_loss *= (1 - BATCH.done) # [T, B, 1] beta_loss = beta_loss.mean() # 1 self.intra_option_oplr.optimize(pi_loss) self.termination_oplr.optimize(beta_loss) return td_error, { 'LEARNING_RATE/q_lr': self.q_oplr.lr, 'LEARNING_RATE/intra_option_lr': self.intra_option_oplr.lr, 'LEARNING_RATE/termination_lr': self.termination_oplr.lr, # 'Statistics/option': self.options[0], 'LOSS/q_loss': q_loss, 'LOSS/pi_loss': pi_loss, 'LOSS/beta_loss': beta_loss, 'Statistics/q_option_max': q_s.max(), 'Statistics/q_option_min': q_s.min(), 'Statistics/q_option_mean': q_s.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class C51(SarlOffPolicy): """ Category 51, https://arxiv.org/abs/1707.06887 No double, no dueling, no noisy net. """ policy_mode = 'off-policy' def __init__(self, v_min=-10, v_max=10, atoms=51, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=1000, network_settings=[128, 128], **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'c51 only support discrete action space' self._v_min = v_min self._v_max = v_max self._atoms = atoms self._delta_z = (self._v_max - self._v_min) / (self._atoms - 1) self._z = th.linspace(self._v_min, self._v_max, self._atoms).float().to(self.device) # [N,] self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.q_net = TargetTwin( C51Distributional(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, atoms=self._atoms, network_settings=network_settings)).to( self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): feat = self.q_net(obs, rnncs=self.rnncs) # [B, A, N] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: q = (self._z * feat).sum(-1) # [B, A, N] * [N,] => [B, A] actions = q.argmax(-1) # [B,] return actions, Data(action=actions) @iton def _train(self, BATCH): q_dist = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A, N] # [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N] q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2) q_eval = (q_dist * self._z).sum(-1) # [T, B, N] * [N,] => [T, B] target_q_dist = self.q_net.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A, N] # [T, B, A, N] * [1, N] => [T, B, A] target_q = (target_q_dist * self._z).sum(-1) a_ = target_q.argmax(-1) # [T, B] a_onehot = F.one_hot(a_, self.a_dim).float() # [T, B, A] # [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N] target_q_dist = (target_q_dist * a_onehot.unsqueeze(-1)).sum(-2) target = n_step_return( BATCH.reward.repeat(1, 1, self._atoms), self.gamma, BATCH.done.repeat(1, 1, self._atoms), target_q_dist, BATCH.begin_mask.repeat(1, 1, self._atoms)).detach() # [T, B, N] target = target.clamp(self._v_min, self._v_max) # [T, B, N] # An amazing trick for calculating the projection gracefully. # ref: https://github.com/ShangtongZhang/DeepRL target_dist = ( 1 - (target.unsqueeze(-1) - self._z.view(1, 1, -1, 1)).abs() / self._delta_z).clamp(0, 1) * target_q_dist.unsqueeze( -1) # [T, B, N, 1] target_dist = target_dist.sum(-1) # [T, B, N] _cross_entropy = -(target_dist * th.log(q_dist + th.finfo().eps)).sum( -1, keepdim=True) # [T, B, 1] loss = (_cross_entropy * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(loss) return _cross_entropy, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class CuriosityModel(nn.Module): """ Model of Intrinsic Curiosity Module (ICM). Curiosity-driven Exploration by Self-supervised Prediction, https://arxiv.org/abs/1705.05363 """ def __init__(self, obs_spec, rep_net_params, is_continuous, action_dim, *, eta=0.2, lr=1.0e-3, beta=0.2): """ params: is_continuous: sepecify whether action space is continuous(True) or discrete(False) action_dim: dimension of action eta: weight of intrinsic reward lr: the learning rate of curiosity model beta: weight factor of loss between inverse_dynamic_net and forward_net """ super().__init__() self.eta = eta self.beta = beta self.is_continuous = is_continuous self.action_dim = action_dim self.rep_net = RepresentationNetwork(obs_spec=obs_spec, rep_net_params=rep_net_params) self.feat_dim = self.rep_net.h_dim # S, S' => A self.inverse_dynamic_net = nn.Sequential( nn.Linear(self.feat_dim * 2, self.feat_dim * 2), Act_REGISTER[default_act](), nn.Linear(self.feat_dim * 2, action_dim)) if self.is_continuous: self.inverse_dynamic_net.add_module('tanh', nn.Tanh()) # S, A => S' self.forward_net = nn.Sequential( nn.Linear(self.feat_dim + action_dim, self.feat_dim), Act_REGISTER[default_act](), nn.Linear(self.feat_dim, self.feat_dim)) self.oplr = OPLR( models=[self.rep_net, self.inverse_dynamic_net, self.forward_net], lr=lr) def forward(self, BATCH): fs, _ = self.rep_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, *] fs_, _ = self.rep_net(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, *] # [T, B, *] <S, A> => S' s_eval = self.forward_net(th.cat((fs, BATCH.action), -1)) LF = 0.5 * (fs_ - s_eval).square().sum(-1, keepdim=True) # [T, B, 1] intrinsic_reward = self.eta * LF loss_forward = LF.mean() # 1 a_eval = self.inverse_dynamic_net(th.cat((fs, fs_), -1)) # [T, B, *] if self.is_continuous: loss_inverse = 0.5 * \ (a_eval - BATCH.action).square().sum(-1).mean() else: idx = BATCH.action.argmax(-1) # [T, B] loss_inverse = F.cross_entropy(a_eval.view(-1, self.action_dim), idx.view(-1)) # 1 loss = (1 - self.beta) * loss_inverse + self.beta * loss_forward self.oplr.optimize(loss) summaries = { 'LOSS/curiosity_loss': loss, 'LOSS/forward_loss': loss_forward, 'LOSS/inverse_loss': loss_inverse } return intrinsic_reward, summaries
class IQN(SarlOffPolicy): """ Implicit Quantile Networks, https://arxiv.org/abs/1806.06923 Double DQN """ policy_mode = 'off-policy' def __init__(self, online_quantiles=8, target_quantiles=8, select_quantiles=32, quantiles_idx=64, huber_delta=1., lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=2, network_settings={ 'q_net': [128, 64], 'quantile': [128, 64], 'tile': [64] }, **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'iqn only support discrete action space' self.online_quantiles = online_quantiles self.target_quantiles = target_quantiles self.select_quantiles = select_quantiles self.quantiles_idx = quantiles_idx self.huber_delta = huber_delta self.assign_interval = assign_interval self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.q_net = TargetTwin(IqnNet(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, quantiles_idx=self.quantiles_idx, network_settings=network_settings)).to(self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): _, select_quantiles_tiled = self._generate_quantiles( # [N*B, X] batch_size=self.n_copies, quantiles_num=self.select_quantiles ) q_values = self.q_net(obs, select_quantiles_tiled, rnncs=self.rnncs) # [N, B, A] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: # [N, B, A] => [B, A] => [B,] actions = q_values.mean(0).argmax(-1) return actions, Data(action=actions) def _generate_quantiles(self, batch_size, quantiles_num): _quantiles = th.rand([quantiles_num * batch_size, 1]) # [N*B, 1] _quantiles_tiled = _quantiles.repeat(1, self.quantiles_idx) # [N*B, 1] => [N*B, X] # pi * i * tau [N*B, X] * [X, ] => [N*B, X] _quantiles_tiled = th.arange(self.quantiles_idx) * np.pi * _quantiles_tiled _quantiles_tiled.cos_() # [N*B, X] _quantiles = _quantiles.view(batch_size, quantiles_num, 1) # [N*B, 1] => [B, N, 1] return _quantiles, _quantiles_tiled # [B, N, 1], [N*B, X] @iton def _train(self, BATCH): time_step = BATCH.reward.shape[0] batch_size = BATCH.reward.shape[1] quantiles, quantiles_tiled = self._generate_quantiles( # [T*B, N, 1], [N*T*B, X] batch_size=time_step * batch_size, quantiles_num=self.online_quantiles) # [T*B, N, 1] => [T, B, N, 1] quantiles = quantiles.view(time_step, batch_size, -1, 1) quantiles_tiled = quantiles_tiled.view(time_step, -1, self.quantiles_idx) # [N*T*B, X] => [T, N*B, X] quantiles_value = self.q_net(BATCH.obs, quantiles_tiled, begin_mask=BATCH.begin_mask) # [T, N, B, A] # [T, N, B, A] => [N, T, B, A] * [T, B, A] => [N, T, B, 1] quantiles_value = (quantiles_value.swapaxes(0, 1) * BATCH.action).sum(-1, keepdim=True) q_eval = quantiles_value.mean(0) # [N, T, B, 1] => [T, B, 1] _, select_quantiles_tiled = self._generate_quantiles( # [N*T*B, X] batch_size=time_step * batch_size, quantiles_num=self.select_quantiles) select_quantiles_tiled = select_quantiles_tiled.view( time_step, -1, self.quantiles_idx) # [N*T*B, X] => [T, N*B, X] q_values = self.q_net( BATCH.obs_, select_quantiles_tiled, begin_mask=BATCH.begin_mask) # [T, N, B, A] q_values = q_values.mean(1) # [T, N, B, A] => [T, B, A] next_max_action = q_values.argmax(-1) # [T, B] next_max_action = F.one_hot( next_max_action, self.a_dim).float() # [T, B, A] _, target_quantiles_tiled = self._generate_quantiles( # [N'*T*B, X] batch_size=time_step * batch_size, quantiles_num=self.target_quantiles) target_quantiles_tiled = target_quantiles_tiled.view( time_step, -1, self.quantiles_idx) # [N'*T*B, X] => [T, N'*B, X] target_quantiles_value = self.q_net.t(BATCH.obs_, target_quantiles_tiled, begin_mask=BATCH.begin_mask) # [T, N', B, A] target_quantiles_value = target_quantiles_value.swapaxes(0, 1) # [T, N', B, A] => [N', T, B, A] target_quantiles_value = (target_quantiles_value * next_max_action).sum(-1, keepdim=True) # [N', T, B, 1] target_q = target_quantiles_value.mean(0) # [T, B, 1] q_target = n_step_return(BATCH.reward, # [T, B, 1] self.gamma, BATCH.done, # [T, B, 1] target_q, # [T, B, 1] BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] # [N', T, B, 1] => [N', T, B] target_quantiles_value = target_quantiles_value.squeeze(-1) target_quantiles_value = target_quantiles_value.permute( 1, 2, 0) # [N', T, B] => [T, B, N'] quantiles_value_target = n_step_return(BATCH.reward.repeat(1, 1, self.target_quantiles), self.gamma, BATCH.done.repeat(1, 1, self.target_quantiles), target_quantiles_value, BATCH.begin_mask.repeat(1, 1, self.target_quantiles)).detach() # [T, B, N'] # [T, B, N'] => [T, B, 1, N'] quantiles_value_target = quantiles_value_target.unsqueeze(-2) quantiles_value_online = quantiles_value.permute(1, 2, 0, 3) # [N, T, B, 1] => [T, B, N, 1] # [T, B, N, 1] - [T, B, 1, N'] => [T, B, N, N'] quantile_error = quantiles_value_online - quantiles_value_target huber = F.huber_loss(quantiles_value_online, quantiles_value_target, reduction="none", delta=self.huber_delta) # [T, B, N, N] # [T, B, N, 1] - [T, B, N, N'] => [T, B, N, N'] huber_abs = (quantiles - quantile_error.detach().le(0.).float()).abs() loss = (huber_abs * huber).mean(-1) # [T, B, N, N'] => [T, B, N] loss = loss.sum(-1, keepdim=True) # [T, B, N] => [T, B, 1] loss = (loss * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class IOC(SarlOffPolicy): """ Learning Options with Interest Functions, https://www.aaai.org/ojs/index.php/AAAI/article/view/5114/4987 Options of Interest: Temporal Abstraction with Interest Functions, http://arxiv.org/abs/2001.00271 """ policy_mode = 'off-policy' def __init__( self, q_lr=5.0e-3, intra_option_lr=5.0e-4, termination_lr=5.0e-4, interest_lr=5.0e-4, boltzmann_temperature=1.0, options_num=4, ent_coff=0.01, double_q=False, use_baseline=True, terminal_mask=True, termination_regularizer=0.01, assign_interval=1000, network_settings={ 'q': [32, 32], 'intra_option': [32, 32], 'termination': [32, 32], 'interest': [32, 32] }, **kwargs): super().__init__(**kwargs) self.assign_interval = assign_interval self.options_num = options_num self.termination_regularizer = termination_regularizer self.ent_coff = ent_coff self.use_baseline = use_baseline self.terminal_mask = terminal_mask self.double_q = double_q self.boltzmann_temperature = boltzmann_temperature self.q_net = TargetTwin( CriticQvalueAll(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.options_num, network_settings=network_settings['q'])).to( self.device) self.intra_option_net = OcIntraOption( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, options_num=self.options_num, network_settings=network_settings['intra_option']).to(self.device) self.termination_net = CriticQvalueAll( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.options_num, network_settings=network_settings['termination'], out_act='sigmoid').to(self.device) self.interest_net = CriticQvalueAll( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.options_num, network_settings=network_settings['interest'], out_act='sigmoid').to(self.device) if self.is_continuous: self.log_std = th.as_tensor( np.full((self.options_num, self.a_dim), -0.5)).requires_grad_().to(self.device) # [P, A] self.intra_option_oplr = OPLR( [self.intra_option_net, self.log_std], intra_option_lr, **self._oplr_params) else: self.intra_option_oplr = OPLR(self.intra_option_net, intra_option_lr, **self._oplr_params) self.q_oplr = OPLR(self.q_net, q_lr, **self._oplr_params) self.termination_oplr = OPLR(self.termination_net, termination_lr, **self._oplr_params) self.interest_oplr = OPLR(self.interest_net, interest_lr, **self._oplr_params) self._trainer_modules.update(q_net=self.q_net, intra_option_net=self.intra_option_net, termination_net=self.termination_net, interest_net=self.interest_net, q_oplr=self.q_oplr, intra_option_oplr=self.intra_option_oplr, termination_oplr=self.termination_oplr, interest_oplr=self.interest_oplr) self.options = self.new_options = th.tensor( np.random.randint(0, self.options_num, self.n_copies)).to(self.device) def episode_step(self, obs: Data, env_rets: Data, begin_mask: np.ndarray): super().episode_step(obs, env_rets, begin_mask) self.options = self.new_options @iton def select_action(self, obs): q = self.q_net(obs, rnncs=self.rnncs) # [B, P] self.rnncs_ = self.q_net.get_rnncs() pi = self.intra_option_net(obs, rnncs=self.rnncs) # [B, P, A] options_onehot = F.one_hot(self.options, self.options_num).float() # [B, P] options_onehot_expanded = options_onehot.unsqueeze(-1) # [B, P, 1] pi = (pi * options_onehot_expanded).sum(-2) # [B, A] if self.is_continuous: mu = pi.tanh() # [B, A] log_std = self.log_std[self.options] # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) actions = dist.sample().clamp(-1, 1) # [B, A] else: pi = pi / self.boltzmann_temperature # [B, A] dist = td.Categorical(logits=pi) actions = dist.sample() # [B, ] interests = self.interest_net(obs, rnncs=self.rnncs) # [B, P] op_logits = interests * q # [B, P] or q.softmax(-1) self.new_options = td.Categorical(logits=op_logits).sample() # [B, ] return actions, Data(action=actions, last_options=self.options, options=self.new_options) def random_action(self): actions = super().random_action() self._acts_info.update( last_options=np.random.randint(0, self.options_num, self.n_copies), options=np.random.randint(0, self.options_num, self.n_copies)) return actions def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH = super()._preprocess_BATCH(BATCH) BATCH.last_options = int2one_hot(BATCH.last_options, self.options_num) BATCH.options = int2one_hot(BATCH.options, self.options_num) return BATCH @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] beta_next = self.termination_net( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] qu_eval = (q * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] beta_s_ = (beta_next * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] q_s_ = (q_next * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] if self.double_q: q_ = self.q_net(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] max_a_idx = F.one_hot(q_.argmax(-1), self.options_num).float() # [T, B, P] q_s_max = (q_next * max_a_idx).sum(-1, keepdim=True) # [T, B, 1] else: q_s_max = q_next.max(-1, keepdim=True)[0] # [T, B, 1] u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max # [T, B, 1] qu_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, u_target, BATCH.begin_mask).detach() # [T, B, 1] td_error = qu_target - qu_eval # [T, B, 1] gradient : q q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.q_oplr.optimize(q_loss) q_s = qu_eval.detach() # [T, B, 1] pi = self.intra_option_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P, A] if self.use_baseline: adv = (qu_target - q_s).detach() # [T, B, 1] else: adv = qu_target.detach() # [T, B, 1] # [T, B, P] => [T, B, P, 1] options_onehot_expanded = BATCH.options.unsqueeze(-1) # [T, B, P, A] => [T, B, A] pi = (pi * options_onehot_expanded).sum(-2) if self.is_continuous: mu = pi.tanh() # [T, B, A] log_std = self.log_std[BATCH.options.argmax(-1)] # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_p = dist.log_prob(BATCH.action).unsqueeze(-1) # [T, B, 1] entropy = dist.entropy().unsqueeze(-1) # [T, B, 1] else: pi = pi / self.boltzmann_temperature # [T, B, A] log_pi = pi.log_softmax(-1) # [T, B, A] entropy = -(log_pi.exp() * log_pi).sum(-1, keepdim=True) # [T, B, 1] log_p = (log_pi * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] pi_loss = -(log_p * adv + self.ent_coff * entropy).mean() # 1 self.intra_option_oplr.optimize(pi_loss) beta = self.termination_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] beta_s = (beta * BATCH.last_options).sum(-1, keepdim=True) # [T, B, 1] interests = self.interest_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] # [T, B, P] or q.softmax(-1) pi_op = (interests * q.detach()).softmax(-1) interest_loss = -(beta_s.detach() * (pi_op * BATCH.options).sum(-1, keepdim=True) * q_s).mean() # 1 self.interest_oplr.optimize(interest_loss) v_s = (q * pi_op).sum(-1, keepdim=True) # [T, B, 1] beta_loss = beta_s * (q_s - v_s).detach() # [T, B, 1] if self.terminal_mask: beta_loss *= (1 - BATCH.done) # [T, B, 1] beta_loss = beta_loss.mean() # 1 self.termination_oplr.optimize(beta_loss) return td_error, { 'LEARNING_RATE/q_lr': self.q_oplr.lr, 'LEARNING_RATE/intra_option_lr': self.intra_option_oplr.lr, 'LEARNING_RATE/termination_lr': self.termination_oplr.lr, # 'Statistics/option': self.options[0], 'LOSS/q_loss': q_loss, 'LOSS/pi_loss': pi_loss, 'LOSS/beta_loss': beta_loss, 'LOSS/interest_loss': interest_loss, 'Statistics/q_option_max': q_s.max(), 'Statistics/q_option_min': q_s.min(), 'Statistics/q_option_mean': q_s.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class DreamerV1(SarlOffPolicy): """ Dream to Control: Learning Behaviors by Latent Imagination, http://arxiv.org/abs/1912.01603 """ policy_mode = 'off-policy' def __init__(self, eps_init: float = 1, eps_mid: float = 0.2, eps_final: float = 0.01, init2mid_annealing_step: int = 1000, stoch_dim=30, deter_dim=200, model_lr=6e-4, actor_lr=8e-5, critic_lr=8e-5, kl_free_nats=3, action_sigma=0.3, imagination_horizon=15, lambda_=0.95, kl_scale=1.0, reward_scale=1.0, use_pcont=False, pcont_scale=10.0, network_settings=dict(), **kwargs): super().__init__(**kwargs) assert self.use_rnn == False, 'assert self.use_rnn == False' if self.obs_spec.has_visual_observation \ and len(self.obs_spec.visual_dims) == 1 \ and not self.obs_spec.has_vector_observation: visual_dim = self.obs_spec.visual_dims[0] # TODO: optimize this assert visual_dim[0] == visual_dim[ 1] == 64, 'visual dimension must be [64, 64, *]' self._is_visual = True elif self.obs_spec.has_vector_observation \ and len(self.obs_spec.vector_dims) == 1 \ and not self.obs_spec.has_visual_observation: self._is_visual = False else: raise ValueError("please check the observation type") self.stoch_dim = stoch_dim self.deter_dim = deter_dim self.kl_free_nats = kl_free_nats self.imagination_horizon = imagination_horizon self.lambda_ = lambda_ self.kl_scale = kl_scale self.reward_scale = reward_scale # https://github.com/danijar/dreamer/issues/2 self.use_pcont = use_pcont # probability of continuing self.pcont_scale = pcont_scale self._action_sigma = action_sigma self._network_settings = network_settings if not self.is_continuous: self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) if self.obs_spec.has_visual_observation: from rls.nn.dreamer import VisualDecoder, VisualEncoder self.obs_encoder = VisualEncoder( self.obs_spec.visual_dims[0], **network_settings['obs_encoder']['visual']).to(self.device) self.obs_decoder = VisualDecoder( self.decoder_input_dim, self.obs_spec.visual_dims[0], **network_settings['obs_decoder']['visual']).to(self.device) else: from rls.nn.dreamer import VectorEncoder self.obs_encoder = VectorEncoder( self.obs_spec.vector_dims[0], **network_settings['obs_encoder']['vector']).to(self.device) self.obs_decoder = DenseModel( self.decoder_input_dim, self.obs_spec.vector_dims[0], **network_settings['obs_decoder']['vector']).to(self.device) self.rssm = self._dreamer_build_rssm() """ p(r_t | s_t, h_t) Reward model to predict reward from state and rnn hidden state """ self.reward_predictor = DenseModel(self.decoder_input_dim, 1, **network_settings['reward']).to( self.device) self.actor = ActionDecoder(self.a_dim, self.decoder_input_dim, dist=self._action_dist, **network_settings['actor']).to(self.device) self.critic = self._dreamer_build_critic() _modules = [ self.obs_encoder, self.rssm, self.obs_decoder, self.reward_predictor ] if self.use_pcont: self.pcont_decoder = DenseModel(self.decoder_input_dim, 1, **network_settings['pcont']).to( self.device) _modules.append(self.pcont_decoder) self.model_oplr = OPLR(_modules, model_lr, **self._oplr_params) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(obs_encoder=self.obs_encoder, obs_decoder=self.obs_decoder, reward_predictor=self.reward_predictor, rssm=self.rssm, actor=self.actor, critic=self.critic, model_oplr=self.model_oplr, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) if self.use_pcont: self._trainer_modules.update(pcont_decoder=self.pcont_decoder) @property def _action_dist(self): return 'tanh_normal' if self.is_continuous else 'one_hot' # 'relaxed_one_hot' @property def decoder_input_dim(self): return self.stoch_dim + self.deter_dim def _dreamer_build_rssm(self): return RecurrentStateSpaceModel( self.stoch_dim, self.deter_dim, self.a_dim, self.obs_encoder.h_dim, **self._network_settings['rssm']).to(self.device) def _dreamer_build_critic(self): return DenseModel(self.decoder_input_dim, 1, **self._network_settings['critic']).to(self.device) @iton def select_action(self, obs): if self._is_visual: obs = get_first_visual(obs) else: obs = get_first_vector(obs) embedded_obs = self.obs_encoder(obs) # [B, *] state_posterior = self.rssm.posterior(self.rnncs['hx'], embedded_obs) state = state_posterior.sample() # [B, *] actions = self.actor.sample_actions(th.cat((state, self.rnncs['hx']), -1), is_train=self._is_train_mode) actions = self._exploration(actions) _, self.rnncs_['hx'] = self.rssm.prior(state, actions, self.rnncs['hx']) if not self.is_continuous: actions = actions.argmax(-1) # [B,] return actions, Data(action=actions) def _exploration(self, action: th.Tensor) -> th.Tensor: """ :param action: action to take, shape (1,) (if categorical), or (action dim,) (if continuous) :return: action of the same shape passed in, augmented with some noise """ if self.is_continuous: sigma = self._action_sigma if self._is_train_mode else 0. noise = th.randn(*action.shape) * sigma return th.clamp(action + noise, -1, 1) else: if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): index = th.randint(0, self.a_dim, (self.n_copies, )) action = th.zeros_like(action) action[..., index] = 1 return action @iton def _train(self, BATCH): T, B = BATCH.action.shape[:2] if self._is_visual: obs_ = get_first_visual(BATCH.obs_) else: obs_ = get_first_vector(BATCH.obs_) # embed observations with CNN embedded_observations = self.obs_encoder(obs_) # [T, B, *] # initialize state and rnn hidden state with 0 vector state, rnn_hidden = self.rssm.init_state(shape=B) # [B, S], [B, D] # compute state and rnn hidden sequences and kl loss kl_loss = 0 states, rnn_hiddens = [], [] for l in range(T): # if the begin of this episode, then reset to 0. # No matther whether last episode is beened truncated of not. state = state * (1. - BATCH.begin_mask[l]) # [B, S] rnn_hidden = rnn_hidden * (1. - BATCH.begin_mask[l]) # [B, D] next_state_prior, next_state_posterior, rnn_hidden = self.rssm( state, BATCH.action[l], rnn_hidden, embedded_observations[l]) # a, s_ state = next_state_posterior.rsample() # [B, S] posterior of s_ states.append(state) # [B, S] rnn_hiddens.append(rnn_hidden) # [B, D] kl_loss += self._kl_loss(next_state_prior, next_state_posterior) kl_loss /= T # 1 # compute reconstructed observations and predicted rewards post_feat = th.cat([th.stack(states, 0), th.stack(rnn_hiddens, 0)], -1) # [T, B, *] obs_pred = self.obs_decoder(post_feat) # [T, B, C, H, W] or [T, B, *] reward_pred = self.reward_predictor(post_feat) # [T, B, 1], s_ => r # compute loss for observation and reward obs_loss = -th.mean(obs_pred.log_prob(obs_)) # [T, B] => 1 # [T, B, 1]=>1 reward_loss = -th.mean( reward_pred.log_prob(BATCH.reward).unsqueeze(-1)) # add all losses and update model parameters with gradient descent model_loss = self.kl_scale * kl_loss + obs_loss + self.reward_scale * reward_loss # 1 if self.use_pcont: pcont_pred = self.pcont_decoder(post_feat) # [T, B, 1], s_ => done # https://github.com/danijar/dreamer/issues/2#issuecomment-605392659 pcont_target = self.gamma * (1. - BATCH.done) # [T, B, 1]=>1 pcont_loss = -th.mean( pcont_pred.log_prob(pcont_target).unsqueeze(-1)) model_loss += self.pcont_scale * pcont_loss self.model_oplr.optimize(model_loss) # remove gradients from previously calculated tensors with th.no_grad(): # [T, B, S] => [T*B, S] flatten_states = th.cat(states, 0).detach() # [T, B, D] => [T*B, D] flatten_rnn_hiddens = th.cat(rnn_hiddens, 0).detach() with FreezeParameters(self.model_oplr.parameters): # compute target values imaginated_states = [] imaginated_rnn_hiddens = [] log_probs = [] entropies = [] for h in range(self.imagination_horizon): imaginated_states.append(flatten_states) # [T*B, S] imaginated_rnn_hiddens.append(flatten_rnn_hiddens) # [T*B, D] flatten_feat = th.cat([flatten_states, flatten_rnn_hiddens], -1).detach() action_dist = self.actor(flatten_feat) actions = action_dist.rsample() # [T*B, A] log_probs.append( action_dist.log_prob( actions.detach()).unsqueeze(-1)) # [T*B, 1] entropies.append( action_dist.entropy().unsqueeze(-1)) # [T*B, 1] flatten_states_prior, flatten_rnn_hiddens = self.rssm.prior( flatten_states, actions, flatten_rnn_hiddens) flatten_states = flatten_states_prior.rsample() # [T*B, S] imaginated_states = th.stack(imaginated_states, 0) # [H, T*B, S] imaginated_rnn_hiddens = th.stack(imaginated_rnn_hiddens, 0) # [H, T*B, D] log_probs = th.stack(log_probs, 0) # [H, T*B, 1] entropies = th.stack(entropies, 0) # [H, T*B, 1] imaginated_feats = th.cat([imaginated_states, imaginated_rnn_hiddens], -1) # [H, T*B, *] with FreezeParameters(self.model_oplr.parameters + self.critic_oplr.parameters): imaginated_rewards = self.reward_predictor( imaginated_feats).mean # [H, T*B, 1] imaginated_values = self._dreamer_target_img_value( imaginated_feats) # [H, T*B, 1]] # Compute the exponential discounted sum of rewards if self.use_pcont: with FreezeParameters(self.pcont_decoder.parameters()): discount_arr = self.pcont_decoder( imaginated_feats).mean # [H, T*B, 1] else: discount_arr = self.gamma * th.ones_like( imaginated_rewards) # [H, T*B, 1] returns = compute_return(imaginated_rewards[:-1], imaginated_values[:-1], discount_arr[:-1], bootstrap=imaginated_values[-1], lambda_=self.lambda_) # [H-1, T*B, 1] # Make the top row 1 so the cumulative product starts with discount^0 discount_arr = th.cat( [th.ones_like(discount_arr[:1]), discount_arr[:-1]], 0) # [H, T*B, 1] discount = th.cumprod(discount_arr, 0).detach()[:-1] # [H-1, T*B, 1] # discount_arr = th.cat([th.ones_like(discount_arr[:1]), discount_arr[1:]]) # discount = th.cumprod(discount_arr[:-1], 0) actor_loss = self._dreamer_build_actor_loss(imaginated_feats, log_probs, entropies, discount, returns) # 1 # Don't let gradients pass through to prevent overwriting gradients. # Value Loss with th.no_grad(): value_feat = imaginated_feats[:-1].detach() # [H-1, T*B, 1] value_target = returns.detach() # [H-1, T*B, 1] value_pred = self.critic(value_feat) # [H-1, T*B, 1] log_prob = value_pred.log_prob(value_target).unsqueeze( -1) # [H-1, T*B, 1] critic_loss = -th.mean(discount * log_prob) # 1 self.actor_oplr.zero_grad() self.critic_oplr.zero_grad() self.actor_oplr.backward(actor_loss) self.critic_oplr.backward(critic_loss) self.actor_oplr.step() self.critic_oplr.step() td_error = (value_pred.mean - value_target).mean(0).detach() # [T*B,] td_error = td_error.view(T, B, 1) summaries = { 'LEARNING_RATE/model_lr': self.model_oplr.lr, 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/model_loss': model_loss, 'LOSS/kl_loss': kl_loss, 'LOSS/obs_loss': obs_loss, 'LOSS/reward_loss': reward_loss, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss } if self.use_pcont: summaries.update({'LOSS/pcont_loss', pcont_loss}) return td_error, summaries def _initial_rnncs(self, batch: int) -> Dict[str, np.ndarray]: return {'hx': np.zeros((batch, self.deter_dim))} def _kl_loss(self, prior_dist, post_dist): # 1 return td.kl_divergence(prior_dist, post_dist).clamp(min=self.kl_free_nats).mean() def _dreamer_target_img_value(self, imaginated_feats): imaginated_values = self.critic(imaginated_feats).mean # [H, T*B, 1] return imaginated_values def _dreamer_build_actor_loss(self, imaginated_feats, log_probs, entropies, discount, returns): actor_loss = -th.mean(discount * returns) # [H-1, T*B, 1] => 1 return actor_loss
class VDN(MultiAgentOffPolicy): """ Value-Decomposition Networks For Cooperative Multi-Agent Learning, http://arxiv.org/abs/1706.05296 QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning, http://arxiv.org/abs/1803.11485 Qatten: A General Framework for Cooperative Multiagent Reinforcement Learning, http://arxiv.org/abs/2002.03939 """ policy_mode = 'off-policy' def __init__(self, mixer='vdn', mixer_settings={}, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, use_double=True, init2mid_annealing_step=1000, assign_interval=1000, network_settings={ 'share': [128], 'v': [128], 'adv': [128] }, **kwargs): super().__init__(**kwargs) assert not any(list(self.is_continuouss.values()) ), 'VDN only support discrete action space' self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self._use_double = use_double self._mixer_type = mixer self._mixer_settings = mixer_settings self.q_nets = {} for id in set(self.model_ids): self.q_nets[id] = TargetTwin( CriticDueling(self.obs_specs[id], rep_net_params=self._rep_net_params, output_shape=self.a_dims[id], network_settings=network_settings)).to( self.device) self.mixer = self._build_mixer() self.oplr = OPLR( tuple(self.q_nets.values()) + (self.mixer, ), lr, **self._oplr_params) self._trainer_modules.update( {f"model_{id}": self.q_nets[id] for id in set(self.model_ids)}) self._trainer_modules.update(mixer=self.mixer, oplr=self.oplr) def _build_mixer(self): assert self._mixer_type in [ 'vdn', 'qmix', 'qatten' ], "assert self._mixer_type in ['vdn', 'qmix', 'qatten']" if self._mixer_type in ['qmix', 'qatten']: assert self._has_global_state, 'assert self._has_global_state' return TargetTwin(Mixer_REGISTER[self._mixer_type]( n_agents=self.n_agents_percopy, state_spec=self.state_spec, rep_net_params=self._rep_net_params, **self._mixer_settings)).to(self.device) @iton # TODO: optimization def select_action(self, obs): acts_info = {} actions = {} for aid, mid in zip(self.agent_ids, self.model_ids): q_values = self.q_nets[mid](obs[aid], rnncs=self.rnncs[aid]) # [B, A] self.rnncs_[aid] = self.q_nets[mid].get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): action = np.random.randint(0, self.a_dims[aid], self.n_copies) else: action = q_values.argmax(-1) # [B,] actions[aid] = action acts_info[aid] = Data(action=action) return actions, acts_info @iton def _train(self, BATCH_DICT): summaries = {} reward = BATCH_DICT[self.agent_ids[0]].reward # [T, B, 1] done = 0. q_evals = [] q_target_next_choose_maxs = [] for aid, mid in zip(self.agent_ids, self.model_ids): done += BATCH_DICT[aid].done # [T, B, 1] q = self.q_nets[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] q_eval = (q * BATCH_DICT[aid].action).sum( -1, keepdim=True) # [T, B, 1] q_evals.append(q_eval) # N * [T, B, 1] q_target = self.q_nets[mid].t( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] if self._use_double: next_q = self.q_nets[mid]( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] next_max_action = next_q.argmax(-1) # [T, B] next_max_action_one_hot = F.one_hot( next_max_action, self.a_dims[aid]).float() # [T, B, A] q_target_next_max = (q_target * next_max_action_one_hot).sum( -1, keepdim=True) # [T, B, 1] else: # [T, B, 1] q_target_next_max = q_target.max(-1, keepdim=True)[0] q_target_next_choose_maxs.append( q_target_next_max) # N * [T, B, 1] q_evals = th.stack(q_evals, -1) # [T, B, 1, N] q_target_next_choose_maxs = th.stack(q_target_next_choose_maxs, -1) # [T, B, 1, N] q_eval_tot = self.mixer( q_evals, BATCH_DICT['global'].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] q_target_next_max_tot = self.mixer.t( q_target_next_choose_maxs, BATCH_DICT['global'].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] q_target_tot = n_step_return( reward, self.gamma, (done > 0.).float(), q_target_next_max_tot, BATCH_DICT['global'].begin_mask).detach() # [T, B, 1] td_error = q_target_tot - q_eval_tot # [T, B, 1] q_loss = td_error.square().mean() # 1 self.oplr.optimize(q_loss) summaries['model'] = { 'LOSS/q_loss': q_loss, 'Statistics/q_max': q_eval_tot.max(), 'Statistics/q_min': q_eval_tot.min(), 'Statistics/q_mean': q_eval_tot.mean() } return td_error, summaries def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: for q_net in self.q_nets.values(): q_net.sync() self.mixer.sync()
class DDDQN(SarlOffPolicy): """ Dueling Double DQN, https://arxiv.org/abs/1511.06581 """ policy_mode = 'off-policy' def __init__(self, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=2, network_settings={ 'share': [128], 'v': [128], 'adv': [128] }, **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'dueling double dqn only support discrete action space' self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.q_net = TargetTwin(CriticDueling(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings)).to(self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: actions = q_values.argmax(-1) # [B,] return actions, Data(action=actions) @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] next_q = self.q_net(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_target = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] next_max_action = next_q.argmax(-1) # [T, B] next_max_action_one_hot = F.one_hot(next_max_action.squeeze(), self.a_dim).float() # [T, B, A] q_target_next_max = (q_target * next_max_action_one_hot).sum(-1, keepdim=True) # [T, B, 1] q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target_next_max, BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(q_loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': q_loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class NPG(SarlOnPolicy): """ Natural Policy Gradient, NPG https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf """ policy_mode = 'on-policy' def __init__( self, agent_spec, actor_step_size=0.5, beta=1.0e-3, lambda_=0.95, cg_iters=10, damping_coeff=0.1, epsilon=0.2, critic_lr=1e-3, train_critic_iters=10, network_settings={ 'actor_continuous': { 'hidden_units': [64, 64], 'condition_sigma': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [32, 32], 'critic': [32, 32] }, **kwargs): super().__init__(agent_spec=agent_spec, **kwargs) self.actor_step_size = actor_step_size self.beta = beta self.lambda_ = lambda_ self._epsilon = epsilon self._cg_iters = cg_iters self._damping_coeff = damping_coeff self._train_critic_iters = train_critic_iters if self.is_continuous: self.actor = ActorMuLogstd( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to( self.device) else: self.actor = ActorDct( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to( self.device) self.critic = CriticValue( self.obs_spec, rep_net_params=self._rep_net_params, network_settings=network_settings['critic']).to(self.device) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, critic_oplr=self.critic_oplr) @iton def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() value = self.critic(obs, rnncs=self.rnncs) # [B, 1] if self.is_continuous: mu, log_std = output # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] log_prob = dist.log_prob(action).unsqueeze(-1) # [B, 1] else: logits = output # [B, A] logp_all = logits.log_softmax(-1) # [B, A] norm_dist = td.Categorical(logits=logp_all) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action).unsqueeze(-1) # [B, 1] acts_info = Data(action=action, value=value, log_prob=log_prob + th.finfo().eps) if self.use_rnn: acts_info.update(rnncs=self.rnncs) if self.is_continuous: acts_info.update(mu=mu, log_std=log_std) else: acts_info.update(logp_all=logp_all) return action, acts_info @iton def _get_value(self, obs, rnncs=None): value = self.critic(obs, rnncs=rnncs) # [B, 1] return value def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH = super()._preprocess_BATCH(BATCH) value = self._get_value(BATCH.obs_[-1], rnncs=self.rnncs) BATCH.discounted_reward = discounted_sum(BATCH.reward, self.gamma, BATCH.done, BATCH.begin_mask, init_value=value) td_error = calculate_td_error( BATCH.reward, self.gamma, BATCH.done, value=BATCH.value, next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]), 0)) BATCH.gae_adv = discounted_sum(td_error, self.lambda_ * self.gamma, BATCH.done, BATCH.begin_mask, init_value=0., normalize=True) return BATCH @iton def _train(self, BATCH): output = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] if self.is_continuous: mu, log_std = output # [T, B, A], [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) new_log_prob = dist.log_prob(BATCH.action).unsqueeze( -1) # [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = output # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] new_log_prob = (BATCH.action * logp_all).sum( -1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 ratio = (new_log_prob - BATCH.log_prob).exp() # [T, B, 1] actor_loss = -(ratio * BATCH.gae_adv).mean() # 1 flat_grads = grads_flatten(actor_loss, self.actor, retain_graph=True).detach() # [1,] if self.is_continuous: kl = td.kl_divergence( td.Independent(td.Normal(BATCH.mu, BATCH.log_std.exp()), 1), td.Independent(td.Normal(mu, log_std.exp()), 1)).mean() else: kl = (BATCH.logp_all.exp() * (BATCH.logp_all - logp_all)).sum(-1).mean() # 1 flat_kl_grad = grads_flatten(kl, self.actor, create_graph=True) search_direction = -self._conjugate_gradients( flat_grads, flat_kl_grad, cg_iters=self._cg_iters) # [1,] with th.no_grad(): flat_params = th.cat( [param.data.view(-1) for param in self.actor.parameters()]) new_flat_params = flat_params + self.actor_step_size * search_direction set_from_flat_params(self.actor, new_flat_params) for _ in range(self._train_critic_iters): value = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] td_error = BATCH.discounted_reward - value # [T, B, 1] critic_loss = td_error.square().mean() # 1 self.critic_oplr.optimize(critic_loss) return { 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/entropy': entropy.mean(), 'LEARNING_RATE/critic_lr': self.critic_oplr.lr } def _conjugate_gradients(self, flat_grads, flat_kl_grad, cg_iters: int = 10, residual_tol: float = 1e-10): """ Conjugate gradient algorithm (see https://en.wikipedia.org/wiki/Conjugate_gradient_method) """ x = th.zeros_like(flat_grads) r, p = flat_grads.clone(), flat_grads.clone() # Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0. # Change if doing warm start. rdotr = r.dot(r) for i in range(cg_iters): z = self._MVP(p, flat_kl_grad) alpha = rdotr / (p.dot(z) + th.finfo().eps) x += alpha * p r -= alpha * z new_rdotr = r.dot(r) if new_rdotr < residual_tol: break p = r + new_rdotr / rdotr * p rdotr = new_rdotr return x def _MVP(self, v, flat_kl_grad): """Matrix vector product.""" # caculate second order gradient of kl with respect to theta kl_v = (flat_kl_grad * v).sum() mvp = grads_flatten(kl_v, self.actor, retain_graph=True).detach() mvp += max(0, self._damping_coeff) * v return mvp
class BootstrappedDQN(SarlOffPolicy): """ Deep Exploration via Bootstrapped DQN, http://arxiv.org/abs/1602.04621 """ policy_mode = 'off-policy' def __init__(self, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=1000, head_num=4, network_settings=[32, 32], **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'Bootstrapped DQN only support discrete action space' self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.head_num = head_num self._probs = th.FloatTensor([1. / head_num for _ in range(head_num)]) self.now_head = 0 self.q_net = TargetTwin( CriticQvalueBootstrap(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, head_num=self.head_num, network_settings=network_settings)).to( self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) def episode_reset(self): super().episode_reset() self.now_head = np.random.randint(self.head_num) @iton def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [H, B, A] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: # [H, B, A] => [B, A] => [B, ] actions = q_values[self.now_head].argmax(-1) return actions, Data(action=actions) @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask).mean( 0) # [H, T, B, A] => [T, B, A] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask).mean( 0) # [H, T, B, A] => [T, B, A] # [T, B, A] * [T, B, A] => [T, B, 1] q_eval = (q * BATCH.action).sum(-1, keepdim=True) q_target = n_step_return( BATCH.reward, self.gamma, BATCH.done, # [T, B, A] => [T, B, 1] q_next.max(-1, keepdim=True)[0], BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 # mask_dist = td.Bernoulli(probs=self._probs) # TODO: # mask = mask_dist.sample([batch_size]).T # [H, B] self.oplr.optimize(q_loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': q_loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class PG(SarlOnPolicy): policy_mode = 'on-policy' def __init__( self, agent_spec, lr=5.0e-4, network_settings={ 'actor_continuous': { 'hidden_units': [32, 32], 'condition_sigma': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [32, 32] }, **kwargs): super().__init__(agent_spec=agent_spec, **kwargs) if self.is_continuous: self.net = ActorMuLogstd( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to( self.device) else: self.net = ActorDct( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to( self.device) self.oplr = OPLR(self.net, lr, **self._oplr_params) self._trainer_modules.update(model=self.net, oplr=self.oplr) @iton def select_action(self, obs): output = self.net(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.net.get_rnncs() if self.is_continuous: mu, log_std = output # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] else: logits = output # [B, A] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] acts_info = Data(action=action) if self.use_rnn: acts_info.update(rnncs=self.rnncs) return action, acts_info def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH = super()._preprocess_BATCH(BATCH) BATCH.discounted_reward = discounted_sum(BATCH.reward, self.gamma, BATCH.done, BATCH.begin_mask, init_value=0., normalize=True) return BATCH @iton def _train(self, BATCH): # [B, T, *] output = self.net(BATCH.obs, begin_mask=BATCH.begin_mask) # [B, T, A] if self.is_continuous: mu, log_std = output # [B, T, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_act_prob = dist.log_prob(BATCH.action).unsqueeze( -1) # [B, T, 1] entropy = dist.entropy().unsqueeze(-1) # [B, T, 1] else: logits = output # [B, T, A] logp_all = logits.log_softmax(-1) # [B, T, A] log_act_prob = (logp_all * BATCH.action).sum( -1, keepdim=True) # [B, T, 1] entropy = -(logp_all.exp() * logp_all).sum( 1, keepdim=True) # [B, T, 1] loss = -(log_act_prob * BATCH.discounted_reward).mean() self.oplr.optimize(loss) return { 'LOSS/loss': loss, 'Statistics/entropy': entropy.mean(), 'LEARNING_RATE/lr': self.oplr.lr }
class MASAC(MultiAgentOffPolicy): policy_mode = 'off-policy' def __init__( self, alpha=0.2, annealing=True, last_alpha=0.01, polyak=0.995, discrete_tau=1.0, network_settings={ 'actor_continuous': { 'share': [128, 128], 'mu': [64], 'log_std': [64], 'soft_clip': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [64, 32], 'q': [128, 128] }, auto_adaption=True, actor_lr=5.0e-4, critic_lr=1.0e-3, alpha_lr=5.0e-4, **kwargs): """ TODO: Annotation """ super().__init__(**kwargs) self.polyak = polyak self.discrete_tau = discrete_tau self.auto_adaption = auto_adaption self.annealing = annealing self.target_entropy = 0.98 for id in self.agent_ids: if self.is_continuouss[id]: self.target_entropy *= (-self.a_dims[id]) else: self.target_entropy *= np.log(self.a_dims[id]) self.actors, self.critics, self.critics2 = {}, {}, {} for id in set(self.model_ids): if self.is_continuouss[id]: self.actors[id] = ActorCts( self.obs_specs[id], rep_net_params=self._rep_net_params, output_shape=self.a_dims[id], network_settings=network_settings['actor_continuous']).to( self.device) else: self.actors[id] = ActorDct( self.obs_specs[id], rep_net_params=self._rep_net_params, output_shape=self.a_dims[id], network_settings=network_settings['actor_discrete']).to( self.device) self.critics[id] = TargetTwin( MACriticQvalueOne(list(self.obs_specs.values()), rep_net_params=self._rep_net_params, action_dim=sum(self.a_dims.values()), network_settings=network_settings['q']), self.polyak).to(self.device) self.critics2[id] = deepcopy(self.critics[id]) self.actor_oplr = OPLR(list(self.actors.values()), actor_lr, **self._oplr_params) self.critic_oplr = OPLR( list(self.critics.values()) + list(self.critics2.values()), critic_lr, **self._oplr_params) if self.auto_adaption: self.log_alpha = th.tensor(0., requires_grad=True).to(self.device) self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params) self._trainer_modules.update(alpha_oplr=self.alpha_oplr) else: self.log_alpha = th.tensor(alpha).log().to(self.device) if self.annealing: self.alpha_annealing = LinearAnnealing(alpha, last_alpha, int(1e6)) self._trainer_modules.update( {f"actor_{id}": self.actors[id] for id in set(self.model_ids)}) self._trainer_modules.update( {f"critic_{id}": self.critics[id] for id in set(self.model_ids)}) self._trainer_modules.update( {f"critic2_{id}": self.critics2[id] for id in set(self.model_ids)}) self._trainer_modules.update(log_alpha=self.log_alpha, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) @property def alpha(self): return self.log_alpha.exp() @iton def select_action(self, obs: Dict): acts_info = {} actions = {} for aid, mid in zip(self.agent_ids, self.model_ids): output = self.actors[mid](obs[aid], rnncs=self.rnncs[aid]) # [B, A] self.rnncs_[aid] = self.actors[mid].get_rnncs() if self.is_continuouss[aid]: mu, log_std = output # [B, A] pi = td.Normal(mu, log_std.exp()).sample().tanh() mu.tanh_() # squash mu # [B, A] else: logits = output # [B, A] mu = logits.argmax(-1) # [B,] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] action = pi if self._is_train_mode else mu acts_info[aid] = Data(action=action) actions[aid] = action return actions, acts_info @iton def _train(self, BATCH_DICT): """ TODO: Annotation """ summaries = defaultdict(dict) target_actions = {} target_log_pis = 1. for aid, mid in zip(self.agent_ids, self.model_ids): if self.is_continuouss[aid]: target_mu, target_log_std = self.actors[mid]( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] dist = td.Independent( td.Normal(target_mu, target_log_std.exp()), 1) target_pi = dist.sample() # [T, B, A] target_pi, target_log_pi = squash_action( target_pi, dist.log_prob(target_pi).unsqueeze( -1)) # [T, B, A], [T, B, 1] else: target_logits = self.actors[mid]( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze( -1) # [T, B, 1] target_pi = F.one_hot(target_pi, self.a_dims[aid]).float() # [T, B, A] target_actions[aid] = target_pi target_log_pis *= target_log_pi target_log_pis += th.finfo().eps target_actions = th.cat(list(target_actions.values()), -1) # [T, B, N*A] qs1, qs2, q_targets1, q_targets2 = {}, {}, {}, {} for mid in self.model_ids: qs1[mid] = self.critics[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat([BATCH_DICT[id].action for id in self.agent_ids], -1)) # [T, B, 1] qs2[mid] = self.critics2[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat([BATCH_DICT[id].action for id in self.agent_ids], -1)) # [T, B, 1] q_targets1[mid] = self.critics[mid].t( [BATCH_DICT[id].obs_ for id in self.agent_ids], target_actions) # [T, B, 1] q_targets2[mid] = self.critics2[mid].t( [BATCH_DICT[id].obs_ for id in self.agent_ids], target_actions) # [T, B, 1] q_loss = {} td_errors = 0. for aid, mid in zip(self.agent_ids, self.model_ids): q_target = th.minimum(q_targets1[mid], q_targets2[mid]) # [T, B, 1] dc_r = n_step_return( BATCH_DICT[aid].reward, self.gamma, BATCH_DICT[aid].done, q_target - self.alpha * target_log_pis, BATCH_DICT['global'].begin_mask).detach() # [T, B, 1] td_error1 = qs1[mid] - dc_r # [T, B, 1] td_error2 = qs2[mid] - dc_r # [T, B, 1] td_errors += (td_error1 + td_error2) / 2 q1_loss = td_error1.square().mean() # 1 q2_loss = td_error2.square().mean() # 1 q_loss[aid] = 0.5 * q1_loss + 0.5 * q2_loss summaries[aid].update({ 'Statistics/q_min': qs1[mid].min(), 'Statistics/q_mean': qs1[mid].mean(), 'Statistics/q_max': qs1[mid].max() }) self.critic_oplr.optimize(sum(q_loss.values())) log_pi_actions = {} log_pis = {} sample_pis = {} for aid, mid in zip(self.agent_ids, self.model_ids): if self.is_continuouss[aid]: mu, log_std = self.actors[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action( pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] pi_action = BATCH_DICT[aid].action.arctanh() _, log_pi_action = squash_action( pi_action, dist.log_prob(pi_action).unsqueeze( -1)) # [T, B, A], [T, B, 1] else: logits = self.actors[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot( _pi.argmax(-1), self.a_dims[aid]).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] log_pi_action = (logp_all * BATCH_DICT[aid].action).sum( -1, keepdim=True) # [T, B, 1] log_pi_actions[aid] = log_pi_action log_pis[aid] = log_pi sample_pis[aid] = pi actor_loss = {} for aid, mid in zip(self.agent_ids, self.model_ids): all_actions = {id: BATCH_DICT[id].action for id in self.agent_ids} all_actions[aid] = sample_pis[aid] all_log_pis = {id: log_pi_actions[id] for id in self.agent_ids} all_log_pis[aid] = log_pis[aid] q_s_pi = th.minimum( self.critics[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat(list(all_actions.values()), -1), begin_mask=BATCH_DICT['global'].begin_mask), self.critics2[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat(list(all_actions.values()), -1), begin_mask=BATCH_DICT['global'].begin_mask)) # [T, B, 1] _log_pis = 1. for _log_pi in all_log_pis.values(): _log_pis *= _log_pi _log_pis += th.finfo().eps actor_loss[aid] = -(q_s_pi - self.alpha * _log_pis).mean() # 1 self.actor_oplr.optimize(sum(actor_loss.values())) for aid in self.agent_ids: summaries[aid].update({ 'LOSS/actor_loss': actor_loss[aid], 'LOSS/critic_loss': q_loss[aid] }) summaries['model'].update({ 'LOSS/actor_loss': sum(actor_loss.values()), 'LOSS/critic_loss': sum(q_loss.values()) }) if self.auto_adaption: _log_pis = 1. _log_pis = 1. for _log_pi in log_pis.values(): _log_pis *= _log_pi _log_pis += th.finfo().eps alpha_loss = -( self.alpha * (_log_pis + self.target_entropy).detach()).mean() # 1 self.alpha_oplr.optimize(alpha_loss) summaries['model'].update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return td_errors / self.n_agents_percopy, summaries def _after_train(self): super()._after_train() if self.annealing and not self.auto_adaption: self.log_alpha.copy_( self.alpha_annealing(self._cur_train_step).log()) for critic in self.critics.values(): critic.sync() for critic2 in self.critics2.values(): critic2.sync()
class MVE(DDPG): """ Model-Based Value Estimation for Efficient Model-Free Reinforcement Learning, http://arxiv.org/abs/1803.00101 """ policy_mode = 'off-policy' def __init__(self, wm_lr=1e-3, roll_out_horizon=15, **kwargs): super().__init__(**kwargs) network_settings = kwargs.get('network_settings', {}) assert not self.obs_spec.has_visual_observation, "assert not self.obs_spec.has_visual_observation" assert self.obs_spec.has_vector_observation, "assert self.obs_spec.has_vector_observation" self._wm_lr = wm_lr self._roll_out_horizon = roll_out_horizon self._forward_dynamic_model = VectorSA2S( self.obs_spec.vector_dims[0], self.a_dim, hidden_units=network_settings['forward_model']) self._reward_model = VectorSA2R( self.obs_spec.vector_dims[0], self.a_dim, hidden_units=network_settings['reward_model']) self._done_model = VectorSA2D( self.obs_spec.vector_dims[0], self.a_dim, hidden_units=network_settings['done_model']) self._wm_oplr = OPLR([ self._forward_dynamic_model, self._reward_model, self._done_model ], self._wm_lr, **self._oplr_params) self._trainer_modules.update( _forward_dynamic_model=self._forward_dynamic_model, _reward_model=self._reward_model, _done_model=self._done_model, _wm_oplr=self._wm_oplr) @iton def _train(self, BATCH): obs = get_first_vector(BATCH.obs) # [T, B, S] obs_ = get_first_vector(BATCH.obs_) # [T, B, S] _timestep = obs.shape[0] _batchsize = obs.shape[1] predicted_obs_ = self._forward_dynamic_model(obs, BATCH.action) # [T, B, S] predicted_reward = self._reward_model(obs, BATCH.action) # [T, B, 1] predicted_done_dist = self._done_model(obs, BATCH.action) # [T, B, 1] _obs_loss = F.mse_loss(obs_, predicted_obs_) # todo _reward_loss = F.mse_loss(BATCH.reward, predicted_reward) _done_loss = -predicted_done_dist.log_prob(BATCH.done).mean() wm_loss = _obs_loss + _reward_loss + _done_loss self._wm_oplr.optimize(wm_loss) obs = th.reshape(obs, (_timestep * _batchsize, -1)) # [T*B, S] obs_ = th.reshape(obs_, (_timestep * _batchsize, -1)) # [T*B, S] actions = th.reshape(BATCH.action, (_timestep * _batchsize, -1)) # [T*B, A] rewards = th.reshape(BATCH.reward, (_timestep * _batchsize, -1)) # [T*B, 1] dones = th.reshape(BATCH.done, (_timestep * _batchsize, -1)) # [T*B, 1] rollout_rewards = [rewards] rollout_dones = [dones] r_obs_ = obs_ _r_obs = deepcopy(BATCH.obs_) r_done = (1. - dones) for _ in range(self._roll_out_horizon): r_obs = r_obs_ _r_obs.vector.vector_0 = r_obs if self.is_continuous: action_target = self.actor.t(_r_obs) # [T*B, A] if self.use_target_action_noise: r_action = self.target_noised_action( action_target) # [T*B, A] else: target_logits = self.actor.t(_r_obs) # [T*B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T*B,] r_action = F.one_hot(target_pi, self.a_dim).float() # [T*B, A] r_obs_ = self._forward_dynamic_model(r_obs, r_action) # [T*B, S] r_reward = self._reward_model(r_obs, r_action) # [T*B, 1] r_done = r_done * (1. - self._done_model(r_obs, r_action).sample() ) # [T*B, 1] rollout_rewards.append(r_reward) # [H+1, T*B, 1] rollout_dones.append(r_done) # [H+1, T*B, 1] _r_obs.vector.vector_0 = obs q = self.critic(_r_obs, actions) # [T*B, 1] _r_obs.vector.vector_0 = r_obs_ q_target = self.critic.t(_r_obs, r_action) # [T*B, 1] dc_r = rewards for t in range(1, self._roll_out_horizon): dc_r += (self.gamma**t) * (rollout_rewards[t] * rollout_dones[t]) dc_r += (self.gamma**self._roll_out_horizon) * rollout_dones[ self._roll_out_horizon] * q_target # [T*B, 1] td_error = dc_r - q # [T*B, 1] q_loss = td_error.square().mean() # 1 self.critic_oplr.optimize(q_loss) # train actor if self.is_continuous: mu = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] q_actor = self.critic(BATCH.obs, mu, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -q_actor.mean() # 1 self.actor_oplr.optimize(actor_loss) return th.ones_like(BATCH.reward), { 'LEARNING_RATE/wm_lr': self._wm_oplr.lr, 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/wm_loss': wm_loss, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': q_loss, 'Statistics/q_min': q.min(), 'Statistics/q_mean': q.mean(), 'Statistics/q_max': q.max() }
class PPO(SarlOnPolicy): """ Proximal Policy Optimization, https://arxiv.org/abs/1707.06347 Emergence of Locomotion Behaviours in Rich Environments, http://arxiv.org/abs/1707.02286, DPPO """ policy_mode = 'on-policy' def __init__(self, agent_spec, ent_coef: float = 1.0e-2, vf_coef: float = 0.5, lr: float = 5.0e-4, lambda_: float = 0.95, epsilon: float = 0.2, use_duel_clip: bool = False, duel_epsilon: float = 0., use_vclip: bool = False, value_epsilon: float = 0.2, share_net: bool = True, actor_lr: float = 3e-4, critic_lr: float = 1e-3, kl_reverse: bool = False, kl_target: float = 0.02, kl_target_cutoff: float = 2, kl_target_earlystop: float = 4, kl_beta: List[float] = [0.7, 1.3], kl_alpha: float = 1.5, kl_coef: float = 1.0, extra_coef: float = 1000.0, use_kl_loss: bool = False, use_extra_loss: bool = False, use_early_stop: bool = False, network_settings: Dict = { 'share': { 'continuous': { 'condition_sigma': False, 'log_std_bound': [-20, 2], 'share': [32, 32], 'mu': [32, 32], 'v': [32, 32] }, 'discrete': { 'share': [32, 32], 'logits': [32, 32], 'v': [32, 32] } }, 'actor_continuous': { 'hidden_units': [64, 64], 'condition_sigma': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [32, 32], 'critic': [32, 32] }, **kwargs): super().__init__(agent_spec=agent_spec, **kwargs) self._ent_coef = ent_coef self.lambda_ = lambda_ assert 0.0 <= lambda_ <= 1.0, "GAE lambda should be in [0, 1]." self._epsilon = epsilon self._use_vclip = use_vclip self._value_epsilon = value_epsilon self._share_net = share_net self._kl_reverse = kl_reverse self._kl_target = kl_target self._kl_alpha = kl_alpha self._kl_coef = kl_coef self._extra_coef = extra_coef self._vf_coef = vf_coef self._use_duel_clip = use_duel_clip self._duel_epsilon = duel_epsilon if self._use_duel_clip: assert - \ self._epsilon < self._duel_epsilon < self._epsilon, "duel_epsilon should be set in the range of (-epsilon, epsilon)." self._kl_cutoff = kl_target * kl_target_cutoff self._kl_stop = kl_target * kl_target_earlystop self._kl_low = kl_target * kl_beta[0] self._kl_high = kl_target * kl_beta[-1] self._use_kl_loss = use_kl_loss self._use_extra_loss = use_extra_loss self._use_early_stop = use_early_stop if self._share_net: if self.is_continuous: self.net = ActorCriticValueCts(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['share']['continuous']).to(self.device) else: self.net = ActorCriticValueDct(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['share']['discrete']).to(self.device) self.oplr = OPLR(self.net, lr, **self._oplr_params) self._trainer_modules.update(model=self.net, oplr=self.oplr) else: if self.is_continuous: self.actor = ActorMuLogstd(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to(self.device) else: self.actor = ActorDct(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to(self.device) self.critic = CriticValue(self.obs_spec, rep_net_params=self._rep_net_params, network_settings=network_settings['critic']).to(self.device) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) @iton def select_action(self, obs): if self.is_continuous: if self._share_net: mu, log_std, value = self.net(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.net.get_rnncs() else: mu, log_std = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() value = self.critic(obs, rnncs=self.rnncs) # [B, 1] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] log_prob = dist.log_prob(action).unsqueeze(-1) # [B, 1] else: if self._share_net: logits, value = self.net(obs, rnncs=self.rnncs) # [B, A], [B, 1] self.rnncs_ = self.net.get_rnncs() else: logits = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() value = self.critic(obs, rnncs=self.rnncs) # [B, 1] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action).unsqueeze(-1) # [B, 1] acts_info = Data(action=action, value=value, log_prob=log_prob + th.finfo().eps) if self.use_rnn: acts_info.update(rnncs=self.rnncs) return action, acts_info @iton def _get_value(self, obs, rnncs=None): if self._share_net: if self.is_continuous: _, _, value = self.net(obs, rnncs=rnncs) # [B, 1] else: _, value = self.net(obs, rnncs=rnncs) # [B, 1] else: value = self.critic(obs, rnncs=rnncs) # [B, 1] return value def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH = super()._preprocess_BATCH(BATCH) value = self._get_value(BATCH.obs_[-1], rnncs=self.rnncs) BATCH.discounted_reward = discounted_sum(BATCH.reward, self.gamma, BATCH.done, BATCH.begin_mask, init_value=value) td_error = calculate_td_error(BATCH.reward, self.gamma, BATCH.done, value=BATCH.value, next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]), 0)) BATCH.gae_adv = discounted_sum(td_error, self.lambda_ * self.gamma, BATCH.done, BATCH.begin_mask, init_value=0., normalize=True) return BATCH def learn(self, BATCH: Data): BATCH = self._preprocess_BATCH(BATCH) # [T, B, *] for _ in range(self._epochs): kls = [] for _BATCH in BATCH.sample(self._chunk_length, self.batch_size, repeat=self._sample_allow_repeat): _BATCH = self._before_train(_BATCH) summaries, kl = self._train(_BATCH) kls.append(kl) self.summaries.update(summaries) self._after_train() if self._use_early_stop and sum(kls) / len(kls) > self._kl_stop: break def _train(self, BATCH): if self._share_net: summaries, kl = self.train_share(BATCH) else: summaries = dict() actor_summaries, kl = self.train_actor(BATCH) critic_summaries = self.train_critic(BATCH) summaries.update(actor_summaries) summaries.update(critic_summaries) if self._use_kl_loss: # ref: https://github.com/joschu/modular_rl/blob/6970cde3da265cf2a98537250fea5e0c0d9a7639/modular_rl/ppo.py#L93 if kl > self._kl_high: self._kl_coef *= self._kl_alpha elif kl < self._kl_low: self._kl_coef /= self._kl_alpha summaries.update({ 'Statistics/kl_coef': self._kl_coef }) return summaries, kl @iton def train_share(self, BATCH): if self.is_continuous: # [T, B, A], [T, B, A], [T, B, 1] mu, log_std, value = self.net(BATCH.obs, begin_mask=BATCH.begin_mask) dist = td.Independent(td.Normal(mu, log_std.exp()), 1) new_log_prob = dist.log_prob(BATCH.action).unsqueeze(-1) # [T, B, 1] entropy = dist.entropy().unsqueeze(-1) # [T, B, 1] else: # [T, B, A], [T, B, 1] logits, value = self.net(BATCH.obs, begin_mask=BATCH.begin_mask) logp_all = logits.log_softmax(-1) # [T, B, 1] new_log_prob = (BATCH.action * logp_all).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1, keepdim=True) # [T, B, 1] ratio = (new_log_prob - BATCH.log_prob).exp() # [T, B, 1] surrogate = ratio * BATCH.gae_adv # [T, B, 1] clipped_surrogate = th.minimum( surrogate, ratio.clamp(1.0 - self._epsilon, 1.0 + self._epsilon) * BATCH.gae_adv ) # [T, B, 1] # ref: https://github.com/thu-ml/tianshou/blob/c97aa4065ee8464bd5897bb86f1f81abd8e2cff9/tianshou/policy/modelfree/ppo.py#L159 if self._use_duel_clip: clipped_surrogate2 = th.maximum( clipped_surrogate, (1.0 + self._duel_epsilon) * BATCH.gae_adv ) # [T, B, 1] clipped_surrogate = th.where(BATCH.gae_adv < 0, clipped_surrogate2, clipped_surrogate) # [T, B, 1] actor_loss = -(clipped_surrogate + self._ent_coef * entropy).mean() # 1 # ref: https://github.com/joschu/modular_rl/blob/6970cde3da265cf2a98537250fea5e0c0d9a7639/modular_rl/ppo.py#L40 # ref: https://github.com/hill-a/stable-baselines/blob/b3f414f4f2900403107357a2206f80868af16da3/stable_baselines/ppo2/ppo2.py#L185 if self._kl_reverse: # TODO: kl = .5 * (new_log_prob - BATCH.log_prob).square().mean() # 1 else: # a sample estimate for KL-divergence, easy to compute kl = .5 * (BATCH.log_prob - new_log_prob).square().mean() if self._use_kl_loss: kl_loss = self._kl_coef * kl # 1 actor_loss += kl_loss if self._use_extra_loss: extra_loss = self._extra_coef * th.maximum(th.zeros_like(kl), kl - self._kl_cutoff).square().mean() # 1 actor_loss += extra_loss td_error = BATCH.discounted_reward - value # [T, B, 1] if self._use_vclip: # ref: https://github.com/llSourcell/OpenAI_Five_vs_Dota2_Explained/blob/c5def7e57aa70785c2394ea2eeb3e5f66ad59a53/train.py#L154 # ref: https://github.com/hill-a/stable-baselines/blob/b3f414f4f2900403107357a2206f80868af16da3/stable_baselines/ppo2/ppo2.py#L172 value_clip = BATCH.value + (value - BATCH.value).clamp(-self._value_epsilon, self._value_epsilon) # [T, B, 1] td_error_clip = BATCH.discounted_reward - value_clip # [T, B, 1] td_square = th.maximum(td_error.square(), td_error_clip.square()) # [T, B, 1] else: td_square = td_error.square() # [T, B, 1] critic_loss = 0.5 * td_square.mean() # 1 loss = actor_loss + self._vf_coef * critic_loss # 1 self.oplr.optimize(loss) return { 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/kl': kl, 'Statistics/entropy': entropy.mean(), 'LEARNING_RATE/lr': self.oplr.lr }, kl @iton def train_actor(self, BATCH): if self.is_continuous: # [T, B, A], [T, B, A] mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) dist = td.Independent(td.Normal(mu, log_std.exp()), 1) new_log_prob = dist.log_prob(BATCH.action).unsqueeze(-1) # [T, B, 1] entropy = dist.entropy().unsqueeze(-1) # [T, B, 1] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] new_log_prob = (BATCH.action * logp_all).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1, keepdim=True) # [T, B, 1] ratio = (new_log_prob - BATCH.log_prob).exp() # [T, B, 1] kl = (BATCH.log_prob - new_log_prob).square().mean() # 1 surrogate = ratio * BATCH.gae_adv # [T, B, 1] clipped_surrogate = th.minimum( surrogate, th.where(BATCH.gae_adv > 0, (1 + self._epsilon) * BATCH.gae_adv, (1 - self._epsilon) * BATCH.gae_adv) ) # [T, B, 1] if self._use_duel_clip: clipped_surrogate = th.maximum( clipped_surrogate, (1.0 + self._duel_epsilon) * BATCH.gae_adv ) # [T, B, 1] actor_loss = -(clipped_surrogate + self._ent_coef * entropy).mean() # 1 if self._use_kl_loss: kl_loss = self._kl_coef * kl # 1 actor_loss += kl_loss if self._use_extra_loss: extra_loss = self._extra_coef * th.maximum(th.zeros_like(kl), kl - self._kl_cutoff).square().mean() # 1 actor_loss += extra_loss self.actor_oplr.optimize(actor_loss) return { 'LOSS/actor_loss': actor_loss, 'Statistics/kl': kl, 'Statistics/entropy': entropy.mean(), 'LEARNING_RATE/actor_lr': self.actor_oplr.lr }, kl @iton def train_critic(self, BATCH): value = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] td_error = BATCH.discounted_reward - value # [T, B, 1] if self._use_vclip: value_clip = BATCH.value + (value - BATCH.value).clamp(-self._value_epsilon, self._value_epsilon) # [T, B, 1] td_error_clip = BATCH.discounted_reward - value_clip # [T, B, 1] td_square = th.maximum(td_error.square(), td_error_clip.square()) # [T, B, 1] else: td_square = td_error.square() # [T, B, 1] critic_loss = 0.5 * td_square.mean() # 1 self.critic_oplr.optimize(critic_loss) return { 'LOSS/critic_loss': critic_loss, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr }
class AveragedDQN(SarlOffPolicy): """ Averaged-DQN, http://arxiv.org/abs/1611.01929 """ policy_mode = 'off-policy' def __init__(self, target_k: int = 4, lr: float = 5.0e-4, eps_init: float = 1, eps_mid: float = 0.2, eps_final: float = 0.01, init2mid_annealing_step: int = 1000, assign_interval: int = 1000, network_settings: List[int] = [32, 32], **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'dqn only support discrete action space' self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.target_k = target_k assert self.target_k > 0, "assert self.target_k > 0" self.current_target_idx = 0 self.q_net = CriticQvalueAll(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings).to( self.device) self.target_nets = [] for i in range(self.target_k): target_q_net = deepcopy(self.q_net) target_q_net.eval() sync_params(target_q_net, self.q_net) self.target_nets.append(target_q_net) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, *] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: for i in range(self.target_k): target_q_values = self.target_nets[i](obs, rnncs=self.rnncs) q_values += target_q_values actions = q_values.argmax(-1) # 不取平均也可以 [B, ] return actions, Data(action=actions) @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, *] q_next = 0 for i in range(self.target_k): q_next += self.target_nets[i](BATCH.obs_, begin_mask=BATCH.begin_mask) q_next /= self.target_k # [T, B, *] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_next.max(-1, keepdim=True)[0], BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(q_loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': q_loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: sync_params(self.target_nets[self.current_target_idx], self.q_net) self.current_target_idx = (self.current_target_idx + 1) % self.target_k
class SAC_V(SarlOffPolicy): """ Soft Actor Critic with Value neural network. https://arxiv.org/abs/1812.05905 Soft Actor-Critic for Discrete Action Settings. https://arxiv.org/abs/1910.07207 """ policy_mode = 'off-policy' def __init__( self, alpha=0.2, annealing=True, last_alpha=0.01, polyak=0.995, use_gumbel=True, discrete_tau=1.0, network_settings={ 'actor_continuous': { 'share': [128, 128], 'mu': [64], 'log_std': [64], 'soft_clip': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [64, 32], 'q': [128, 128], 'v': [128, 128] }, actor_lr=5.0e-4, critic_lr=1.0e-3, alpha_lr=5.0e-4, auto_adaption=True, **kwargs): super().__init__(**kwargs) self.polyak = polyak self.use_gumbel = use_gumbel self.discrete_tau = discrete_tau self.auto_adaption = auto_adaption self.annealing = annealing self.v_net = TargetTwin( CriticValue(self.obs_spec, rep_net_params=self._rep_net_params, network_settings=network_settings['v']), self.polyak).to(self.device) if self.is_continuous: self.actor = ActorCts( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to( self.device) else: self.actor = ActorDct( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to( self.device) # entropy = -log(1/|A|) = log |A| self.target_entropy = 0.98 * (-self.a_dim if self.is_continuous else np.log(self.a_dim)) if self.is_continuous or self.use_gumbel: self.q_net = CriticQvalueOne( self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, network_settings=network_settings['q']).to(self.device) else: self.q_net = CriticQvalueAll( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['q']).to(self.device) self.q_net2 = deepcopy(self.q_net) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR([self.q_net, self.q_net2, self.v_net], critic_lr, **self._oplr_params) if self.auto_adaption: self.log_alpha = th.tensor(0., requires_grad=True).to(self.device) self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params) self._trainer_modules.update(alpha_oplr=self.alpha_oplr) else: self.log_alpha = th.tensor(alpha).log().to(self.device) if self.annealing: self.alpha_annealing = LinearAnnealing(alpha, last_alpha, int(1e6)) self._trainer_modules.update(actor=self.actor, v_net=self.v_net, q_net=self.q_net, q_net2=self.q_net2, log_alpha=self.log_alpha, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) @property def alpha(self): return self.log_alpha.exp() @iton def select_action(self, obs): if self.is_continuous: mu, log_std = self.actor(obs, rnncs=self.rnncs) # [B, A] pi = td.Normal(mu, log_std.exp()).sample().tanh() # [B, A] mu.tanh_() # squash mu # [B, A] else: logits = self.actor(obs, rnncs=self.rnncs) # [B, A] mu = logits.argmax(-1) # [B,] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] self.rnncs_ = self.actor.get_rnncs() actions = pi if self._is_train_mode else mu return actions, Data(action=actions) def _train(self, BATCH): if self.is_continuous or self.use_gumbel: td_error, summaries = self._train_continuous(BATCH) else: td_error, summaries = self._train_discrete(BATCH) return td_error, summaries @iton def _train_continuous(self, BATCH): v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] v_target = self.v_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, 1] if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action( pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] q1 = self.q_net(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q2 = self.q_net2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q1_pi = self.q_net(BATCH.obs, pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q2_pi = self.q_net2(BATCH.obs, pi, begin_mask=BATCH.begin_mask) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target, BATCH.begin_mask).detach() # [T, B, 1] v_from_q_stop = (th.minimum(q1_pi, q2_pi) - self.alpha * log_pi).detach() # [T, B, 1] td_v = v - v_from_q_stop # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action( pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 q1_pi = self.q_net(BATCH.obs, pi, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -(q1_pi - self.alpha * log_pi).mean() # 1 self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/v_loss': v_loss_stop, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy, 'Statistics/q_min': th.minimum(q1, q2).min(), 'Statistics/q_mean': th.minimum(q1, q2).mean(), 'Statistics/q_max': th.maximum(q1, q2).max(), 'Statistics/v_mean': v.mean() } if self.auto_adaption: alpha_loss = -(self.alpha * (log_pi.detach() + self.target_entropy)).mean() self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries @iton def _train_discrete(self, BATCH): v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] v_target = self.v_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, 1] q1_all = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2_all = self.q_net2(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q1 = (q1_all * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q2 = (q2_all * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target, BATCH.begin_mask).detach() # [T, B, 1] td_v = v - (th.minimum((logp_all.exp() * q1_all).sum(-1, keepdim=True), (logp_all.exp() * q2_all).sum( -1, keepdim=True))).detach() # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop self.critic_oplr.optimize(critic_loss) q1_all = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2_all = self.q_net2(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] entropy = -(logp_all.exp() * logp_all).sum(-1, keepdim=True) # [T, B, 1] q_all = th.minimum(q1_all, q2_all) # [T, B, A] actor_loss = -((q_all - self.alpha * logp_all) * logp_all.exp()).sum( -1) # [T, B, A] => [T, B] actor_loss = actor_loss.mean() # 1 self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/v_loss': v_loss_stop, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy.mean(), 'Statistics/v_mean': v.mean() } if self.auto_adaption: corr = (self.target_entropy - entropy).detach() # [T, B, 1] # corr = ((logp_all - self.a_dim) * logp_all.exp()).sum(-1).detach() alpha_loss = -(self.alpha * corr) # [T, B, 1] alpha_loss = alpha_loss.mean() # 1 self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries def _after_train(self): super()._after_train() if self.annealing and not self.auto_adaption: self.log_alpha.copy_( self.alpha_annealing(self._cur_train_step).log()) self.v_net.sync()
class BCQ(SarlOffPolicy): """ Benchmarking Batch Deep Reinforcement Learning Algorithms, http://arxiv.org/abs/1910.01708 Off-Policy Deep Reinforcement Learning without Exploration, http://arxiv.org/abs/1812.02900 """ policy_mode = 'off-policy' def __init__(self, polyak=0.995, discrete=dict(threshold=0.3, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=1000, network_settings=[32, 32]), continuous=dict(phi=0.05, lmbda=0.75, select_samples=100, train_samples=10, actor_lr=1e-3, critic_lr=1e-3, vae_lr=1e-3, network_settings=dict( actor=[32, 32], critic=[32, 32], vae=dict(encoder=[750, 750], decoder=[750, 750]))), **kwargs): super().__init__(**kwargs) self._polyak = polyak if self.is_continuous: self._lmbda = continuous['lmbda'] self._select_samples = continuous['select_samples'] self._train_samples = continuous['train_samples'] self.actor = TargetTwin(BCQ_Act_Cts( self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, phi=continuous['phi'], network_settings=continuous['network_settings']['actor']), polyak=self._polyak).to(self.device) self.critic = TargetTwin(BCQ_CriticQvalueOne( self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, network_settings=continuous['network_settings']['critic']), polyak=self._polyak).to(self.device) self.vae = VAE(self.obs_spec, rep_net_params=self._rep_net_params, a_dim=self.a_dim, z_dim=self.a_dim * 2, hiddens=continuous['network_settings']['vae']).to( self.device) self.actor_oplr = OPLR(self.actor, continuous['actor_lr'], **self._oplr_params) self.critic_oplr = OPLR(self.critic, continuous['critic_lr'], **self._oplr_params) self.vae_oplr = OPLR(self.vae, continuous['vae_lr'], **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, vae=self.vae, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr, vae_oplr=self.vae_oplr) else: self.expl_expt_mng = ExplorationExploitationClass( eps_init=discrete['eps_init'], eps_mid=discrete['eps_mid'], eps_final=discrete['eps_final'], init2mid_annealing_step=discrete['init2mid_annealing_step'], max_step=self._max_train_step) self.assign_interval = discrete['assign_interval'] self._threshold = discrete['threshold'] self.q_net = TargetTwin(BCQ_DCT( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=discrete['network_settings']), polyak=self._polyak).to(self.device) self.oplr = OPLR(self.q_net, discrete['lr'], **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): if self.is_continuous: _actions = [] for _ in range(self._select_samples): _actions.append( self.actor(obs, self.vae.decode(obs), rnncs=self.rnncs)) # [B, A] self.rnncs_ = self.actor.get_rnncs( ) # TODO: calculate corrected hidden state _actions = th.stack(_actions, dim=0) # [N, B, A] q1s = [] for i in range(self._select_samples): q1s.append(self.critic(obs, _actions[i])[0]) q1s = th.stack(q1s, dim=0) # [N, B, 1] max_idxs = q1s.argmax(dim=0, keepdim=True)[-1] # [1, B, 1] actions = _actions[ max_idxs, th.arange(self.n_copies).reshape(self.n_copies, 1), th.arange(self.a_dim)] else: q_values, i_values = self.q_net(obs, rnncs=self.rnncs) # [B, *] q_values = q_values - q_values.min(dim=-1, keepdim=True)[0] # [B, *] i_values = F.log_softmax(i_values, dim=-1) # [B, *] i_values = i_values.exp() # [B, *] i_values = (i_values / i_values.max(-1, keepdim=True)[0] > self._threshold).float() # [B, *] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: actions = (i_values * q_values).argmax(-1) # [B,] return actions, Data(action=actions) @iton def _train(self, BATCH): if self.is_continuous: # Variational Auto-Encoder Training recon, mean, std = self.vae(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) recon_loss = F.mse_loss(recon, BATCH.action) KL_loss = -0.5 * (1 + th.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() vae_loss = recon_loss + 0.5 * KL_loss self.vae_oplr.optimize(vae_loss) target_Qs = [] for _ in range(self._train_samples): # Compute value of perturbed actions sampled from the VAE _vae_actions = self.vae.decode(BATCH.obs_, begin_mask=BATCH.begin_mask) _actor_actions = self.actor.t(BATCH.obs_, _vae_actions, begin_mask=BATCH.begin_mask) target_Q1, target_Q2 = self.critic.t( BATCH.obs_, _actor_actions, begin_mask=BATCH.begin_mask) # Soft Clipped Double Q-learning target_Q = self._lmbda * th.min(target_Q1, target_Q2) + \ (1. - self._lmbda) * th.max(target_Q1, target_Q2) target_Qs.append(target_Q) target_Qs = th.stack(target_Qs, dim=0) # [N, T, B, 1] # Take max over each BATCH.action sampled from the VAE target_Q = target_Qs.max(dim=0)[0] # [T, B, 1] target_Q = n_step_return(BATCH.reward, self.gamma, BATCH.done, target_Q, BATCH.begin_mask).detach() # [T, B, 1] current_Q1, current_Q2 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) td_error = ((current_Q1 - target_Q) + (current_Q2 - target_Q)) / 2 critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) self.critic_oplr.optimize(critic_loss) # Pertubation Model / Action Training sampled_actions = self.vae.decode(BATCH.obs, begin_mask=BATCH.begin_mask) perturbed_actions = self.actor(BATCH.obs, sampled_actions, begin_mask=BATCH.begin_mask) # Update through DPG q1, _ = self.critic(BATCH.obs, perturbed_actions, begin_mask=BATCH.begin_mask) actor_loss = -q1.mean() self.actor_oplr.optimize(actor_loss) return td_error, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LEARNING_RATE/vae_lr': self.vae_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'LOSS/vae_loss': vae_loss, 'Statistics/q_min': q1.min(), 'Statistics/q_mean': q1.mean(), 'Statistics/q_max': q1.max() } else: q_next, i_next = self.q_net( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_next = q_next - q_next.min(dim=-1, keepdim=True)[0] # [B, *] i_next = F.log_softmax(i_next, dim=-1) # [T, B, A] i_next = i_next.exp() # [T, B, A] i_next = (i_next / i_next.max(-1, keepdim=True)[0] > self._threshold).float() # [T, B, A] q_next = i_next * q_next # [T, B, A] next_max_action = q_next.argmax(-1) # [T, B] next_max_action_one_hot = F.one_hot( next_max_action.squeeze(), self.a_dim).float() # [T, B, A] q_target_next, _ = self.q_net.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_target_next_max = (q_target_next * next_max_action_one_hot).sum( -1, keepdim=True) # [T, B, 1] q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target_next_max, BATCH.begin_mask).detach() # [T, B, 1] q, i = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 imt = F.log_softmax(i, dim=-1) # [T, B, A] imt = imt.reshape(-1, self.a_dim) # [T*B, A] action = BATCH.action.reshape(-1, self.a_dim) # [T*B, A] i_loss = F.nll_loss(imt, action.argmax(-1)) # 1 loss = q_loss + i_loss + 1e-2 * i.pow(2).mean() self.oplr.optimize(loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/q_loss': q_loss, 'LOSS/i_loss': i_loss, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self.is_continuous: self.actor.sync() self.critic.sync() else: if self._polyak != 0: self.q_net.sync() else: if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class SQL(SarlOffPolicy): """ Soft Q-Learning. ref: https://github.com/Bigpig4396/PyTorch-Soft-Q-Learning/blob/master/SoftQ.py NOTE: not the original of the paper, NO SVGD. Reinforcement Learning with Deep Energy-Based Policies: https://arxiv.org/abs/1702.08165 """ policy_mode = 'off-policy' def __init__(self, lr=5.0e-4, alpha=2, polyak=0.995, network_settings=[32, 32], **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'sql only support discrete action space' self.alpha = alpha self.polyak = polyak self.q_net = TargetTwin(CriticQvalueAll(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings), self.polyak).to(self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.q_net.get_rnncs() logits = ((q_values - self._get_v(q_values)) / self.alpha).exp() # > 0 # [B, A] logits /= logits.sum(-1, keepdim=True) # [B, A] cate_dist = td.Categorical(logits=logits) actions = cate_dist.sample() # [B,] return actions, Data(action=actions) def _get_v(self, q): v = self.alpha * (q / self.alpha).exp().mean(-1, keepdim=True).log() # [B, 1] or [T, B, 1] return v @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] v_next = self._get_v(q_next) # [T, B, 1] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_next, BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(q_loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': q_loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() self.q_net.sync()
class QRDQN(SarlOffPolicy): """ Quantile Regression DQN Distributional Reinforcement Learning with Quantile Regression, https://arxiv.org/abs/1710.10044 No double, no dueling, no noisy net. """ policy_mode = 'off-policy' def __init__(self, nums=20, huber_delta=1., lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=1000, network_settings=[128, 128], **kwargs): assert nums > 0, 'assert nums > 0' super().__init__(**kwargs) assert not self.is_continuous, 'qrdqn only support discrete action space' self.nums = nums self.huber_delta = huber_delta self.quantiles = th.tensor((2 * np.arange(self.nums) + 1) / (2.0 * self.nums)).float().to(self.device) # [N,] self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.q_net = TargetTwin(QrdqnDistributional(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, nums=self.nums, network_settings=network_settings)).to(self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, A, N] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: q = q_values.mean(-1) # [B, A, N] => [B, A] actions = q.argmax(-1) # [B,] return actions, Data(action=actions) @iton def _train(self, BATCH): q_dist = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A, N] q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2) # [T, B, A, N] => [T, B, N] target_q_dist = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A, N] target_q = target_q_dist.mean(-1) # [T, B, A, N] => [T, B, A] _a = target_q.argmax(-1) # [T, B] next_max_action = F.one_hot(_a, self.a_dim).float().unsqueeze(-1) # [T, B, A, 1] # [T, B, A, N] => [T, B, N] target_q_dist = (target_q_dist * next_max_action).sum(-2) target = n_step_return(BATCH.reward.repeat(1, 1, self.nums), self.gamma, BATCH.done.repeat(1, 1, self.nums), target_q_dist, BATCH.begin_mask.repeat(1, 1, self.nums)).detach() # [T, B, N] q_eval = q_dist.mean(-1, keepdim=True) # [T, B, 1] q_target = target.mean(-1, keepdim=True) # [T, B, 1] td_error = q_target - q_eval # [T, B, 1], used for PER target = target.unsqueeze(-2) # [T, B, 1, N] q_dist = q_dist.unsqueeze(-1) # [T, B, N, 1] # [T, B, 1, N] - [T, B, N, 1] => [T, B, N, N] quantile_error = target - q_dist huber = F.huber_loss(target, q_dist, reduction="none", delta=self.huber_delta) # [T, B, N, N] # [N,] - [T, B, N, N] => [T, B, N, N] huber_abs = (self.quantiles - quantile_error.detach().le(0.).float()).abs() loss = (huber_abs * huber).mean(-1) # [T, B, N, N] => [T, B, N] loss = loss.sum(-1, keepdim=True) # [T, B, N] => [T, B, 1] loss = (loss * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class TD3(SarlOffPolicy): """ Twin Delayed Deep Deterministic Policy Gradient, https://arxiv.org/abs/1802.09477 """ policy_mode = 'off-policy' def __init__(self, polyak=0.995, delay_num=2, noise_action='clip_normal', noise_params={ 'sigma': 0.2, 'noise_bound': 0.2 }, actor_lr=5.0e-4, critic_lr=1.0e-3, discrete_tau=1.0, network_settings={ 'actor_continuous': [32, 32], 'actor_discrete': [32, 32], 'q': [32, 32] }, **kwargs): super().__init__(**kwargs) self.polyak = polyak self.delay_num = delay_num self.discrete_tau = discrete_tau if self.is_continuous: actor = ActorDPG( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']) self.noised_action = self.target_noised_action = Noise_action_REGISTER[ noise_action](**noise_params) else: actor = ActorDct( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']) self.actor = TargetTwin(actor, self.polyak).to(self.device) self.critic = TargetTwin( CriticQvalueOne(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, network_settings=network_settings['q']), self.polyak).to(self.device) self.critic2 = deepcopy(self.critic) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR([self.critic, self.critic2], critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, critic2=self.critic2, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) def episode_reset(self): super().episode_reset() if self.is_continuous: self.noised_action.reset() @iton def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() if self.is_continuous: mu = output # [B, A] pi = self.noised_action(mu) # [B, A] else: logits = output # [B, A] mu = logits.argmax(-1) # [B,] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] actions = pi if self._is_train_mode else mu return actions, Data(action=actions) @iton def _train(self, BATCH): for _ in range(self.delay_num): if self.is_continuous: action_target = self.target_noised_action( self.actor.t(BATCH.obs_, begin_mask=BATCH.begin_mask)) # [T, B, A] else: target_logits = self.actor.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] action_target = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q2 = self.critic2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q_target = th.minimum( self.critic.t(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask), self.critic2.t(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask)) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target, BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * (q1_loss + q2_loss) self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] q1_actor = self.critic(BATCH.obs, mu, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -q1_actor.mean() # 1 self.actor_oplr.optimize(actor_loss) return (td_error1 + td_error2) / 2, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/q_min': th.minimum(q1, q2).min(), 'Statistics/q_mean': th.minimum(q1, q2).mean(), 'Statistics/q_max': th.maximum(q1, q2).max() } def _after_train(self): super()._after_train() self.actor.sync() self.critic.sync() self.critic2.sync()
class PlaNet(SarlOffPolicy): """ Learning Latent Dynamics for Planning from Pixels, http://arxiv.org/abs/1811.04551 """ policy_mode = 'off-policy' def __init__(self, stoch_dim=30, deter_dim=200, model_lr=6e-4, kl_free_nats=3, kl_scale=1.0, reward_scale=1.0, cem_horizon=12, cem_iter_nums=10, cem_candidates=1000, cem_tops=100, action_sigma=0.3, network_settings=dict(), **kwargs): super().__init__(**kwargs) assert self.is_continuous == True, 'assert self.is_continuous == True' self.cem_horizon = cem_horizon self.cem_iter_nums = cem_iter_nums self.cem_candidates = cem_candidates self.cem_tops = cem_tops assert self.use_rnn == False, 'assert self.use_rnn == False' if self.obs_spec.has_visual_observation \ and len(self.obs_spec.visual_dims) == 1 \ and not self.obs_spec.has_vector_observation: visual_dim = self.obs_spec.visual_dims[0] # TODO: optimize this assert visual_dim[0] == visual_dim[1] == 64, 'visual dimension must be [64, 64, *]' self._is_visual = True elif self.obs_spec.has_vector_observation \ and len(self.obs_spec.vector_dims) == 1 \ and not self.obs_spec.has_visual_observation: self._is_visual = False else: raise ValueError("please check the observation type") self.stoch_dim = stoch_dim self.deter_dim = deter_dim self.kl_free_nats = kl_free_nats self.kl_scale = kl_scale self.reward_scale = reward_scale self._action_sigma = action_sigma self._network_settings = network_settings if self.obs_spec.has_visual_observation: from rls.nn.dreamer import VisualDecoder, VisualEncoder self.obs_encoder = VisualEncoder(self.obs_spec.visual_dims[0], **network_settings['obs_encoder']['visual']).to(self.device) self.obs_decoder = VisualDecoder(self.decoder_input_dim, self.obs_spec.visual_dims[0], **network_settings['obs_decoder']['visual']).to(self.device) else: from rls.nn.dreamer import VectorEncoder self.obs_encoder = VectorEncoder(self.obs_spec.vector_dims[0], **network_settings['obs_encoder']['vector']).to(self.device) self.obs_decoder = DenseModel(self.decoder_input_dim, self.obs_spec.vector_dims[0], **network_settings['obs_decoder']['vector']).to(self.device) self.rssm = self._dreamer_build_rssm() """ p(r_t | s_t, h_t) Reward model to predict reward from state and rnn hidden state """ self.reward_predictor = DenseModel(self.decoder_input_dim, 1, **network_settings['reward']).to(self.device) self.model_oplr = OPLR([self.obs_encoder, self.rssm, self.obs_decoder, self.reward_predictor], model_lr, **self._oplr_params) self._trainer_modules.update(obs_encoder=self.obs_encoder, obs_decoder=self.obs_decoder, reward_predictor=self.reward_predictor, rssm=self.rssm, model_oplr=self.model_oplr) @property def decoder_input_dim(self): return self.stoch_dim + self.deter_dim def _dreamer_build_rssm(self): return RecurrentStateSpaceModel(self.stoch_dim, self.deter_dim, self.a_dim, self.obs_encoder.h_dim, **self._network_settings['rssm']).to(self.device) @iton def select_action(self, obs): if self._is_visual: obs = get_first_visual(obs) else: obs = get_first_vector(obs) # Compute starting state for planning # while taking information from current observation (posterior) embedded_obs = self.obs_encoder(obs) # [B, *] state_posterior = self.rssm.posterior(self.rnncs['hx'], embedded_obs) # dist # [B, *] # Initialize action distribution mean = th.zeros((self.cem_horizon, 1, self.n_copies, self.a_dim)) # [H, 1, B, A] stddev = th.ones((self.cem_horizon, 1, self.n_copies, self.a_dim)) # [H, 1, B, A] # Iteratively improve action distribution with CEM for itr in range(self.cem_iter_nums): action_candidates = mean + stddev * \ th.randn(self.cem_horizon, self.cem_candidates, self.n_copies, self.a_dim) # [H, N, B, A] action_candidates = action_candidates.reshape(self.cem_horizon, -1, self.a_dim) # [H, N*B, A] # Initialize reward, state, and rnn hidden state # These are for parallel exploration total_predicted_reward = th.zeros((self.cem_candidates * self.n_copies, 1)) # [N*B, 1] state = state_posterior.sample((self.cem_candidates,)) # [N, B, *] state = state.view(-1, state.shape[-1]) # [N*B, *] rnn_hidden = self.rnncs['hx'].repeat((self.cem_candidates, 1)) # [B, *] => [N*B, *] # Compute total predicted reward by open-loop prediction using pri for t in range(self.cem_horizon): next_state_prior, rnn_hidden = self.rssm.prior(state, th.tanh(action_candidates[t]), rnn_hidden) state = next_state_prior.sample() # [N*B, *] post_feat = th.cat([state, rnn_hidden], -1) # [N*B, *] total_predicted_reward += self.reward_predictor(post_feat).mean # [N*B, 1] # update action distribution using top-k samples total_predicted_reward = total_predicted_reward.view(self.cem_candidates, self.n_copies, 1) # [N, B, 1] _, top_indexes = total_predicted_reward.topk(self.cem_tops, dim=0, largest=True, sorted=False) # [N', B, 1] action_candidates = action_candidates.view(self.cem_horizon, self.cem_candidates, self.n_copies, -1) # [H, N, B, A] top_action_candidates = action_candidates[:, top_indexes, th.arange(self.n_copies).reshape(self.n_copies, 1), th.arange(self.a_dim)] # [H, N', B, A] mean = top_action_candidates.mean(dim=1, keepdim=True) # [H, 1, B, A] stddev = top_action_candidates.std(dim=1, unbiased=False, keepdim=True) # [H, 1, B, A] # Return only first action (replan each state based on new observation) actions = th.tanh(mean[0].squeeze(0)) # [B, A] actions = self._exploration(actions) _, self.rnncs_['hx'] = self.rssm.prior(state_posterior.sample(), actions, self.rnncs['hx']) return actions, Data(action=actions) def _exploration(self, action: th.Tensor) -> th.Tensor: """ :param action: action to take, shape (1,) (if categorical), or (action dim,) (if continuous) :return: action of the same shape passed in, augmented with some noise """ sigma = self._action_sigma if self._is_train_mode else 0. noise = th.randn(*action.shape) * sigma return th.clamp(action + noise, -1, 1) @iton def _train(self, BATCH): T, B = BATCH.action.shape[:2] if self._is_visual: obs_ = get_first_visual(BATCH.obs_) else: obs_ = get_first_vector(BATCH.obs_) # embed observations with CNN embedded_observations = self.obs_encoder(obs_) # [T, B, *] # initialize state and rnn hidden state with 0 vector state, rnn_hidden = self.rssm.init_state(shape=B) # [B, S], [B, D] # compute state and rnn hidden sequences and kl loss kl_loss = 0 states, rnn_hiddens = [], [] for l in range(T): # if the begin of this episode, then reset to 0. # No matther whether last episode is beened truncated of not. state = state * (1. - BATCH.begin_mask[l]) # [B, S] rnn_hidden = rnn_hidden * (1. - BATCH.begin_mask[l]) # [B, D] next_state_prior, next_state_posterior, rnn_hidden = self.rssm(state, BATCH.action[l], rnn_hidden, embedded_observations[l]) # a, s_ state = next_state_posterior.rsample() # [B, S] posterior of s_ states.append(state) # [B, S] rnn_hiddens.append(rnn_hidden) # [B, D] kl_loss += self._kl_loss(next_state_prior, next_state_posterior) kl_loss /= T # 1 # compute reconstructed observations and predicted rewards post_feat = th.cat([th.stack(states, 0), th.stack(rnn_hiddens, 0)], -1) # [T, B, *] obs_pred = self.obs_decoder(post_feat) # [T, B, C, H, W] or [T, B, *] reward_pred = self.reward_predictor(post_feat) # [T, B, 1], s_ => r # compute loss for observation and reward obs_loss = -th.mean(obs_pred.log_prob(obs_)) # [T, B] => 1 # [T, B, 1]=>1 reward_loss = -th.mean(reward_pred.log_prob(BATCH.reward).unsqueeze(-1)) # add all losses and update model parameters with gradient descent model_loss = self.kl_scale * kl_loss + obs_loss + self.reward_scale * reward_loss # 1 self.model_oplr.optimize(model_loss) summaries = { 'LEARNING_RATE/model_lr': self.model_oplr.lr, 'LOSS/model_loss': model_loss, 'LOSS/kl_loss': kl_loss, 'LOSS/obs_loss': obs_loss, 'LOSS/reward_loss': reward_loss } return th.ones_like(BATCH.reward), summaries def _initial_rnncs(self, batch: int) -> Dict[str, np.ndarray]: return {'hx': np.zeros((batch, self.deter_dim))} def _kl_loss(self, prior_dist, post_dist): # 1 return td.kl_divergence(prior_dist, post_dist).clamp(min=self.kl_free_nats).mean()
class DQN(SarlOffPolicy): """ Deep Q-learning Network, DQN, [2013](https://arxiv.org/pdf/1312.5602.pdf), [2015](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) DQN + LSTM, https://arxiv.org/abs/1507.06527 """ policy_mode = 'off-policy' def __init__(self, lr: float = 5.0e-4, eps_init: float = 1, eps_mid: float = 0.2, eps_final: float = 0.01, init2mid_annealing_step: int = 1000, assign_interval: int = 1000, network_settings: List[int] = [32, 32], **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'dqn only support discrete action space' self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.q_net = TargetTwin( CriticQvalueAll(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings)).to(self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net) self._trainer_modules.update(oplr=self.oplr) @iton def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, *] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: actions = q_values.argmax(-1) # [B,] return actions, Data(action=actions) @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q_target = n_step_return( BATCH.reward, self.gamma, BATCH.done, q_next.max(-1, keepdim=True)[0], BATCH.begin_mask, nstep=self._n_step_value).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(q_loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': q_loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class A2C(SarlOnPolicy): """ Synchronous Advantage Actor-Critic, A2C, http://arxiv.org/abs/1602.01783 """ policy_mode = 'on-policy' def __init__( self, agent_spec, beta=1.0e-3, actor_lr=5.0e-4, critic_lr=1.0e-3, network_settings={ 'actor_continuous': { 'hidden_units': [64, 64], 'condition_sigma': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [32, 32], 'critic': [32, 32] }, **kwargs): super().__init__(agent_spec=agent_spec, **kwargs) self.beta = beta if self.is_continuous: self.actor = ActorMuLogstd( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to( self.device) else: self.actor = ActorDct( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to( self.device) self.critic = CriticValue( self.obs_spec, rep_net_params=self._rep_net_params, network_settings=network_settings['critic']).to(self.device) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) @iton def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() if self.is_continuous: mu, log_std = output # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] else: logits = output # [B, A] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] acts_info = Data(action=action) if self.use_rnn: acts_info.update(rnncs=self.rnncs) return action, acts_info @iton def _get_value(self, obs, rnncs=None): value = self.critic(obs, rnncs=self.rnncs) return value def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH = super()._preprocess_BATCH(BATCH) value = self._get_value(BATCH.obs_[-1], rnncs=self.rnncs) BATCH.discounted_reward = discounted_sum(BATCH.reward, self.gamma, BATCH.done, BATCH.begin_mask, init_value=value) td_error = calculate_td_error( BATCH.reward, self.gamma, BATCH.done, value=BATCH.value, next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]), 0)) BATCH.gae_adv = discounted_sum(td_error, self.lambda_ * self.gamma, BATCH.done, BATCH.begin_mask, init_value=0., normalize=True) return BATCH @iton def _train(self, BATCH): v = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] td_error = BATCH.discounted_reward - v # [T, B, 1] critic_loss = td_error.square().mean() # 1 self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_act_prob = dist.log_prob(BATCH.action).unsqueeze( -1) # [T, B, 1] entropy = dist.entropy().unsqueeze(-1) # [T, B, 1] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] log_act_prob = (BATCH.action * logp_all).sum( -1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum( -1, keepdim=True) # [T, B, 1] # advantage = BATCH.discounted_reward - v.detach() # [T, B, 1] actor_loss = -(log_act_prob * BATCH.gae_adv + self.beta * entropy).mean() # 1 self.actor_oplr.optimize(actor_loss) return { 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/entropy': entropy.mean(), 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr }
class TAC(SarlOffPolicy): """Tsallis Actor Critic, TAC with V neural Network. https://arxiv.org/abs/1902.00137 """ policy_mode = 'off-policy' def __init__(self, alpha=0.2, annealing=True, last_alpha=0.01, polyak=0.995, entropic_index=1.5, discrete_tau=1.0, network_settings={ 'actor_continuous': { 'share': [128, 128], 'mu': [64], 'log_std': [64], 'soft_clip': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [64, 32], 'q': [128, 128] }, auto_adaption=True, actor_lr=5.0e-4, critic_lr=1.0e-3, alpha_lr=5.0e-4, **kwargs): super().__init__(**kwargs) self.polyak = polyak self.discrete_tau = discrete_tau self.entropic_index = 2 - entropic_index self.auto_adaption = auto_adaption self.annealing = annealing self.critic = TargetTwin(CriticQvalueOne(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, network_settings=network_settings['q']), self.polyak).to(self.device) self.critic2 = deepcopy(self.critic) if self.is_continuous: self.actor = ActorCts(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to(self.device) else: self.actor = ActorDct(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to(self.device) # entropy = -log(1/|A|) = log |A| self.target_entropy = 0.98 * (-self.a_dim if self.is_continuous else np.log(self.a_dim)) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR([self.critic, self.critic2], critic_lr, **self._oplr_params) if self.auto_adaption: self.log_alpha = th.tensor(0., requires_grad=True).to(self.device) self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params) self._trainer_modules.update(alpha_oplr=self.alpha_oplr) else: self.log_alpha = th.tensor(alpha).log().to(self.device) if self.annealing: self.alpha_annealing = LinearAnnealing(alpha, last_alpha, int(1e6)) self._trainer_modules.update(actor=self.actor, critic=self.critic, critic2=self.critic2, log_alpha=self.log_alpha, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) @property def alpha(self): return self.log_alpha.exp() @iton def select_action(self, obs): if self.is_continuous: mu, log_std = self.actor(obs, rnncs=self.rnncs) # [B, A] pi = td.Normal(mu, log_std.exp()).sample().tanh() # [B, A] mu.tanh_() # squash mu # [B, A] else: logits = self.actor(obs, rnncs=self.rnncs) # [B, A] mu = logits.argmax(-1) # [B,] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] self.rnncs_ = self.actor.get_rnncs() actions = pi if self._is_train_mode else mu return actions, Data(action=actions) @iton def _train(self, BATCH): if self.is_continuous: target_mu, target_log_std = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(target_mu, target_log_std.exp()), 1) target_pi = dist.sample() # [T, B, A] target_pi, target_log_pi = squash_action(target_pi, dist.log_prob( target_pi).unsqueeze(-1), is_independent=False) # [T, B, A] target_log_pi = tsallis_entropy_log_q(target_log_pi, self.entropic_index) # [T, B, 1] else: target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze(-1) # [T, B, 1] target_pi = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q2 = self.critic2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q1_target = self.critic.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q2_target = self.critic2.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q_target = th.minimum(q1_target, q2_target) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, (q_target - self.alpha * target_log_pi), BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action(pi, dist.log_prob(pi).unsqueeze(-1), is_independent=False) # [T, B, A] log_pi = tsallis_entropy_log_q(log_pi, self.entropic_index) # [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(-1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 q_s_pi = th.minimum(self.critic(BATCH.obs, pi, begin_mask=BATCH.begin_mask), self.critic2(BATCH.obs, pi, begin_mask=BATCH.begin_mask)) # [T, B, 1] actor_loss = -(q_s_pi - self.alpha * log_pi).mean() # 1 self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy, 'Statistics/q_min': th.minimum(q1, q2).min(), 'Statistics/q_mean': th.minimum(q1, q2).mean(), 'Statistics/q_max': th.maximum(q1, q2).max() } if self.auto_adaption: alpha_loss = -(self.alpha * (log_pi + self.target_entropy).detach()).mean() # 1 self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries def _after_train(self): super()._after_train() self.critic.sync() self.critic2.sync() if self.annealing and not self.auto_adaption: self.log_alpha.copy_(self.alpha_annealing(self._cur_train_step).log())
class DPG(SarlOffPolicy): """ Deterministic Policy Gradient, https://hal.inria.fr/file/index/docid/938992/filename/dpg-icml2014.pdf """ policy_mode = 'off-policy' def __init__(self, actor_lr=5.0e-4, critic_lr=1.0e-3, use_target_action_noise=False, noise_action='ou', noise_params={ 'sigma': 0.2 }, discrete_tau=1.0, network_settings={ 'actor_continuous': [32, 32], 'actor_discrete': [32, 32], 'q': [32, 32] }, **kwargs): super().__init__(**kwargs) self.discrete_tau = discrete_tau self.use_target_action_noise = use_target_action_noise if self.is_continuous: self.target_noised_action = ClippedNormalNoisedAction(sigma=0.2, noise_bound=0.2) self.noised_action = Noise_action_REGISTER[noise_action](**noise_params) self.actor = ActorDPG(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to(self.device) else: self.actor = ActorDct(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to(self.device) self.critic = CriticQvalueOne(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, network_settings=network_settings['q']).to(self.device) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) def episode_reset(self): super().episode_reset() if self.is_continuous: self.noised_action.reset() @iton def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() if self.is_continuous: mu = output # [B, A] pi = self.noised_action(mu) # [B, A] else: logits = output # [B, A] mu = logits.argmax(-1) # [B,] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] actions = pi if self._is_train_mode else mu return actions, Data(action=actions) @iton def _train(self, BATCH): if self.is_continuous: action_target = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] if self.use_target_action_noise: action_target = self.target_noised_action(action_target) # [T, B, A] else: target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] action_target = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q_target = self.critic(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target, BATCH.begin_mask).detach() # [T, B, 1] q = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, A] td_error = dc_r - q # [T, B, A] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.critic_oplr.optimize(q_loss) if self.is_continuous: mu = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] _pi = logits.softmax(-1) # [T, B, A] _pi_true_one_hot = F.one_hot( logits.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] q_actor = self.critic(BATCH.obs, mu, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -q_actor.mean() # 1 self.actor_oplr.optimize(actor_loss) return td_error, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': q_loss, 'Statistics/q_min': q.min(), 'Statistics/q_mean': q.mean(), 'Statistics/q_max': q.max() }
class MADDPG(MultiAgentOffPolicy): """ Multi-Agent Deep Deterministic Policy Gradient, https://arxiv.org/abs/1706.02275 """ policy_mode = 'off-policy' def __init__(self, polyak=0.995, noise_action='ou', noise_params={'sigma': 0.2}, actor_lr=5.0e-4, critic_lr=1.0e-3, discrete_tau=1.0, network_settings={ 'actor_continuous': [32, 32], 'actor_discrete': [32, 32], 'q': [32, 32] }, **kwargs): """ TODO: Annotation """ super().__init__(**kwargs) self.polyak = polyak self.discrete_tau = discrete_tau self.actors, self.critics = {}, {} for id in set(self.model_ids): if self.is_continuouss[id]: self.actors[id] = TargetTwin( ActorDPG( self.obs_specs[id], rep_net_params=self._rep_net_params, output_shape=self.a_dims[id], network_settings=network_settings['actor_continuous']), self.polyak).to(self.device) else: self.actors[id] = TargetTwin( ActorDct( self.obs_specs[id], rep_net_params=self._rep_net_params, output_shape=self.a_dims[id], network_settings=network_settings['actor_discrete']), self.polyak).to(self.device) self.critics[id] = TargetTwin( MACriticQvalueOne(list(self.obs_specs.values()), rep_net_params=self._rep_net_params, action_dim=sum(self.a_dims.values()), network_settings=network_settings['q']), self.polyak).to(self.device) self.actor_oplr = OPLR(list(self.actors.values()), actor_lr, **self._oplr_params) self.critic_oplr = OPLR(list(self.critics.values()), critic_lr, **self._oplr_params) # TODO: 添加动作类型判断 self.noised_actions = { id: Noise_action_REGISTER[noise_action](**noise_params) for id in set(self.model_ids) if self.is_continuouss[id] } self._trainer_modules.update( {f"actor_{id}": self.actors[id] for id in set(self.model_ids)}) self._trainer_modules.update( {f"critic_{id}": self.critics[id] for id in set(self.model_ids)}) self._trainer_modules.update(actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) def episode_reset(self): super().episode_reset() for noised_action in self.noised_actions.values(): noised_action.reset() @iton def select_action(self, obs: Dict): acts_info = {} actions = {} for aid, mid in zip(self.agent_ids, self.model_ids): output = self.actors[mid](obs[aid], rnncs=self.rnncs[aid]) # [B, A] self.rnncs_[aid] = self.actors[mid].get_rnncs() if self.is_continuouss[aid]: mu = output # [B, A] pi = self.noised_actions[mid](mu) # [B, A] else: logits = output # [B, A] mu = logits.argmax(-1) # [B,] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] action = pi if self._is_train_mode else mu acts_info[aid] = Data(action=action) actions[aid] = action return actions, acts_info @iton def _train(self, BATCH_DICT): """ TODO: Annotation """ summaries = defaultdict(dict) target_actions = {} for aid, mid in zip(self.agent_ids, self.model_ids): if self.is_continuouss[aid]: target_actions[aid] = self.actors[mid].t( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] else: target_logits = self.actors[mid].t( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] action_target = F.one_hot( target_pi, self.a_dims[aid]).float() # [T, B, A] target_actions[aid] = action_target # [T, B, A] target_actions = th.cat(list(target_actions.values()), -1) # [T, B, N*A] qs, q_targets = {}, {} for mid in self.model_ids: qs[mid] = self.critics[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat([BATCH_DICT[id].action for id in self.agent_ids], -1)) # [T, B, 1] q_targets[mid] = self.critics[mid].t( [BATCH_DICT[id].obs_ for id in self.agent_ids], target_actions) # [T, B, 1] q_loss = {} td_errors = 0. for aid, mid in zip(self.agent_ids, self.model_ids): dc_r = n_step_return( BATCH_DICT[aid].reward, self.gamma, BATCH_DICT[aid].done, q_targets[mid], BATCH_DICT['global'].begin_mask).detach() # [T, B, 1] td_error = dc_r - qs[mid] # [T, B, 1] td_errors += td_error q_loss[aid] = 0.5 * td_error.square().mean() # 1 summaries[aid].update({ 'Statistics/q_min': qs[mid].min(), 'Statistics/q_mean': qs[mid].mean(), 'Statistics/q_max': qs[mid].max() }) self.critic_oplr.optimize(sum(q_loss.values())) actor_loss = {} for aid, mid in zip(self.agent_ids, self.model_ids): if self.is_continuouss[aid]: mu = self.actors[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] else: logits = self.actors[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot( _pi.argmax(-1), self.a_dims[aid]).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] all_actions = {id: BATCH_DICT[id].action for id in self.agent_ids} all_actions[aid] = mu q_actor = self.critics[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat(list(all_actions.values()), -1), begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] actor_loss[aid] = -q_actor.mean() # 1 self.actor_oplr.optimize(sum(actor_loss.values())) for aid in self.agent_ids: summaries[aid].update({ 'LOSS/actor_loss': actor_loss[aid], 'LOSS/critic_loss': q_loss[aid] }) summaries['model'].update({ 'LOSS/actor_loss', sum(actor_loss.values()), 'LOSS/critic_loss', sum(q_loss.values()) }) return td_errors / self.n_agents_percopy, summaries def _after_train(self): super()._after_train() for actor in self.actors.values(): actor.sync() for critic in self.critics.values(): critic.sync()
class AOC(SarlOnPolicy): """ Asynchronous Advantage Option-Critic with Deliberation Cost, A2OC When Waiting is not an Option : Learning Options with a Deliberation Cost, A2OC, http://arxiv.org/abs/1709.04571 """ policy_mode = 'on-policy' def __init__( self, agent_spec, options_num=4, dc=0.01, terminal_mask=False, eps=0.1, pi_beta=1.0e-3, lr=5.0e-4, lambda_=0.95, epsilon=0.2, value_epsilon=0.2, kl_reverse=False, kl_target=0.02, kl_target_cutoff=2, kl_target_earlystop=4, kl_beta=[0.7, 1.3], kl_alpha=1.5, kl_coef=1.0, network_settings={ 'share': [32, 32], 'q': [32, 32], 'intra_option': [32, 32], 'termination': [32, 32] }, **kwargs): super().__init__(agent_spec=agent_spec, **kwargs) self.pi_beta = pi_beta self.lambda_ = lambda_ self._epsilon = epsilon self._value_epsilon = value_epsilon self._kl_reverse = kl_reverse self._kl_target = kl_target self._kl_alpha = kl_alpha self._kl_coef = kl_coef self._kl_cutoff = kl_target * kl_target_cutoff self._kl_stop = kl_target * kl_target_earlystop self._kl_low = kl_target * kl_beta[0] self._kl_high = kl_target * kl_beta[-1] self.options_num = options_num self.dc = dc self.terminal_mask = terminal_mask self.eps = eps self.net = AocShare(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, options_num=self.options_num, network_settings=network_settings, is_continuous=self.is_continuous).to(self.device) if self.is_continuous: self.log_std = th.as_tensor( np.full((self.options_num, self.a_dim), -0.5)).requires_grad_().to(self.device) # [P, A] self.oplr = OPLR([self.net, self.log_std], lr, **self._oplr_params) else: self.oplr = OPLR(self.net, lr, **self._oplr_params) self._trainer_modules.update(model=self.net, oplr=self.oplr) self.oc_mask = th.tensor(np.zeros(self.n_copies)).to(self.device) self.options = th.tensor( np.random.randint(0, self.options_num, self.n_copies)).to(self.device) def episode_reset(self): super().episode_reset() self._done_mask = th.tensor(np.full(self.n_copies, True)).to(self.device) def episode_step(self, obs: Data, env_rets: Data, begin_mask: np.ndarray): super().episode_step(obs, env_rets, begin_mask) self._done_mask = th.tensor(env_rets.done).to(self.device) self.options = self.new_options self.oc_mask = th.zeros_like(self.oc_mask) @iton def select_action(self, obs): # [B, P], [B, P, A], [B, P] (q, pi, beta) = self.net(obs, rnncs=self.rnncs) self.rnncs_ = self.net.get_rnncs() options_onehot = F.one_hot(self.options, self.options_num).float() # [B, P] options_onehot_expanded = options_onehot.unsqueeze(-1) # [B, P, 1] pi = (pi * options_onehot_expanded).sum(-2) # [B, A] if self.is_continuous: mu = pi # [B, A] log_std = self.log_std[self.options] # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] log_prob = dist.log_prob(action).unsqueeze(-1) # [B, 1] else: logits = pi # [B, A] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action).unsqueeze(-1) # [B, 1] value = q_o = (q * options_onehot).sum(-1, keepdim=True) # [B, 1] beta_adv = q_o - ((1 - self.eps) * q.max(-1, keepdim=True)[0] + self.eps * q.mean(-1, keepdim=True)) # [B, 1] max_options = q.argmax(-1) # [B, P] => [B, ] beta_probs = (beta * options_onehot).sum(-1) # [B, P] => [B,] beta_dist = td.Bernoulli(probs=beta_probs) # <1 则不改变op, =1 则改变op new_options = th.where(beta_dist.sample() < 1, self.options, max_options) self.new_options = th.where(self._done_mask, max_options, new_options) self.oc_mask = (self.new_options == self.options).float() acts_info = Data( action=action, value=value, log_prob=log_prob + th.finfo().eps, beta_advantage=beta_adv + self.dc, last_options=self.options, options=self.new_options, reward_offset=-((1 - self.oc_mask) * self.dc).unsqueeze(-1)) if self.use_rnn: acts_info.update(rnncs=self.rnncs) return action, acts_info @iton def _get_value(self, obs, options, rnncs=None): (q, _, _) = self.net(obs, rnncs=rnncs) # [B, P] value = (q * options).sum(-1, keepdim=True) # [B, 1] return value def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH = super()._preprocess_BATCH(BATCH) BATCH.reward += BATCH.reward_offset BATCH.last_options = int2one_hot(BATCH.last_options, self.options_num) BATCH.options = int2one_hot(BATCH.options, self.options_num) value = self._get_value(BATCH.obs_[-1], BATCH.options[-1], rnncs=self.rnncs) BATCH.discounted_reward = discounted_sum(BATCH.reward, self.gamma, BATCH.done, BATCH.begin_mask, init_value=value) td_error = calculate_td_error( BATCH.reward, self.gamma, BATCH.done, value=BATCH.value, next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]), 0)) BATCH.gae_adv = discounted_sum(td_error, self.lambda_ * self.gamma, BATCH.done, BATCH.begin_mask, init_value=0., normalize=True) return BATCH def learn(self, BATCH: Data): BATCH = self._preprocess_BATCH(BATCH) # [T, B, *] for _ in range(self._epochs): kls = [] for _BATCH in BATCH.sample(self._chunk_length, self.batch_size, repeat=self._sample_allow_repeat): _BATCH = self._before_train(_BATCH) summaries, kl = self._train(_BATCH) kls.append(kl) self.summaries.update(summaries) self._after_train() if sum(kls) / len(kls) > self._kl_stop: break @iton def _train(self, BATCH): # [T, B, P], [T, B, P, A], [T, B, P] (q, pi, beta) = self.net(BATCH.obs, begin_mask=BATCH.begin_mask) options_onehot_expanded = BATCH.options.unsqueeze(-1) # [T, B, P, 1] # [T, B, P, A] => [T, B, A] pi = (pi * options_onehot_expanded).sum(-2) value = (q * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] if self.is_continuous: mu = pi # [T, B, A] log_std = self.log_std[BATCH.options.argmax(-1)] # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) new_log_prob = dist.log_prob(BATCH.action).unsqueeze( -1) # [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = pi # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] new_log_prob = (BATCH.action * logp_all).sum( -1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum( -1, keepdim=True).mean() # 1 ratio = (new_log_prob - BATCH.log_prob).exp() # [T, B, 1] if self._kl_reverse: kl = (new_log_prob - BATCH.log_prob).mean() # 1 else: # a sample estimate for KL-divergence, easy to compute kl = (BATCH.log_prob - new_log_prob).mean() surrogate = ratio * BATCH.gae_adv # [T, B, 1] value_clip = BATCH.value + (value - BATCH.value).clamp( -self._value_epsilon, self._value_epsilon) # [T, B, 1] td_error = BATCH.discounted_reward - value # [T, B, 1] td_error_clip = BATCH.discounted_reward - value_clip # [T, B, 1] td_square = th.maximum(td_error.square(), td_error_clip.square()) # [T, B, 1] pi_loss = -th.minimum( surrogate, ratio.clamp(1.0 - self._epsilon, 1.0 + self._epsilon) * BATCH.gae_adv).mean() # [T, B, 1] kl_loss = self._kl_coef * kl extra_loss = 1000.0 * th.maximum(th.zeros_like(kl), kl - self._kl_cutoff).square().mean() pi_loss = pi_loss + kl_loss + extra_loss # 1 q_loss = 0.5 * td_square.mean() # 1 beta_s = (beta * BATCH.last_options).sum(-1, keepdim=True) # [T, B, 1] beta_loss = (beta_s * BATCH.beta_advantage) # [T, B, 1] if self.terminal_mask: beta_loss *= (1 - BATCH.done) # [T, B, 1] beta_loss = beta_loss.mean() # 1 loss = pi_loss + 1.0 * q_loss + beta_loss - self.pi_beta * entropy self.oplr.optimize(loss) if kl > self._kl_high: self._kl_coef *= self._kl_alpha elif kl < self._kl_low: self._kl_coef /= self._kl_alpha return { 'LOSS/loss': loss, 'LOSS/pi_loss': pi_loss, 'LOSS/q_loss': q_loss, 'LOSS/beta_loss': beta_loss, 'Statistics/kl': kl, 'Statistics/kl_coef': self._kl_coef, 'Statistics/entropy': entropy, 'LEARNING_RATE/lr': self.oplr.lr }, kl
class MAXSQN(SarlOffPolicy): """ https://github.com/createamind/DRL/blob/master/spinup/algos/maxsqn/maxsqn.py """ policy_mode = 'off-policy' def __init__(self, alpha=0.2, beta=0.1, polyak=0.995, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, use_epsilon=False, q_lr=5.0e-4, alpha_lr=5.0e-4, auto_adaption=True, network_settings=[32, 32], **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'maxsqn only support discrete action space' self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.use_epsilon = use_epsilon self.polyak = polyak self.auto_adaption = auto_adaption self.target_entropy = beta * np.log(self.a_dim) self.critic = TargetTwin(CriticQvalueAll(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings), self.polyak).to(self.device) self.critic2 = deepcopy(self.critic) self.critic_oplr = OPLR([self.critic, self.critic2], q_lr, **self._oplr_params) if self.auto_adaption: self.log_alpha = th.tensor(0., requires_grad=True).to(self.device) self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params) self._trainer_modules.update(alpha_oplr=self.alpha_oplr) else: self.log_alpha = th.tensor(alpha).log().to(self.device) self._trainer_modules.update(critic=self.critic, critic2=self.critic2, log_alpha=self.log_alpha, critic_oplr=self.critic_oplr) @property def alpha(self): return self.log_alpha.exp() @iton def select_action(self, obs): q = self.critic(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.critic.get_rnncs() if self.use_epsilon and self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: cate_dist = td.Categorical(logits=(q / self.alpha)) mu = q.argmax(-1) # [B,] actions = pi = cate_dist.sample() # [B,] return actions, Data(action=actions) @iton def _train(self, BATCH): q1 = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2 = self.critic2(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q1_eval = (q1 * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q2_eval = (q2 * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q1_log_probs = (q1 / (self.alpha + th.finfo().eps)).log_softmax(-1) # [T, B, A] q1_entropy = -(q1_log_probs.exp() * q1_log_probs).sum(-1, keepdim=True).mean() # 1 q1_target = self.critic.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q2_target = self.critic2.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q1_target_max = q1_target.max(-1, keepdim=True)[0] # [T, B, 1] q1_target_log_probs = (q1_target / (self.alpha + th.finfo().eps)).log_softmax(-1) # [T, B, A] q1_target_entropy = -(q1_target_log_probs.exp() * q1_target_log_probs).sum(-1, keepdim=True) # [T, B, 1] q2_target_max = q2_target.max(-1, keepdim=True)[0] # [T, B, 1] # q2_target_log_probs = q2_target.log_softmax(-1) # q2_target_log_max = q2_target_log_probs.max(1, keepdim=True)[0] q_target = th.minimum(q1_target_max, q2_target_max) + self.alpha * q1_target_entropy # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target, BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1_eval - dc_r # [T, B, 1] td_error2 = q2_eval - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 loss = 0.5 * (q1_loss + q2_loss) self.critic_oplr.optimize(loss) summaries = { 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/loss': loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/q1_entropy': q1_entropy, 'Statistics/q_min': th.minimum(q1, q2).mean(), 'Statistics/q_mean': q1.mean(), 'Statistics/q_max': th.maximum(q1, q2).mean() } if self.auto_adaption: alpha_loss = -(self.alpha * (self.target_entropy - q1_entropy).detach()).mean() self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries def _after_train(self): super()._after_train() self.critic.sync() self.critic2.sync()
class AC(SarlOffPolicy): policy_mode = 'off-policy' # off-policy actor-critic def __init__( self, actor_lr=5.0e-4, critic_lr=1.0e-3, network_settings={ 'actor_continuous': { 'hidden_units': [64, 64], 'condition_sigma': False, 'log_std_bound': [-20, 2] }, 'actor_discrete': [32, 32], 'critic': [32, 32] }, **kwargs): super().__init__(**kwargs) if self.is_continuous: self.actor = ActorMuLogstd( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_continuous']).to( self.device) else: self.actor = ActorDct( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings['actor_discrete']).to( self.device) self.critic = CriticQvalueOne( self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, network_settings=network_settings['critic']).to(self.device) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) @iton def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, *] self.rnncs_ = self.actor.get_rnncs() if self.is_continuous: mu, log_std = output # [B, *] dist = td.Independent(td.Normal(mu, log_std.exp()), -1) action = dist.sample().clamp(-1, 1) # [B, *] log_prob = dist.log_prob(action) # [B,] else: logits = output # [B, *] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action) # [B,] return action, Data(action=action, log_prob=log_prob) def random_action(self): actions = super().random_action() if self.is_continuous: self._acts_info.update(log_prob=np.full(self.n_copies, np.log(0.5))) # [B,] else: self._acts_info.update(log_prob=np.full(self.n_copies, 1. / self.a_dim)) # [B,] return actions @iton def _train(self, BATCH): q = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] if self.is_continuous: next_mu, _ = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, *] max_q_next = self.critic( BATCH.obs_, next_mu, begin_mask=BATCH.begin_mask).detach() # [T, B, 1] else: logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, *] max_a = logits.argmax(-1) # [T, B] max_a_one_hot = F.one_hot(max_a, self.a_dim).float() # [T, B, N] max_q_next = self.critic(BATCH.obs_, max_a_one_hot).detach() # [T, B, 1] td_error = q - n_step_return(BATCH.reward, self.gamma, BATCH.done, max_q_next, BATCH.begin_mask).detach() # [T, B, 1] critic_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, *] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_prob = dist.log_prob(BATCH.action) # [T, B] entropy = dist.entropy().mean() # 1 else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, *] logp_all = logits.log_softmax(-1) # [T, B, *] log_prob = (logp_all * BATCH.action).sum(-1) # [T, B] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 ratio = (log_prob - BATCH.log_prob).exp().detach() # [T, B] actor_loss = -(ratio * log_prob * q.squeeze(-1).detach()).mean() # [T, B] => 1 self.actor_oplr.optimize(actor_loss) return td_error, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/q_max': q.max(), 'Statistics/q_min': q.min(), 'Statistics/q_mean': q.mean(), 'Statistics/ratio': ratio.mean(), 'Statistics/entropy': entropy }