class MAXSQN(SarlOffPolicy): """ https://github.com/createamind/DRL/blob/master/spinup/algos/maxsqn/maxsqn.py """ policy_mode = 'off-policy' def __init__(self, alpha=0.2, beta=0.1, polyak=0.995, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, use_epsilon=False, q_lr=5.0e-4, alpha_lr=5.0e-4, auto_adaption=True, network_settings=[32, 32], **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'maxsqn only support discrete action space' self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.use_epsilon = use_epsilon self.polyak = polyak self.auto_adaption = auto_adaption self.target_entropy = beta * np.log(self.a_dim) self.critic = TargetTwin(CriticQvalueAll(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings), self.polyak).to(self.device) self.critic2 = deepcopy(self.critic) self.critic_oplr = OPLR([self.critic, self.critic2], q_lr, **self._oplr_params) if self.auto_adaption: self.log_alpha = th.tensor(0., requires_grad=True).to(self.device) self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params) self._trainer_modules.update(alpha_oplr=self.alpha_oplr) else: self.log_alpha = th.tensor(alpha).log().to(self.device) self._trainer_modules.update(critic=self.critic, critic2=self.critic2, log_alpha=self.log_alpha, critic_oplr=self.critic_oplr) @property def alpha(self): return self.log_alpha.exp() @iton def select_action(self, obs): q = self.critic(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.critic.get_rnncs() if self.use_epsilon and self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: cate_dist = td.Categorical(logits=(q / self.alpha)) mu = q.argmax(-1) # [B,] actions = pi = cate_dist.sample() # [B,] return actions, Data(action=actions) @iton def _train(self, BATCH): q1 = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2 = self.critic2(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q1_eval = (q1 * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q2_eval = (q2 * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q1_log_probs = (q1 / (self.alpha + th.finfo().eps)).log_softmax(-1) # [T, B, A] q1_entropy = -(q1_log_probs.exp() * q1_log_probs).sum(-1, keepdim=True).mean() # 1 q1_target = self.critic.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q2_target = self.critic2.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q1_target_max = q1_target.max(-1, keepdim=True)[0] # [T, B, 1] q1_target_log_probs = (q1_target / (self.alpha + th.finfo().eps)).log_softmax(-1) # [T, B, A] q1_target_entropy = -(q1_target_log_probs.exp() * q1_target_log_probs).sum(-1, keepdim=True) # [T, B, 1] q2_target_max = q2_target.max(-1, keepdim=True)[0] # [T, B, 1] # q2_target_log_probs = q2_target.log_softmax(-1) # q2_target_log_max = q2_target_log_probs.max(1, keepdim=True)[0] q_target = th.minimum(q1_target_max, q2_target_max) + self.alpha * q1_target_entropy # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target, BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1_eval - dc_r # [T, B, 1] td_error2 = q2_eval - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 loss = 0.5 * (q1_loss + q2_loss) self.critic_oplr.optimize(loss) summaries = { 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/loss': loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/q1_entropy': q1_entropy, 'Statistics/q_min': th.minimum(q1, q2).mean(), 'Statistics/q_mean': q1.mean(), 'Statistics/q_max': th.maximum(q1, q2).mean() } if self.auto_adaption: alpha_loss = -(self.alpha * (self.target_entropy - q1_entropy).detach()).mean() self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries def _after_train(self): super()._after_train() self.critic.sync() self.critic2.sync()
class OC(SarlOffPolicy): """ The Option-Critic Architecture. http://arxiv.org/abs/1609.05140 """ policy_mode = 'off-policy' def __init__(self, q_lr=5.0e-3, intra_option_lr=5.0e-4, termination_lr=5.0e-4, use_eps_greedy=False, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, boltzmann_temperature=1.0, options_num=4, ent_coff=0.01, double_q=False, use_baseline=True, terminal_mask=True, termination_regularizer=0.01, assign_interval=1000, network_settings={ 'q': [32, 32], 'intra_option': [32, 32], 'termination': [32, 32] }, **kwargs): super().__init__(**kwargs) self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.options_num = options_num self.termination_regularizer = termination_regularizer self.ent_coff = ent_coff self.use_baseline = use_baseline self.terminal_mask = terminal_mask self.double_q = double_q self.boltzmann_temperature = boltzmann_temperature self.use_eps_greedy = use_eps_greedy self.q_net = TargetTwin( CriticQvalueAll(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.options_num, network_settings=network_settings['q'])).to( self.device) self.intra_option_net = OcIntraOption( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, options_num=self.options_num, network_settings=network_settings['intra_option']).to(self.device) self.termination_net = CriticQvalueAll( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.options_num, network_settings=network_settings['termination'], out_act='sigmoid').to(self.device) if self.is_continuous: # https://discuss.pytorch.org/t/valueerror-cant-optimize-a-non-leaf-tensor/21751 # https://blog.csdn.net/nkhgl/article/details/100047276 self.log_std = th.as_tensor( np.full((self.options_num, self.a_dim), -0.5)).requires_grad_().to(self.device) # [P, A] self.intra_option_oplr = OPLR( [self.intra_option_net, self.log_std], intra_option_lr, **self._oplr_params) else: self.intra_option_oplr = OPLR(self.intra_option_net, intra_option_lr, **self._oplr_params) self.q_oplr = OPLR(self.q_net, q_lr, **self._oplr_params) self.termination_oplr = OPLR(self.termination_net, termination_lr, **self._oplr_params) self._trainer_modules.update(q_net=self.q_net, intra_option_net=self.intra_option_net, termination_net=self.termination_net, q_oplr=self.q_oplr, intra_option_oplr=self.intra_option_oplr, termination_oplr=self.termination_oplr) self.options = self.new_options = self._generate_random_options() def _generate_random_options(self): # [B,] return th.tensor(np.random.randint(0, self.options_num, self.n_copies)).to(self.device) def episode_step(self, obs: Data, env_rets: Data, begin_mask: np.ndarray): super().episode_step(obs, env_rets, begin_mask) self.options = self.new_options @iton def select_action(self, obs): q = self.q_net(obs, rnncs=self.rnncs) # [B, P] self.rnncs_ = self.q_net.get_rnncs() pi = self.intra_option_net(obs, rnncs=self.rnncs) # [B, P, A] beta = self.termination_net(obs, rnncs=self.rnncs) # [B, P] options_onehot = F.one_hot(self.options, self.options_num).float() # [B, P] options_onehot_expanded = options_onehot.unsqueeze(-1) # [B, P, 1] pi = (pi * options_onehot_expanded).sum(-2) # [B, A] if self.is_continuous: mu = pi.tanh() # [B, A] log_std = self.log_std[self.options] # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) actions = dist.sample().clamp(-1, 1) # [B, A] else: pi = pi / self.boltzmann_temperature # [B, A] dist = td.Categorical(logits=pi) actions = dist.sample() # [B, ] max_options = q.argmax(-1).long() # [B, P] => [B, ] if self.use_eps_greedy: # epsilon greedy if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): self.new_options = self._generate_random_options() else: self.new_options = max_options else: beta_probs = (beta * options_onehot).sum(-1) # [B, P] => [B,] beta_dist = td.Bernoulli(probs=beta_probs) self.new_options = th.where(beta_dist.sample() < 1, self.options, max_options) return actions, Data(action=actions, last_options=self.options, options=self.new_options) def random_action(self): actions = super().random_action() self._acts_info.update( last_options=np.random.randint(0, self.options_num, self.n_copies), options=np.random.randint(0, self.options_num, self.n_copies)) return actions def _preprocess_BATCH(self, BATCH): # [T, B, *] BATCH = super()._preprocess_BATCH(BATCH) BATCH.last_options = int2one_hot(BATCH.last_options, self.options_num) BATCH.options = int2one_hot(BATCH.options, self.options_num) return BATCH @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] beta_next = self.termination_net( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] qu_eval = (q * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] beta_s_ = (beta_next * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] q_s_ = (q_next * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L94 if self.double_q: q_ = self.q_net(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] # [T, B, P] => [T, B] => [T, B, P] max_a_idx = F.one_hot(q_.argmax(-1), self.options_num).float() q_s_max = (q_next * max_a_idx).sum(-1, keepdim=True) # [T, B, 1] else: q_s_max = q_next.max(-1, keepdim=True)[0] # [T, B, 1] u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max # [T, B, 1] qu_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, u_target, BATCH.begin_mask).detach() # [T, B, 1] td_error = qu_target - qu_eval # gradient : q [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # [T, B, 1] => 1 self.q_oplr.optimize(q_loss) q_s = qu_eval.detach() # [T, B, 1] # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L130 if self.use_baseline: adv = (qu_target - q_s).detach() # [T, B, 1] else: adv = qu_target.detach() # [T, B, 1] # [T, B, P] => [T, B, P, 1] options_onehot_expanded = BATCH.options.unsqueeze(-1) pi = self.intra_option_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P, A] # [T, B, P, A] => [T, B, A] pi = (pi * options_onehot_expanded).sum(-2) if self.is_continuous: mu = pi.tanh() # [T, B, A] log_std = self.log_std[BATCH.options.argmax(-1)] # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_p = dist.log_prob(BATCH.action).unsqueeze(-1) # [T, B, 1] entropy = dist.entropy().unsqueeze(-1) # [T, B, 1] else: pi = pi / self.boltzmann_temperature # [T, B, A] log_pi = pi.log_softmax(-1) # [T, B, A] entropy = -(log_pi.exp() * log_pi).sum(-1, keepdim=True) # [T, B, 1] log_p = (BATCH.action * log_pi).sum(-1, keepdim=True) # [T, B, 1] pi_loss = -(log_p * adv + self.ent_coff * entropy).mean() # 1 beta = self.termination_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] beta_s = (beta * BATCH.last_options).sum(-1, keepdim=True) # [T, B, 1] if self.use_eps_greedy: v_s = q.max( -1, keepdim=True)[0] - self.termination_regularizer # [T, B, 1] else: v_s = (1 - beta_s) * q_s + beta_s * q.max( -1, keepdim=True)[0] # [T, B, 1] # v_s = q.mean(-1, keepdim=True) # [T, B, 1] beta_loss = beta_s * (q_s - v_s).detach() # [T, B, 1] # https://github.com/lweitkamp/option-critic-pytorch/blob/0c57da7686f8903ed2d8dded3fae832ee9defd1a/option_critic.py#L238 if self.terminal_mask: beta_loss *= (1 - BATCH.done) # [T, B, 1] beta_loss = beta_loss.mean() # 1 self.intra_option_oplr.optimize(pi_loss) self.termination_oplr.optimize(beta_loss) return td_error, { 'LEARNING_RATE/q_lr': self.q_oplr.lr, 'LEARNING_RATE/intra_option_lr': self.intra_option_oplr.lr, 'LEARNING_RATE/termination_lr': self.termination_oplr.lr, # 'Statistics/option': self.options[0], 'LOSS/q_loss': q_loss, 'LOSS/pi_loss': pi_loss, 'LOSS/beta_loss': beta_loss, 'Statistics/q_option_max': q_s.max(), 'Statistics/q_option_min': q_s.min(), 'Statistics/q_option_mean': q_s.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class C51(SarlOffPolicy): """ Category 51, https://arxiv.org/abs/1707.06887 No double, no dueling, no noisy net. """ policy_mode = 'off-policy' def __init__(self, v_min=-10, v_max=10, atoms=51, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=1000, network_settings=[128, 128], **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'c51 only support discrete action space' self._v_min = v_min self._v_max = v_max self._atoms = atoms self._delta_z = (self._v_max - self._v_min) / (self._atoms - 1) self._z = th.linspace(self._v_min, self._v_max, self._atoms).float().to(self.device) # [N,] self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.q_net = TargetTwin( C51Distributional(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, atoms=self._atoms, network_settings=network_settings)).to( self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): feat = self.q_net(obs, rnncs=self.rnncs) # [B, A, N] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: q = (self._z * feat).sum(-1) # [B, A, N] * [N,] => [B, A] actions = q.argmax(-1) # [B,] return actions, Data(action=actions) @iton def _train(self, BATCH): q_dist = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A, N] # [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N] q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2) q_eval = (q_dist * self._z).sum(-1) # [T, B, N] * [N,] => [T, B] target_q_dist = self.q_net.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A, N] # [T, B, A, N] * [1, N] => [T, B, A] target_q = (target_q_dist * self._z).sum(-1) a_ = target_q.argmax(-1) # [T, B] a_onehot = F.one_hot(a_, self.a_dim).float() # [T, B, A] # [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N] target_q_dist = (target_q_dist * a_onehot.unsqueeze(-1)).sum(-2) target = n_step_return( BATCH.reward.repeat(1, 1, self._atoms), self.gamma, BATCH.done.repeat(1, 1, self._atoms), target_q_dist, BATCH.begin_mask.repeat(1, 1, self._atoms)).detach() # [T, B, N] target = target.clamp(self._v_min, self._v_max) # [T, B, N] # An amazing trick for calculating the projection gracefully. # ref: https://github.com/ShangtongZhang/DeepRL target_dist = ( 1 - (target.unsqueeze(-1) - self._z.view(1, 1, -1, 1)).abs() / self._delta_z).clamp(0, 1) * target_q_dist.unsqueeze( -1) # [T, B, N, 1] target_dist = target_dist.sum(-1) # [T, B, N] _cross_entropy = -(target_dist * th.log(q_dist + th.finfo().eps)).sum( -1, keepdim=True) # [T, B, 1] loss = (_cross_entropy * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(loss) return _cross_entropy, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class DQN(SarlOffPolicy): """ Deep Q-learning Network, DQN, [2013](https://arxiv.org/pdf/1312.5602.pdf), [2015](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) DQN + LSTM, https://arxiv.org/abs/1507.06527 """ policy_mode = 'off-policy' def __init__(self, lr: float = 5.0e-4, eps_init: float = 1, eps_mid: float = 0.2, eps_final: float = 0.01, init2mid_annealing_step: int = 1000, assign_interval: int = 1000, network_settings: List[int] = [32, 32], **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'dqn only support discrete action space' self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.q_net = TargetTwin( CriticQvalueAll(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings)).to(self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net) self._trainer_modules.update(oplr=self.oplr) @iton def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, *] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: actions = q_values.argmax(-1) # [B,] return actions, Data(action=actions) @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q_target = n_step_return( BATCH.reward, self.gamma, BATCH.done, q_next.max(-1, keepdim=True)[0], BATCH.begin_mask, nstep=self._n_step_value).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(q_loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': q_loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class IQN(SarlOffPolicy): """ Implicit Quantile Networks, https://arxiv.org/abs/1806.06923 Double DQN """ policy_mode = 'off-policy' def __init__(self, online_quantiles=8, target_quantiles=8, select_quantiles=32, quantiles_idx=64, huber_delta=1., lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=2, network_settings={ 'q_net': [128, 64], 'quantile': [128, 64], 'tile': [64] }, **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'iqn only support discrete action space' self.online_quantiles = online_quantiles self.target_quantiles = target_quantiles self.select_quantiles = select_quantiles self.quantiles_idx = quantiles_idx self.huber_delta = huber_delta self.assign_interval = assign_interval self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.q_net = TargetTwin(IqnNet(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, quantiles_idx=self.quantiles_idx, network_settings=network_settings)).to(self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): _, select_quantiles_tiled = self._generate_quantiles( # [N*B, X] batch_size=self.n_copies, quantiles_num=self.select_quantiles ) q_values = self.q_net(obs, select_quantiles_tiled, rnncs=self.rnncs) # [N, B, A] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: # [N, B, A] => [B, A] => [B,] actions = q_values.mean(0).argmax(-1) return actions, Data(action=actions) def _generate_quantiles(self, batch_size, quantiles_num): _quantiles = th.rand([quantiles_num * batch_size, 1]) # [N*B, 1] _quantiles_tiled = _quantiles.repeat(1, self.quantiles_idx) # [N*B, 1] => [N*B, X] # pi * i * tau [N*B, X] * [X, ] => [N*B, X] _quantiles_tiled = th.arange(self.quantiles_idx) * np.pi * _quantiles_tiled _quantiles_tiled.cos_() # [N*B, X] _quantiles = _quantiles.view(batch_size, quantiles_num, 1) # [N*B, 1] => [B, N, 1] return _quantiles, _quantiles_tiled # [B, N, 1], [N*B, X] @iton def _train(self, BATCH): time_step = BATCH.reward.shape[0] batch_size = BATCH.reward.shape[1] quantiles, quantiles_tiled = self._generate_quantiles( # [T*B, N, 1], [N*T*B, X] batch_size=time_step * batch_size, quantiles_num=self.online_quantiles) # [T*B, N, 1] => [T, B, N, 1] quantiles = quantiles.view(time_step, batch_size, -1, 1) quantiles_tiled = quantiles_tiled.view(time_step, -1, self.quantiles_idx) # [N*T*B, X] => [T, N*B, X] quantiles_value = self.q_net(BATCH.obs, quantiles_tiled, begin_mask=BATCH.begin_mask) # [T, N, B, A] # [T, N, B, A] => [N, T, B, A] * [T, B, A] => [N, T, B, 1] quantiles_value = (quantiles_value.swapaxes(0, 1) * BATCH.action).sum(-1, keepdim=True) q_eval = quantiles_value.mean(0) # [N, T, B, 1] => [T, B, 1] _, select_quantiles_tiled = self._generate_quantiles( # [N*T*B, X] batch_size=time_step * batch_size, quantiles_num=self.select_quantiles) select_quantiles_tiled = select_quantiles_tiled.view( time_step, -1, self.quantiles_idx) # [N*T*B, X] => [T, N*B, X] q_values = self.q_net( BATCH.obs_, select_quantiles_tiled, begin_mask=BATCH.begin_mask) # [T, N, B, A] q_values = q_values.mean(1) # [T, N, B, A] => [T, B, A] next_max_action = q_values.argmax(-1) # [T, B] next_max_action = F.one_hot( next_max_action, self.a_dim).float() # [T, B, A] _, target_quantiles_tiled = self._generate_quantiles( # [N'*T*B, X] batch_size=time_step * batch_size, quantiles_num=self.target_quantiles) target_quantiles_tiled = target_quantiles_tiled.view( time_step, -1, self.quantiles_idx) # [N'*T*B, X] => [T, N'*B, X] target_quantiles_value = self.q_net.t(BATCH.obs_, target_quantiles_tiled, begin_mask=BATCH.begin_mask) # [T, N', B, A] target_quantiles_value = target_quantiles_value.swapaxes(0, 1) # [T, N', B, A] => [N', T, B, A] target_quantiles_value = (target_quantiles_value * next_max_action).sum(-1, keepdim=True) # [N', T, B, 1] target_q = target_quantiles_value.mean(0) # [T, B, 1] q_target = n_step_return(BATCH.reward, # [T, B, 1] self.gamma, BATCH.done, # [T, B, 1] target_q, # [T, B, 1] BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] # [N', T, B, 1] => [N', T, B] target_quantiles_value = target_quantiles_value.squeeze(-1) target_quantiles_value = target_quantiles_value.permute( 1, 2, 0) # [N', T, B] => [T, B, N'] quantiles_value_target = n_step_return(BATCH.reward.repeat(1, 1, self.target_quantiles), self.gamma, BATCH.done.repeat(1, 1, self.target_quantiles), target_quantiles_value, BATCH.begin_mask.repeat(1, 1, self.target_quantiles)).detach() # [T, B, N'] # [T, B, N'] => [T, B, 1, N'] quantiles_value_target = quantiles_value_target.unsqueeze(-2) quantiles_value_online = quantiles_value.permute(1, 2, 0, 3) # [N, T, B, 1] => [T, B, N, 1] # [T, B, N, 1] - [T, B, 1, N'] => [T, B, N, N'] quantile_error = quantiles_value_online - quantiles_value_target huber = F.huber_loss(quantiles_value_online, quantiles_value_target, reduction="none", delta=self.huber_delta) # [T, B, N, N] # [T, B, N, 1] - [T, B, N, N'] => [T, B, N, N'] huber_abs = (quantiles - quantile_error.detach().le(0.).float()).abs() loss = (huber_abs * huber).mean(-1) # [T, B, N, N'] => [T, B, N] loss = loss.sum(-1, keepdim=True) # [T, B, N] => [T, B, 1] loss = (loss * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class QRDQN(SarlOffPolicy): """ Quantile Regression DQN Distributional Reinforcement Learning with Quantile Regression, https://arxiv.org/abs/1710.10044 No double, no dueling, no noisy net. """ policy_mode = 'off-policy' def __init__(self, nums=20, huber_delta=1., lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=1000, network_settings=[128, 128], **kwargs): assert nums > 0, 'assert nums > 0' super().__init__(**kwargs) assert not self.is_continuous, 'qrdqn only support discrete action space' self.nums = nums self.huber_delta = huber_delta self.quantiles = th.tensor((2 * np.arange(self.nums) + 1) / (2.0 * self.nums)).float().to(self.device) # [N,] self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.q_net = TargetTwin(QrdqnDistributional(self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, nums=self.nums, network_settings=network_settings)).to(self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, A, N] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: q = q_values.mean(-1) # [B, A, N] => [B, A] actions = q.argmax(-1) # [B,] return actions, Data(action=actions) @iton def _train(self, BATCH): q_dist = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A, N] q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2) # [T, B, A, N] => [T, B, N] target_q_dist = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A, N] target_q = target_q_dist.mean(-1) # [T, B, A, N] => [T, B, A] _a = target_q.argmax(-1) # [T, B] next_max_action = F.one_hot(_a, self.a_dim).float().unsqueeze(-1) # [T, B, A, 1] # [T, B, A, N] => [T, B, N] target_q_dist = (target_q_dist * next_max_action).sum(-2) target = n_step_return(BATCH.reward.repeat(1, 1, self.nums), self.gamma, BATCH.done.repeat(1, 1, self.nums), target_q_dist, BATCH.begin_mask.repeat(1, 1, self.nums)).detach() # [T, B, N] q_eval = q_dist.mean(-1, keepdim=True) # [T, B, 1] q_target = target.mean(-1, keepdim=True) # [T, B, 1] td_error = q_target - q_eval # [T, B, 1], used for PER target = target.unsqueeze(-2) # [T, B, 1, N] q_dist = q_dist.unsqueeze(-1) # [T, B, N, 1] # [T, B, 1, N] - [T, B, N, 1] => [T, B, N, N] quantile_error = target - q_dist huber = F.huber_loss(target, q_dist, reduction="none", delta=self.huber_delta) # [T, B, N, N] # [N,] - [T, B, N, N] => [T, B, N, N] huber_abs = (self.quantiles - quantile_error.detach().le(0.).float()).abs() loss = (huber_abs * huber).mean(-1) # [T, B, N, N] => [T, B, N] loss = loss.sum(-1, keepdim=True) # [T, B, N] => [T, B, 1] loss = (loss * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class DreamerV1(SarlOffPolicy): """ Dream to Control: Learning Behaviors by Latent Imagination, http://arxiv.org/abs/1912.01603 """ policy_mode = 'off-policy' def __init__(self, eps_init: float = 1, eps_mid: float = 0.2, eps_final: float = 0.01, init2mid_annealing_step: int = 1000, stoch_dim=30, deter_dim=200, model_lr=6e-4, actor_lr=8e-5, critic_lr=8e-5, kl_free_nats=3, action_sigma=0.3, imagination_horizon=15, lambda_=0.95, kl_scale=1.0, reward_scale=1.0, use_pcont=False, pcont_scale=10.0, network_settings=dict(), **kwargs): super().__init__(**kwargs) assert self.use_rnn == False, 'assert self.use_rnn == False' if self.obs_spec.has_visual_observation \ and len(self.obs_spec.visual_dims) == 1 \ and not self.obs_spec.has_vector_observation: visual_dim = self.obs_spec.visual_dims[0] # TODO: optimize this assert visual_dim[0] == visual_dim[ 1] == 64, 'visual dimension must be [64, 64, *]' self._is_visual = True elif self.obs_spec.has_vector_observation \ and len(self.obs_spec.vector_dims) == 1 \ and not self.obs_spec.has_visual_observation: self._is_visual = False else: raise ValueError("please check the observation type") self.stoch_dim = stoch_dim self.deter_dim = deter_dim self.kl_free_nats = kl_free_nats self.imagination_horizon = imagination_horizon self.lambda_ = lambda_ self.kl_scale = kl_scale self.reward_scale = reward_scale # https://github.com/danijar/dreamer/issues/2 self.use_pcont = use_pcont # probability of continuing self.pcont_scale = pcont_scale self._action_sigma = action_sigma self._network_settings = network_settings if not self.is_continuous: self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) if self.obs_spec.has_visual_observation: from rls.nn.dreamer import VisualDecoder, VisualEncoder self.obs_encoder = VisualEncoder( self.obs_spec.visual_dims[0], **network_settings['obs_encoder']['visual']).to(self.device) self.obs_decoder = VisualDecoder( self.decoder_input_dim, self.obs_spec.visual_dims[0], **network_settings['obs_decoder']['visual']).to(self.device) else: from rls.nn.dreamer import VectorEncoder self.obs_encoder = VectorEncoder( self.obs_spec.vector_dims[0], **network_settings['obs_encoder']['vector']).to(self.device) self.obs_decoder = DenseModel( self.decoder_input_dim, self.obs_spec.vector_dims[0], **network_settings['obs_decoder']['vector']).to(self.device) self.rssm = self._dreamer_build_rssm() """ p(r_t | s_t, h_t) Reward model to predict reward from state and rnn hidden state """ self.reward_predictor = DenseModel(self.decoder_input_dim, 1, **network_settings['reward']).to( self.device) self.actor = ActionDecoder(self.a_dim, self.decoder_input_dim, dist=self._action_dist, **network_settings['actor']).to(self.device) self.critic = self._dreamer_build_critic() _modules = [ self.obs_encoder, self.rssm, self.obs_decoder, self.reward_predictor ] if self.use_pcont: self.pcont_decoder = DenseModel(self.decoder_input_dim, 1, **network_settings['pcont']).to( self.device) _modules.append(self.pcont_decoder) self.model_oplr = OPLR(_modules, model_lr, **self._oplr_params) self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params) self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params) self._trainer_modules.update(obs_encoder=self.obs_encoder, obs_decoder=self.obs_decoder, reward_predictor=self.reward_predictor, rssm=self.rssm, actor=self.actor, critic=self.critic, model_oplr=self.model_oplr, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr) if self.use_pcont: self._trainer_modules.update(pcont_decoder=self.pcont_decoder) @property def _action_dist(self): return 'tanh_normal' if self.is_continuous else 'one_hot' # 'relaxed_one_hot' @property def decoder_input_dim(self): return self.stoch_dim + self.deter_dim def _dreamer_build_rssm(self): return RecurrentStateSpaceModel( self.stoch_dim, self.deter_dim, self.a_dim, self.obs_encoder.h_dim, **self._network_settings['rssm']).to(self.device) def _dreamer_build_critic(self): return DenseModel(self.decoder_input_dim, 1, **self._network_settings['critic']).to(self.device) @iton def select_action(self, obs): if self._is_visual: obs = get_first_visual(obs) else: obs = get_first_vector(obs) embedded_obs = self.obs_encoder(obs) # [B, *] state_posterior = self.rssm.posterior(self.rnncs['hx'], embedded_obs) state = state_posterior.sample() # [B, *] actions = self.actor.sample_actions(th.cat((state, self.rnncs['hx']), -1), is_train=self._is_train_mode) actions = self._exploration(actions) _, self.rnncs_['hx'] = self.rssm.prior(state, actions, self.rnncs['hx']) if not self.is_continuous: actions = actions.argmax(-1) # [B,] return actions, Data(action=actions) def _exploration(self, action: th.Tensor) -> th.Tensor: """ :param action: action to take, shape (1,) (if categorical), or (action dim,) (if continuous) :return: action of the same shape passed in, augmented with some noise """ if self.is_continuous: sigma = self._action_sigma if self._is_train_mode else 0. noise = th.randn(*action.shape) * sigma return th.clamp(action + noise, -1, 1) else: if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): index = th.randint(0, self.a_dim, (self.n_copies, )) action = th.zeros_like(action) action[..., index] = 1 return action @iton def _train(self, BATCH): T, B = BATCH.action.shape[:2] if self._is_visual: obs_ = get_first_visual(BATCH.obs_) else: obs_ = get_first_vector(BATCH.obs_) # embed observations with CNN embedded_observations = self.obs_encoder(obs_) # [T, B, *] # initialize state and rnn hidden state with 0 vector state, rnn_hidden = self.rssm.init_state(shape=B) # [B, S], [B, D] # compute state and rnn hidden sequences and kl loss kl_loss = 0 states, rnn_hiddens = [], [] for l in range(T): # if the begin of this episode, then reset to 0. # No matther whether last episode is beened truncated of not. state = state * (1. - BATCH.begin_mask[l]) # [B, S] rnn_hidden = rnn_hidden * (1. - BATCH.begin_mask[l]) # [B, D] next_state_prior, next_state_posterior, rnn_hidden = self.rssm( state, BATCH.action[l], rnn_hidden, embedded_observations[l]) # a, s_ state = next_state_posterior.rsample() # [B, S] posterior of s_ states.append(state) # [B, S] rnn_hiddens.append(rnn_hidden) # [B, D] kl_loss += self._kl_loss(next_state_prior, next_state_posterior) kl_loss /= T # 1 # compute reconstructed observations and predicted rewards post_feat = th.cat([th.stack(states, 0), th.stack(rnn_hiddens, 0)], -1) # [T, B, *] obs_pred = self.obs_decoder(post_feat) # [T, B, C, H, W] or [T, B, *] reward_pred = self.reward_predictor(post_feat) # [T, B, 1], s_ => r # compute loss for observation and reward obs_loss = -th.mean(obs_pred.log_prob(obs_)) # [T, B] => 1 # [T, B, 1]=>1 reward_loss = -th.mean( reward_pred.log_prob(BATCH.reward).unsqueeze(-1)) # add all losses and update model parameters with gradient descent model_loss = self.kl_scale * kl_loss + obs_loss + self.reward_scale * reward_loss # 1 if self.use_pcont: pcont_pred = self.pcont_decoder(post_feat) # [T, B, 1], s_ => done # https://github.com/danijar/dreamer/issues/2#issuecomment-605392659 pcont_target = self.gamma * (1. - BATCH.done) # [T, B, 1]=>1 pcont_loss = -th.mean( pcont_pred.log_prob(pcont_target).unsqueeze(-1)) model_loss += self.pcont_scale * pcont_loss self.model_oplr.optimize(model_loss) # remove gradients from previously calculated tensors with th.no_grad(): # [T, B, S] => [T*B, S] flatten_states = th.cat(states, 0).detach() # [T, B, D] => [T*B, D] flatten_rnn_hiddens = th.cat(rnn_hiddens, 0).detach() with FreezeParameters(self.model_oplr.parameters): # compute target values imaginated_states = [] imaginated_rnn_hiddens = [] log_probs = [] entropies = [] for h in range(self.imagination_horizon): imaginated_states.append(flatten_states) # [T*B, S] imaginated_rnn_hiddens.append(flatten_rnn_hiddens) # [T*B, D] flatten_feat = th.cat([flatten_states, flatten_rnn_hiddens], -1).detach() action_dist = self.actor(flatten_feat) actions = action_dist.rsample() # [T*B, A] log_probs.append( action_dist.log_prob( actions.detach()).unsqueeze(-1)) # [T*B, 1] entropies.append( action_dist.entropy().unsqueeze(-1)) # [T*B, 1] flatten_states_prior, flatten_rnn_hiddens = self.rssm.prior( flatten_states, actions, flatten_rnn_hiddens) flatten_states = flatten_states_prior.rsample() # [T*B, S] imaginated_states = th.stack(imaginated_states, 0) # [H, T*B, S] imaginated_rnn_hiddens = th.stack(imaginated_rnn_hiddens, 0) # [H, T*B, D] log_probs = th.stack(log_probs, 0) # [H, T*B, 1] entropies = th.stack(entropies, 0) # [H, T*B, 1] imaginated_feats = th.cat([imaginated_states, imaginated_rnn_hiddens], -1) # [H, T*B, *] with FreezeParameters(self.model_oplr.parameters + self.critic_oplr.parameters): imaginated_rewards = self.reward_predictor( imaginated_feats).mean # [H, T*B, 1] imaginated_values = self._dreamer_target_img_value( imaginated_feats) # [H, T*B, 1]] # Compute the exponential discounted sum of rewards if self.use_pcont: with FreezeParameters(self.pcont_decoder.parameters()): discount_arr = self.pcont_decoder( imaginated_feats).mean # [H, T*B, 1] else: discount_arr = self.gamma * th.ones_like( imaginated_rewards) # [H, T*B, 1] returns = compute_return(imaginated_rewards[:-1], imaginated_values[:-1], discount_arr[:-1], bootstrap=imaginated_values[-1], lambda_=self.lambda_) # [H-1, T*B, 1] # Make the top row 1 so the cumulative product starts with discount^0 discount_arr = th.cat( [th.ones_like(discount_arr[:1]), discount_arr[:-1]], 0) # [H, T*B, 1] discount = th.cumprod(discount_arr, 0).detach()[:-1] # [H-1, T*B, 1] # discount_arr = th.cat([th.ones_like(discount_arr[:1]), discount_arr[1:]]) # discount = th.cumprod(discount_arr[:-1], 0) actor_loss = self._dreamer_build_actor_loss(imaginated_feats, log_probs, entropies, discount, returns) # 1 # Don't let gradients pass through to prevent overwriting gradients. # Value Loss with th.no_grad(): value_feat = imaginated_feats[:-1].detach() # [H-1, T*B, 1] value_target = returns.detach() # [H-1, T*B, 1] value_pred = self.critic(value_feat) # [H-1, T*B, 1] log_prob = value_pred.log_prob(value_target).unsqueeze( -1) # [H-1, T*B, 1] critic_loss = -th.mean(discount * log_prob) # 1 self.actor_oplr.zero_grad() self.critic_oplr.zero_grad() self.actor_oplr.backward(actor_loss) self.critic_oplr.backward(critic_loss) self.actor_oplr.step() self.critic_oplr.step() td_error = (value_pred.mean - value_target).mean(0).detach() # [T*B,] td_error = td_error.view(T, B, 1) summaries = { 'LEARNING_RATE/model_lr': self.model_oplr.lr, 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/model_loss': model_loss, 'LOSS/kl_loss': kl_loss, 'LOSS/obs_loss': obs_loss, 'LOSS/reward_loss': reward_loss, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss } if self.use_pcont: summaries.update({'LOSS/pcont_loss', pcont_loss}) return td_error, summaries def _initial_rnncs(self, batch: int) -> Dict[str, np.ndarray]: return {'hx': np.zeros((batch, self.deter_dim))} def _kl_loss(self, prior_dist, post_dist): # 1 return td.kl_divergence(prior_dist, post_dist).clamp(min=self.kl_free_nats).mean() def _dreamer_target_img_value(self, imaginated_feats): imaginated_values = self.critic(imaginated_feats).mean # [H, T*B, 1] return imaginated_values def _dreamer_build_actor_loss(self, imaginated_feats, log_probs, entropies, discount, returns): actor_loss = -th.mean(discount * returns) # [H-1, T*B, 1] => 1 return actor_loss
class BCQ(SarlOffPolicy): """ Benchmarking Batch Deep Reinforcement Learning Algorithms, http://arxiv.org/abs/1910.01708 Off-Policy Deep Reinforcement Learning without Exploration, http://arxiv.org/abs/1812.02900 """ policy_mode = 'off-policy' def __init__(self, polyak=0.995, discrete=dict(threshold=0.3, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=1000, network_settings=[32, 32]), continuous=dict(phi=0.05, lmbda=0.75, select_samples=100, train_samples=10, actor_lr=1e-3, critic_lr=1e-3, vae_lr=1e-3, network_settings=dict( actor=[32, 32], critic=[32, 32], vae=dict(encoder=[750, 750], decoder=[750, 750]))), **kwargs): super().__init__(**kwargs) self._polyak = polyak if self.is_continuous: self._lmbda = continuous['lmbda'] self._select_samples = continuous['select_samples'] self._train_samples = continuous['train_samples'] self.actor = TargetTwin(BCQ_Act_Cts( self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, phi=continuous['phi'], network_settings=continuous['network_settings']['actor']), polyak=self._polyak).to(self.device) self.critic = TargetTwin(BCQ_CriticQvalueOne( self.obs_spec, rep_net_params=self._rep_net_params, action_dim=self.a_dim, network_settings=continuous['network_settings']['critic']), polyak=self._polyak).to(self.device) self.vae = VAE(self.obs_spec, rep_net_params=self._rep_net_params, a_dim=self.a_dim, z_dim=self.a_dim * 2, hiddens=continuous['network_settings']['vae']).to( self.device) self.actor_oplr = OPLR(self.actor, continuous['actor_lr'], **self._oplr_params) self.critic_oplr = OPLR(self.critic, continuous['critic_lr'], **self._oplr_params) self.vae_oplr = OPLR(self.vae, continuous['vae_lr'], **self._oplr_params) self._trainer_modules.update(actor=self.actor, critic=self.critic, vae=self.vae, actor_oplr=self.actor_oplr, critic_oplr=self.critic_oplr, vae_oplr=self.vae_oplr) else: self.expl_expt_mng = ExplorationExploitationClass( eps_init=discrete['eps_init'], eps_mid=discrete['eps_mid'], eps_final=discrete['eps_final'], init2mid_annealing_step=discrete['init2mid_annealing_step'], max_step=self._max_train_step) self.assign_interval = discrete['assign_interval'] self._threshold = discrete['threshold'] self.q_net = TargetTwin(BCQ_DCT( self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=discrete['network_settings']), polyak=self._polyak).to(self.device) self.oplr = OPLR(self.q_net, discrete['lr'], **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): if self.is_continuous: _actions = [] for _ in range(self._select_samples): _actions.append( self.actor(obs, self.vae.decode(obs), rnncs=self.rnncs)) # [B, A] self.rnncs_ = self.actor.get_rnncs( ) # TODO: calculate corrected hidden state _actions = th.stack(_actions, dim=0) # [N, B, A] q1s = [] for i in range(self._select_samples): q1s.append(self.critic(obs, _actions[i])[0]) q1s = th.stack(q1s, dim=0) # [N, B, 1] max_idxs = q1s.argmax(dim=0, keepdim=True)[-1] # [1, B, 1] actions = _actions[ max_idxs, th.arange(self.n_copies).reshape(self.n_copies, 1), th.arange(self.a_dim)] else: q_values, i_values = self.q_net(obs, rnncs=self.rnncs) # [B, *] q_values = q_values - q_values.min(dim=-1, keepdim=True)[0] # [B, *] i_values = F.log_softmax(i_values, dim=-1) # [B, *] i_values = i_values.exp() # [B, *] i_values = (i_values / i_values.max(-1, keepdim=True)[0] > self._threshold).float() # [B, *] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: actions = (i_values * q_values).argmax(-1) # [B,] return actions, Data(action=actions) @iton def _train(self, BATCH): if self.is_continuous: # Variational Auto-Encoder Training recon, mean, std = self.vae(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) recon_loss = F.mse_loss(recon, BATCH.action) KL_loss = -0.5 * (1 + th.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() vae_loss = recon_loss + 0.5 * KL_loss self.vae_oplr.optimize(vae_loss) target_Qs = [] for _ in range(self._train_samples): # Compute value of perturbed actions sampled from the VAE _vae_actions = self.vae.decode(BATCH.obs_, begin_mask=BATCH.begin_mask) _actor_actions = self.actor.t(BATCH.obs_, _vae_actions, begin_mask=BATCH.begin_mask) target_Q1, target_Q2 = self.critic.t( BATCH.obs_, _actor_actions, begin_mask=BATCH.begin_mask) # Soft Clipped Double Q-learning target_Q = self._lmbda * th.min(target_Q1, target_Q2) + \ (1. - self._lmbda) * th.max(target_Q1, target_Q2) target_Qs.append(target_Q) target_Qs = th.stack(target_Qs, dim=0) # [N, T, B, 1] # Take max over each BATCH.action sampled from the VAE target_Q = target_Qs.max(dim=0)[0] # [T, B, 1] target_Q = n_step_return(BATCH.reward, self.gamma, BATCH.done, target_Q, BATCH.begin_mask).detach() # [T, B, 1] current_Q1, current_Q2 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) td_error = ((current_Q1 - target_Q) + (current_Q2 - target_Q)) / 2 critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) self.critic_oplr.optimize(critic_loss) # Pertubation Model / Action Training sampled_actions = self.vae.decode(BATCH.obs, begin_mask=BATCH.begin_mask) perturbed_actions = self.actor(BATCH.obs, sampled_actions, begin_mask=BATCH.begin_mask) # Update through DPG q1, _ = self.critic(BATCH.obs, perturbed_actions, begin_mask=BATCH.begin_mask) actor_loss = -q1.mean() self.actor_oplr.optimize(actor_loss) return td_error, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LEARNING_RATE/vae_lr': self.vae_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'LOSS/vae_loss': vae_loss, 'Statistics/q_min': q1.min(), 'Statistics/q_mean': q1.mean(), 'Statistics/q_max': q1.max() } else: q_next, i_next = self.q_net( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_next = q_next - q_next.min(dim=-1, keepdim=True)[0] # [B, *] i_next = F.log_softmax(i_next, dim=-1) # [T, B, A] i_next = i_next.exp() # [T, B, A] i_next = (i_next / i_next.max(-1, keepdim=True)[0] > self._threshold).float() # [T, B, A] q_next = i_next * q_next # [T, B, A] next_max_action = q_next.argmax(-1) # [T, B] next_max_action_one_hot = F.one_hot( next_max_action.squeeze(), self.a_dim).float() # [T, B, A] q_target_next, _ = self.q_net.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_target_next_max = (q_target_next * next_max_action_one_hot).sum( -1, keepdim=True) # [T, B, 1] q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target_next_max, BATCH.begin_mask).detach() # [T, B, 1] q, i = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 imt = F.log_softmax(i, dim=-1) # [T, B, A] imt = imt.reshape(-1, self.a_dim) # [T*B, A] action = BATCH.action.reshape(-1, self.a_dim) # [T*B, A] i_loss = F.nll_loss(imt, action.argmax(-1)) # 1 loss = q_loss + i_loss + 1e-2 * i.pow(2).mean() self.oplr.optimize(loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/q_loss': q_loss, 'LOSS/i_loss': i_loss, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self.is_continuous: self.actor.sync() self.critic.sync() else: if self._polyak != 0: self.q_net.sync() else: if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class AveragedDQN(SarlOffPolicy): """ Averaged-DQN, http://arxiv.org/abs/1611.01929 """ policy_mode = 'off-policy' def __init__(self, target_k: int = 4, lr: float = 5.0e-4, eps_init: float = 1, eps_mid: float = 0.2, eps_final: float = 0.01, init2mid_annealing_step: int = 1000, assign_interval: int = 1000, network_settings: List[int] = [32, 32], **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'dqn only support discrete action space' self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.target_k = target_k assert self.target_k > 0, "assert self.target_k > 0" self.current_target_idx = 0 self.q_net = CriticQvalueAll(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings).to( self.device) self.target_nets = [] for i in range(self.target_k): target_q_net = deepcopy(self.q_net) target_q_net.eval() sync_params(target_q_net, self.q_net) self.target_nets.append(target_q_net) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, *] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: for i in range(self.target_k): target_q_values = self.target_nets[i](obs, rnncs=self.rnncs) q_values += target_q_values actions = q_values.argmax(-1) # 不取平均也可以 [B, ] return actions, Data(action=actions) @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, *] q_next = 0 for i in range(self.target_k): q_next += self.target_nets[i](BATCH.obs_, begin_mask=BATCH.begin_mask) q_next /= self.target_k # [T, B, *] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_next.max(-1, keepdim=True)[0], BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(q_loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': q_loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: sync_params(self.target_nets[self.current_target_idx], self.q_net) self.current_target_idx = (self.current_target_idx + 1) % self.target_k
class DDDQN(SarlOffPolicy): """ Dueling Double DQN, https://arxiv.org/abs/1511.06581 """ policy_mode = 'off-policy' def __init__(self, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=2, network_settings={ 'share': [128], 'v': [128], 'adv': [128] }, **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'dueling double dqn only support discrete action space' self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.q_net = TargetTwin(CriticDueling(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, network_settings=network_settings)).to(self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) @iton def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: actions = q_values.argmax(-1) # [B,] return actions, Data(action=actions) @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] next_q = self.q_net(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_target = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] next_max_action = next_q.argmax(-1) # [T, B] next_max_action_one_hot = F.one_hot(next_max_action.squeeze(), self.a_dim).float() # [T, B, A] q_target_next_max = (q_target * next_max_action_one_hot).sum(-1, keepdim=True) # [T, B, 1] q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target_next_max, BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(q_loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': q_loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()
class VDN(MultiAgentOffPolicy): """ Value-Decomposition Networks For Cooperative Multi-Agent Learning, http://arxiv.org/abs/1706.05296 QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning, http://arxiv.org/abs/1803.11485 Qatten: A General Framework for Cooperative Multiagent Reinforcement Learning, http://arxiv.org/abs/2002.03939 """ policy_mode = 'off-policy' def __init__(self, mixer='vdn', mixer_settings={}, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, use_double=True, init2mid_annealing_step=1000, assign_interval=1000, network_settings={ 'share': [128], 'v': [128], 'adv': [128] }, **kwargs): super().__init__(**kwargs) assert not any(list(self.is_continuouss.values()) ), 'VDN only support discrete action space' self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self._use_double = use_double self._mixer_type = mixer self._mixer_settings = mixer_settings self.q_nets = {} for id in set(self.model_ids): self.q_nets[id] = TargetTwin( CriticDueling(self.obs_specs[id], rep_net_params=self._rep_net_params, output_shape=self.a_dims[id], network_settings=network_settings)).to( self.device) self.mixer = self._build_mixer() self.oplr = OPLR( tuple(self.q_nets.values()) + (self.mixer, ), lr, **self._oplr_params) self._trainer_modules.update( {f"model_{id}": self.q_nets[id] for id in set(self.model_ids)}) self._trainer_modules.update(mixer=self.mixer, oplr=self.oplr) def _build_mixer(self): assert self._mixer_type in [ 'vdn', 'qmix', 'qatten' ], "assert self._mixer_type in ['vdn', 'qmix', 'qatten']" if self._mixer_type in ['qmix', 'qatten']: assert self._has_global_state, 'assert self._has_global_state' return TargetTwin(Mixer_REGISTER[self._mixer_type]( n_agents=self.n_agents_percopy, state_spec=self.state_spec, rep_net_params=self._rep_net_params, **self._mixer_settings)).to(self.device) @iton # TODO: optimization def select_action(self, obs): acts_info = {} actions = {} for aid, mid in zip(self.agent_ids, self.model_ids): q_values = self.q_nets[mid](obs[aid], rnncs=self.rnncs[aid]) # [B, A] self.rnncs_[aid] = self.q_nets[mid].get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): action = np.random.randint(0, self.a_dims[aid], self.n_copies) else: action = q_values.argmax(-1) # [B,] actions[aid] = action acts_info[aid] = Data(action=action) return actions, acts_info @iton def _train(self, BATCH_DICT): summaries = {} reward = BATCH_DICT[self.agent_ids[0]].reward # [T, B, 1] done = 0. q_evals = [] q_target_next_choose_maxs = [] for aid, mid in zip(self.agent_ids, self.model_ids): done += BATCH_DICT[aid].done # [T, B, 1] q = self.q_nets[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] q_eval = (q * BATCH_DICT[aid].action).sum( -1, keepdim=True) # [T, B, 1] q_evals.append(q_eval) # N * [T, B, 1] q_target = self.q_nets[mid].t( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] if self._use_double: next_q = self.q_nets[mid]( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] next_max_action = next_q.argmax(-1) # [T, B] next_max_action_one_hot = F.one_hot( next_max_action, self.a_dims[aid]).float() # [T, B, A] q_target_next_max = (q_target * next_max_action_one_hot).sum( -1, keepdim=True) # [T, B, 1] else: # [T, B, 1] q_target_next_max = q_target.max(-1, keepdim=True)[0] q_target_next_choose_maxs.append( q_target_next_max) # N * [T, B, 1] q_evals = th.stack(q_evals, -1) # [T, B, 1, N] q_target_next_choose_maxs = th.stack(q_target_next_choose_maxs, -1) # [T, B, 1, N] q_eval_tot = self.mixer( q_evals, BATCH_DICT['global'].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] q_target_next_max_tot = self.mixer.t( q_target_next_choose_maxs, BATCH_DICT['global'].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] q_target_tot = n_step_return( reward, self.gamma, (done > 0.).float(), q_target_next_max_tot, BATCH_DICT['global'].begin_mask).detach() # [T, B, 1] td_error = q_target_tot - q_eval_tot # [T, B, 1] q_loss = td_error.square().mean() # 1 self.oplr.optimize(q_loss) summaries['model'] = { 'LOSS/q_loss': q_loss, 'Statistics/q_max': q_eval_tot.max(), 'Statistics/q_min': q_eval_tot.min(), 'Statistics/q_mean': q_eval_tot.mean() } return td_error, summaries def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: for q_net in self.q_nets.values(): q_net.sync() self.mixer.sync()
class BootstrappedDQN(SarlOffPolicy): """ Deep Exploration via Bootstrapped DQN, http://arxiv.org/abs/1602.04621 """ policy_mode = 'off-policy' def __init__(self, lr=5.0e-4, eps_init=1, eps_mid=0.2, eps_final=0.01, init2mid_annealing_step=1000, assign_interval=1000, head_num=4, network_settings=[32, 32], **kwargs): super().__init__(**kwargs) assert not self.is_continuous, 'Bootstrapped DQN only support discrete action space' self.expl_expt_mng = ExplorationExploitationClass( eps_init=eps_init, eps_mid=eps_mid, eps_final=eps_final, init2mid_annealing_step=init2mid_annealing_step, max_step=self._max_train_step) self.assign_interval = assign_interval self.head_num = head_num self._probs = th.FloatTensor([1. / head_num for _ in range(head_num)]) self.now_head = 0 self.q_net = TargetTwin( CriticQvalueBootstrap(self.obs_spec, rep_net_params=self._rep_net_params, output_shape=self.a_dim, head_num=self.head_num, network_settings=network_settings)).to( self.device) self.oplr = OPLR(self.q_net, lr, **self._oplr_params) self._trainer_modules.update(model=self.q_net, oplr=self.oplr) def episode_reset(self): super().episode_reset() self.now_head = np.random.randint(self.head_num) @iton def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [H, B, A] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: # [H, B, A] => [B, A] => [B, ] actions = q_values[self.now_head].argmax(-1) return actions, Data(action=actions) @iton def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask).mean( 0) # [H, T, B, A] => [T, B, A] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask).mean( 0) # [H, T, B, A] => [T, B, A] # [T, B, A] * [T, B, A] => [T, B, 1] q_eval = (q * BATCH.action).sum(-1, keepdim=True) q_target = n_step_return( BATCH.reward, self.gamma, BATCH.done, # [T, B, A] => [T, B, 1] q_next.max(-1, keepdim=True)[0], BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 # mask_dist = td.Bernoulli(probs=self._probs) # TODO: # mask = mask_dist.sample([batch_size]).T # [H, B] self.oplr.optimize(q_loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': q_loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() } def _after_train(self): super()._after_train() if self._cur_train_step % self.assign_interval == 0: self.q_net.sync()