def __post_init__(self, observation_space, action_space): device = self.device or ("cuda" if torch.cuda.is_available() else "cpu") model = self.Model(observation_space, action_space) self.model = model.to(device) self.model_target = no_grad(deepcopy(self.model)) self.outputnorm = self.OutputNorm(self.model.critic_output_layers) self.outputnorm_target = self.OutputNorm( self.model_target.critic_output_layers) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) self.memory = Memory(self.memory_size, self.batchsize, device) self.is_training = False
class Agent: observation_space: InitVar action_space: InitVar Model: type = rtrl.sac_models.Mlp OutputNorm: type = PopArt batchsize: int = 256 # training batch size memory_size: int = 1000000 # replay memory size lr: float = 0.0003 # learning rate discount: float = 0.99 # reward discount factor target_update: float = 0.005 # parameter for exponential moving average reward_scale: float = 5. entropy_scale: float = 1. start_training: int = 10000 device: str = None training_interval: int = 1 model_nograd = cached_property( lambda self: no_grad(copy_shared(self.model))) num_updates = 0 training_steps = 0 def __post_init__(self, observation_space, action_space): device = self.device or ("cuda" if torch.cuda.is_available() else "cpu") model = self.Model(observation_space, action_space) self.model = model.to(device) self.model_target = no_grad(deepcopy(self.model)) self.actor_optimizer = torch.optim.Adam(self.model.actor.parameters(), lr=self.lr) self.critic_optimizer = torch.optim.Adam( self.model.critics.parameters(), lr=self.lr) self.memory = Memory(self.memory_size, self.batchsize, device) self.outputnorm = self.OutputNorm(self.model.critic_output_layers) self.outputnorm_target = self.OutputNorm( self.model_target.critic_output_layers) def act(self, obs, r, done, info, train=False): stats = [] action, _ = self.model.act(obs, r, done, info, train) if train: self.memory.append(np.float32(r), np.float32(done), info, obs, action) if len( self.memory ) >= self.start_training and self.training_steps % self.training_interval == 0: stats += self.train(), self.training_steps += 1 return action, stats def train(self): obs, actions, rewards, next_obs, terminals = self.memory.sample() rewards, terminals = rewards[:, None], terminals[:, None] # expand for correct broadcasting below new_action_distribution = self.model.actor(obs) new_actions = new_action_distribution.rsample() # critic loss next_action_distribution = self.model_nograd.actor(next_obs) next_actions = next_action_distribution.sample() next_value = [ c(next_obs, next_actions) for c in self.model_target.critics ] next_value = reduce(torch.min, next_value) next_value = self.outputnorm_target.unnormalize(next_value) next_value = next_value - self.entropy_scale * next_action_distribution.log_prob( next_actions)[:, None] value_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * next_value value_target = self.outputnorm.update(value_target) values = [c(obs, actions) for c in self.model.critics] assert values[ 0].shape == value_target.shape and not value_target.requires_grad loss_critic = sum(mse_loss(v, value_target) for v in values) # actor loss new_value = [c(obs, new_actions) for c in self.model.critics] new_value = reduce(torch.min, new_value) new_value = self.outputnorm.unnormalize(new_value) loss_actor = self.entropy_scale * new_action_distribution.log_prob( new_actions)[:, None] - new_value assert loss_actor.shape == (self.batchsize, 1) loss_actor = self.outputnorm.normalize(loss_actor).mean() # update actor and critic self.critic_optimizer.zero_grad() loss_critic.backward() self.critic_optimizer.step() self.actor_optimizer.zero_grad() loss_actor.backward() self.actor_optimizer.step() # self.outputnorm.normalize(value_target, update=True) # This is not the right place to update PopArt # update target critics and normalizers exponential_moving_average(self.model_target.critics.parameters(), self.model.critics.parameters(), self.target_update) exponential_moving_average(self.outputnorm_target.parameters(), self.outputnorm.parameters(), self.target_update) return dict( loss_actor=loss_actor.detach(), loss_critic=loss_critic.detach(), outputnorm_mean=float(self.outputnorm.mean), outputnorm_std=float(self.outputnorm.std), memory_size=len(self.memory), )