class TD3Agent(BaseAgent): def __init__(self, config, ob_space, ac_space, actor, critic): super().__init__(config, ob_space) self._ob_space = ob_space self._ac_space = ac_space self._log_alpha = [ torch.zeros(1, requires_grad=True, device=config.device) ] self._alpha_optim = [ optim.Adam([self._log_alpha[0]], lr=config.lr_actor) ] self._actor = actor(self._config, self._ob_space, self._ac_space, self._config.tanh_policy, deterministic=True) self._actor_target = actor(self._config, self._ob_space, self._ac_space, self._config.tanh_policy, deterministic=True) self._actor_target.load_state_dict(self._actor.state_dict()) self._critic1 = critic(config, ob_space, ac_space) self._critic2 = critic(config, ob_space, ac_space) self._critic1_target = critic(config, ob_space, ac_space) self._critic2_target = critic(config, ob_space, ac_space) self._critic1_target.load_state_dict(self._critic1.state_dict()) self._critic2_target.load_state_dict(self._critic2.state_dict()) self._network_cuda(config.device) self._actor_optim = optim.Adam(self._actor.parameters(), lr=config.lr_actor) self._critic1_optim = optim.Adam(self._critic1.parameters(), lr=config.lr_critic) self._critic2_optim = optim.Adam(self._critic2.parameters(), lr=config.lr_critic) self._buffer = ReplayBuffer(config, ob_space, ac_space) self._ounoise = OUNoise(action_size(ac_space)) self._log_creation() def _log_creation(self): logger.info('creating a TD3 agent') logger.info('the actor has %d parameters', count_parameters(self._actor)) logger.info('the critic1 has %d parameters', count_parameters(self._critic1)) logger.info('the critic2 has %d parameters', count_parameters(self._critic2)) def store_episode(self, rollouts): self._buffer.store_episode(rollouts) def store_sample(self, rollouts): self._buffer.store_sample(rollouts) def state_dict(self): return { 'actor_state_dict': self._actor.state_dict(), 'critic1_state_dict': self._critic1.state_dict(), 'critic2_state_dict': self._critic2.state_dict(), 'actor_optim_state_dict': self._actor_optim.state_dict(), 'critic1_optim_state_dict': self._critic1_optim.state_dict(), 'critic2_optim_state_dict': self._critic2_optim.state_dict(), } def load_state_dict(self, ckpt): self._actor.load_state_dict(ckpt['actor_state_dict']) self._actor_target.load_state_dict(self._actor.state_dict()) self._critic1.load_state_dict(ckpt['critic1_state_dict']) self._critic2.load_state_dict(ckpt['critic2_state_dict']) self._critic1_target.load_state_dict(self._critic1.state_dict()) self._critic2_target.load_state_dict(self._critic2.state_dict()) self._network_cuda(self._config.device) self._actor_optim.load_state_dict(ckpt['actor_optim_state_dict']) self._critic1_optim.load_state_dict(ckpt['critic1_optim_state_dict']) self._critic2_optim.load_state_dict(ckpt['critic2_optim_state_dict']) optimizer_cuda(self._actor_optim, self._config.device) optimizer_cuda(self._critic1_optim, self._config.device) optimizer_cuda(self._critic2_optim, self._config.device) def _network_cuda(self, device): self._actor.to(device) self._actor_target.to(device) self._critic1.to(device) self._critic2.to(device) self._critic1_target.to(device) self._critic2_target.to(device) def train(self): config = self._config for i in range(config.num_batches): transitions = self._buffer.sample(config.batch_size) train_info = self._update_network(transitions, step=i) self._soft_update_target_network(self._actor_target, self._actor, self._config.polyak) self._soft_update_target_network(self._critic1_target, self._critic1, self._config.polyak) self._soft_update_target_network(self._critic2_target, self._critic2, self._config.polyak) return train_info def act_log(self, ob): return self._actor.act_log(ob) def act(self, ob, is_train=True): ob = to_tensor(ob, self._config.device) ac, activation = self._actor.act(ob, is_train=is_train) if is_train: for k, space in self._ac_space.spaces.items(): if isinstance(space, spaces.Box): ac[k] += self._config.noise_scale * np.random.randn( len(ac[k])) ac[k] = np.clip(ac[k], self._ac_space[k].low, self._ac_space[k].high) return ac, activation def target_act(self, ob, is_train=True): ac, activation = self._actor_target.act(ob, is_train=is_train) return ac, activation def target_act_log(self, ob): return self._actor_target.act_log(ob) def _update_network(self, transitions, step=0): config = self._config info = {} o, o_next = transitions['ob'], transitions['ob_next'] bs = len(transitions['done']) _to_tensor = lambda x: to_tensor(x, config.device) o = _to_tensor(o) o_next = _to_tensor(o_next) ac = _to_tensor(transitions['ac']) done = _to_tensor(transitions['done']).reshape(bs, 1) rew = _to_tensor(transitions['rew']).reshape(bs, 1) ## Actor loss actions_real, _ = self.act_log(o) actor_loss = -self._critic1(o, actions_real).mean() info['actor_loss'] = actor_loss.cpu().item() ## Critic loss with torch.no_grad(): actions_next, _ = self.target_act_log(o_next) for k, space in self._ac_space.spaces.items(): if isinstance(space, spaces.Box): epsilon = torch.randn_like( actions_next[k]) * self._config.noise_scale epsilon = torch.clamp(epsilon, -config.noise_clip, config.noise_clip) actions_next[k] += epsilon q_next_value1 = self._critic1_target(o_next, actions_next) q_next_value2 = self._critic2_target(o_next, actions_next) q_next_value = torch.min(q_next_value1, q_next_value2) target_q_value = rew + ( 1. - done) * config.discount_factor * q_next_value target_q_value = target_q_value.detach() real_q_value1 = self._critic1(o, ac) real_q_value2 = self._critic2(o, ac) critic1_loss = 0.5 * (target_q_value - real_q_value1).pow(2).mean() critic2_loss = 0.5 * (target_q_value - real_q_value2).pow(2).mean() info['min_target_q'] = target_q_value.min().cpu().item() info['target_q'] = target_q_value.mean().cpu().item() info['min_real1_q'] = real_q_value1.min().cpu().item() info['min_real2_q'] = real_q_value2.min().cpu().item() info['real1_q'] = real_q_value1.mean().cpu().item() info['rea2_q'] = real_q_value2.mean().cpu().item() info['critic1_loss'] = critic1_loss.cpu().item() info['critic2_loss'] = critic2_loss.cpu().item() # update the actor self._actor_optim.zero_grad() actor_loss.backward() self._actor_optim.step() # update the critics self._critic1_optim.zero_grad() critic1_loss.backward() self._critic1_optim.step() self._critic2_optim.zero_grad() critic2_loss.backward() self._critic2_optim.step() return info
class SACAgent(BaseAgent): def __init__(self, config, ob_space, ac_space, actor, critic): super().__init__(config, ob_space) self._ob_space = ob_space self._ac_space = ac_space self._log_alpha = torch.tensor(np.log(config.alpha), requires_grad=True, device=config.device) self._alpha_optim = optim.Adam([self._log_alpha], lr=config.lr_actor) # build up networks self._actor = actor(config, ob_space, ac_space, config.tanh_policy) self._critic1 = critic(config, ob_space, ac_space) self._critic2 = critic(config, ob_space, ac_space) self._target_entropy = -action_size(self._actor._ac_space) # build up target networks self._critic1_target = critic(config, ob_space, ac_space) self._critic2_target = critic(config, ob_space, ac_space) self._critic1_target.load_state_dict(self._critic1.state_dict()) self._critic2_target.load_state_dict(self._critic2.state_dict()) if config.policy == 'cnn': self._critic2.base.copy_conv_weights_from(self._critic1.base) self._actor.base.copy_conv_weights_from(self._critic1.base) if config.unsup_algo == 'curl': self._curl = CURL(config, ob_space, ac_space, self._critic1, self._critic1_target) self._encoder_optim = optim.Adam( self._critic1.base.parameters(), lr=config.lr_encoder) self._cpc_optim = optim.Adam(self._curl.parameters(), lr=config.lr_encoder) self._network_cuda(config.device) self._actor_optim = optim.Adam(self._actor.parameters(), lr=config.lr_actor) self._critic1_optim = optim.Adam(self._critic1.parameters(), lr=config.lr_critic) self._critic2_optim = optim.Adam(self._critic2.parameters(), lr=config.lr_critic) self._buffer = ReplayBuffer(config, ob_space, ac_space) def _log_creation(self): logger.info("Creating a SAC agent") logger.info("The actor has %d parameters".format( count_parameters(self._actor))) logger.info('The critic1 has %d parameters', count_parameters(self._critic1)) logger.info('The critic2 has %d parameters', count_parameters(self._critic2)) def store_sample(self, rollouts): self._buffer.store_sample(rollouts) def _network_cuda(self, device): self._actor.to(device) self._critic1.to(device) self._critic2.to(device) self._critic1_target.to(device) self._critic2_target.to(device) if self._config.policy == 'cnn' and self._config.unsup_algo == 'curl': self._curl.to(device) def state_dict(self): ret = { 'log_alpha': self._log_alpha.cpu().detach().numpy(), 'actor_state_dict': self._actor.state_dict(), 'critic1_state_dict': self._critic1.state_dict(), 'critic2_state_dict': self._critic2.state_dict(), 'alpha_optim_state_dict': self._alpha_optim.state_dict(), 'actor_optim_state_dict': self._actor_optim.state_dict(), 'critic1_optim_state_dict': self._critic1_optim.state_dict(), 'critic2_optim_state_dict': self._critic2_optim.state_dict(), } if self._config.policy == 'cnn' and self._config.unsup_algo == 'curl': ret['curl_state_dict'] = self._curl.state_dict() ret['encoder_optim_state_dict'] = self._encoder_optim.state_dict() ret['cpc_optim_state_dict'] = self._cpc_optim.state_dict() def load_state_dict(self, ckpt): self._log_alpha.data = torch.tensor(ckpt['log_alpha'], requires_grad=True, device=self._config.device) self._actor.load_state_dict(ckpt['actor_state_dict']) self._critic1.load_state_dict(ckpt['critic1_state_dict']) self._critic2.load_state_dict(ckpt['critic2_state_dict']) self._critic1_target.load_state_dict(self._critic1.state_dict()) self._critic2_target.load_state_dict(self._critic2.state_dict()) self._alpha_optim.load_state_dict(ckpt['alpha_optim_state_dict']) self._actor_optim.load_state_dict(ckpt['actor_optim_state_dict']) self._critic1_optim.load_state_dict(ckpt['critic1_optim_state_dict']) self._critic2_optim.load_state_dict(ckpt['critic2_optim_state_dict']) optimizer_cuda(self._alpha_optim, self._config.device) optimizer_cuda(self._actor_optim, self._config.device) optimizer_cuda(self._critic1_optim, self._config.device) optimizer_cuda(self._critic2_optim, self._config.device) if self._config.policy == 'cnn' and self._config.unsup_algo == 'curl': self._curl.load_state_dict(ckpt['curl_state_dict']) self._encoder_optim.load_state_dict( ckpt['encoder_optim_state_dict']) self._cpc_optim.load_state_dict(ckpt['cpc_optim_state_dict']) optimizer_cuda(self._encoder_optim, self._config.device) optimizer_cuda(self._cpc_optim, self._config.device) self._network_cuda(self._config.device) def train(self): for _ in range(self._config.num_batches): if self._config.policy == 'cnn' and self._config.unsup_algo == 'curl': transitions = self._buffer.sample_cpc(self._config.batch_size) else: transitions = self._buffer.sample(self._config.batch_size) train_info = self._update_network(transitions) self._soft_update_target_network(self._critic1_target, self._critic1, self._config.polyak) self._soft_update_target_network(self._critic2_target, self._critic2, self._config.polyak) return train_info def act_log(self, o): return self._actor.act_log(o) def _update_critic(self, o, ac, rew, o_next, done): info = {} alpha = self._log_alpha.exp() with torch.no_grad(): actions_next, log_pi_next = self.act_log(o_next) q_next_value1 = self._critic1_target(o_next, actions_next) q_next_value2 = self._critic2_target(o_next, actions_next) q_next_value = torch.min(q_next_value1, q_next_value2) - alpha * log_pi_next target_q_value = rew * self._config.reward_scale + \ (1-done) * self._config.discount_factor * q_next_value target_q_value = target_q_value.detach() # q loss real_q_value1 = self._critic1(o, ac) real_q_value2 = self._critic2(o, ac) critic1_loss = 0.5 * (target_q_value - real_q_value1).pow(2).mean() critic2_loss = 0.5 * (target_q_value - real_q_value2).pow(2).mean() info['min_target_q'] = target_q_value.min().cpu().item() info['target_q'] = target_q_value.mean().cpu().item() info['min_real1_q'] = real_q_value1.min().cpu().item() info['min_real2_q'] = real_q_value2.min().cpu().item() info['real1_q'] = real_q_value1.mean().cpu().item() info['real2_q'] = real_q_value2.mean().cpu().item() info['critic1_loss'] = critic1_loss.cpu().item() info['critic2_loss'] = critic2_loss.cpu().item() return info def _update_actor_and_alpha(self, o): info = {} actions_real, log_pi = self.act_log(o) alpha_loss = -(self._log_alpha * (log_pi + self._target_entropy).detach()).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() alpha = self._log_alpha.exp() # actor loss entropy_loss = (alpha * log_pi).mean() actor_loss = -torch.min(self._critic1(o, actions_real), self._critic2(o, actions_real)).mean() info['entropy_alpha'] = alpha.cpu().item() info['entropy_loss'] = entropy_loss.cpu().item() info['actor_loss'] = actor_loss.cpu().item() actor_loss += entropy_loss # update the actor self._actor_optim.zero_grad() actor_loss.backward() self._actor_optim.step() return info def _update_cpc(self, o_anchor, o_pos, cpc_kwargs): info = {} z_a = self._curl.encode(o_anchor) z_pos = self._curl.encode(o_pos, ema=True) logits = self._curl.compute_logits(z_a, z_pos) labels = torch.arange(logits.shape[0]).long().to(self._config.device) cpc_loss = F.cross_entropy(logits, labels) info['cpc_loss'] = cpc_loss.cpu().item() self._encoder_optim.zero_grad() self._cpc_optim.zero_grad() cpc_loss.backward() self._encoder_optim.step() self._cpc_optim.step() return info def _update_network(self, transitions): info = {} # pre-process observations o, o_next = transitions['ob'], transitions['ob_next'] bs = len(transitions['done']) _to_tensor = lambda x: to_tensor(x, self._config.device) o = _to_tensor(o) o_next = _to_tensor(o_next) ac = _to_tensor(transitions['ac']) done = _to_tensor(transitions['done']).reshape(bs, 1) rew = _to_tensor(transitions['rew']).reshape(bs, 1) # update alpha critic_info = self._update_critic(o, ac, rew, o_next, done) info.update(critic_info) actor_alpha_info = self._update_actor_and_alpha(o) info.update(actor_alpha_info) if self._config.policy == 'cnn' and self._config.unsup_algo == 'curl': cpc_kwargs = transitions['cpc_kwargs'] o_anchor = _to_tensor(cpc_kwargs['ob_anchor']) o_pos = _to_tensor(cpc_kwargs['ob_pos']) cpc_info = self._update_cpc(o_anchor, o_pos, cpc_kwargs) info.update(cpc_info) return info
class DQNAgent(BaseAgent): def __init__(self, config, ob_space, ac_space, dqn): super().__init__(config, ob_space) self._ob_space = ob_space self._ac_space = ac_space # build up networks self._dqn = dqn(config, ob_space, ac_space) self._network_cuda(config.device) self._dqn_optim = optim.Adam(self._dqn.parameters(), lr=config.lr_actor) sampler = RandomSampler() self._buffer = ReplayBuffer(config.buffer_size, sampler.sample_func, ob_space, ac_space) def _log_creation(self): logger.info("Creating a DQN agent") logger.info("The DQN has %d parameters".format( count_parameters(self._dqn))) def store_episode(self, rollouts): self._buffer.store_episode(rollouts) def store_sample(self, rollouts): self._buffer.store_sample(rollouts) def _network_cuda(self, device): self._dqn.to(device) def state_dict(self): return { 'dqn_state_dict': self._dqn.state_dict(), 'dqn_optim_state_dict': self._dqn_optim.state_dict(), } def load_state_dict(self, ckpt): self._dqn.load_state_dict(ckpt['dqn_state_dict']) self._network_cuda(self._config.device) self._dqn_optim.load_state_dict(ckpt['dqn_optim_state_dict']) optimizer_cuda(self._dqn_optim, self._config.device) def train(self): for _ in range(self._config.num_batches): transitions = self._buffer.sample(self._config.batch_size) train_info = self._update_network(transitions) return train_info def act_log(self, o): raise NotImplementedError def act(self, o): o = to_tensor(o, self._config.device) q_value = self._dqn(o) action = OrderedDict([('default', q_value.max(1)[1].item())]) return action, None def _update_network(self, transitions): info = {} # pre-process observations o, o_next = transitions['ob'], transitions['ob_next'] bs = len(transitions['done']) _to_tensor = lambda x: to_tensor(x, self._config.device) o = _to_tensor(o) o_next = _to_tensor(o_next) ac = _to_tensor(transitions['ac']) ac = ac['default'].to(torch.long) done = _to_tensor(transitions['done']).reshape(bs, 1) rew = _to_tensor(transitions['rew']).reshape(bs, 1) with torch.no_grad(): q_next_values = self._dqn(o) q_next_value = q_next_values.max(1)[0] target_q_value = rew + \ (1-done) * self._config.discount_factor * q_next_value target_q_value = target_q_value.detach() q_values = self._dqn(o) q_value = q_values.gather(1, ac[:, 0].unsqueeze(1)).squeeze(1) info['target_q'] = target_q_value.mean().cpu().item() info['real_q'] = q_value.mean().cpu().item() loss = (q_value - target_q_value).pow(2).mean() self._dqn_optim.zero_grad() loss.backward() self._dqn_optim.step() return info