class SACAgent(AgentBase): """ Soft Actor-Critic. Uses stochastic policy and dual value network (two critics). Based on "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor" by Haarnoja et al. (2018) (http://arxiv.org/abs/1801.01290). """ name = "SAC" def __init__(self, in_features: FeatureType, action_size: int, **kwargs): """ Parameters: hidden_layers: (default: (128, 128)) Shape of the hidden layers that are fully connected networks. gamma: (default: 0.99) Discount value. tau: (default: 0.02) Soft copy fraction. batch_size: (default 64) Number of samples in a batch. buffer_size: (default: 1e6) Size of the prioritized experience replay buffer. warm_up: (default: 0) Number of samples that needs to be observed before starting to learn. update_freq: (default: 1) Number of samples between policy updates. number_updates: (default: 1) Number of times of batch sampling/training per `update_freq`. alpha: (default: 0.2) Weight of log probs in value function. alpha_lr: (default: None) If provided, it will add alpha as a training parameters and `alpha_lr` is its learning rate. action_scale: (default: 1.) Scale for returned action values. max_grad_norm_alpha: (default: 1.) Gradient clipping for the alpha. max_grad_norm_actor: (default 10.) Gradient clipping for the actor. max_grad_norm_critic: (default: 10.) Gradient clipping for the critic. device: Defaults to CUDA if available. """ super().__init__(**kwargs) self.device = kwargs.get("device", DEVICE) self.in_features: Tuple[int] = (in_features, ) if isinstance( in_features, int) else tuple(in_features) self.state_size: int = in_features if isinstance( in_features, int) else reduce(operator.mul, in_features) self.action_size = action_size self.gamma: float = float(self._register_param(kwargs, 'gamma', 0.99)) self.tau: float = float(self._register_param(kwargs, 'tau', 0.02)) self.batch_size: int = int( self._register_param(kwargs, 'batch_size', 64)) self.buffer_size: int = int( self._register_param(kwargs, 'buffer_size', int(1e6))) self.memory = PERBuffer(self.batch_size, self.buffer_size) self.action_min = self._register_param(kwargs, 'action_min', -1) self.action_max = self._register_param(kwargs, 'action_max', 1) self.action_scale = self._register_param(kwargs, 'action_scale', 1) self.warm_up = int(self._register_param(kwargs, 'warm_up', 0)) self.update_freq = int(self._register_param(kwargs, 'update_freq', 1)) self.number_updates = int( self._register_param(kwargs, 'number_updates', 1)) self.actor_number_updates = int( self._register_param(kwargs, 'actor_number_updates', 1)) self.critic_number_updates = int( self._register_param(kwargs, 'critic_number_updates', 1)) # Reason sequence initiation. hidden_layers = to_numbers_seq( self._register_param(kwargs, 'hidden_layers', (128, 128))) actor_hidden_layers = to_numbers_seq( self._register_param(kwargs, 'actor_hidden_layers', hidden_layers)) critic_hidden_layers = to_numbers_seq( self._register_param(kwargs, 'critic_hidden_layers', hidden_layers)) self.simple_policy = bool( self._register_param(kwargs, "simple_policy", False)) if self.simple_policy: self.policy = MultivariateGaussianPolicySimple( self.action_size, **kwargs) self.actor = ActorBody(self.state_size, self.policy.param_dim * self.action_size, hidden_layers=actor_hidden_layers, device=self.device) else: self.policy = GaussianPolicy(actor_hidden_layers[-1], self.action_size, out_scale=self.action_scale, device=self.device) self.actor = ActorBody(self.state_size, actor_hidden_layers[-1], hidden_layers=actor_hidden_layers[:-1], device=self.device) self.double_critic = DoubleCritic(self.in_features, self.action_size, CriticBody, hidden_layers=critic_hidden_layers, device=self.device) self.target_double_critic = DoubleCritic( self.in_features, self.action_size, CriticBody, hidden_layers=critic_hidden_layers, device=self.device) # Target sequence initiation hard_update(self.target_double_critic, self.double_critic) # Optimization sequence initiation. self.target_entropy = -self.action_size alpha_lr = self._register_param(kwargs, "alpha_lr") self.alpha_lr = float(alpha_lr) if alpha_lr else None alpha_init = float(self._register_param(kwargs, "alpha", 0.2)) self.log_alpha = torch.tensor(np.log(alpha_init), device=self.device, requires_grad=True) actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4)) critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4)) self.actor_params = list(self.actor.parameters()) + list( self.policy.parameters()) self.critic_params = list(self.double_critic.parameters()) self.actor_optimizer = optim.Adam(self.actor_params, lr=actor_lr) self.critic_optimizer = optim.Adam(list(self.critic_params), lr=critic_lr) if self.alpha_lr is not None: self.alpha_optimizer = optim.Adam([self.log_alpha], lr=self.alpha_lr) self.max_grad_norm_alpha = float( self._register_param(kwargs, "max_grad_norm_alpha", 1.0)) self.max_grad_norm_actor = float( self._register_param(kwargs, "max_grad_norm_actor", 10.0)) self.max_grad_norm_critic = float( self._register_param(kwargs, "max_grad_norm_critic", 10.0)) # Breath, my child. self.iteration = 0 self._loss_actor = float('inf') self._loss_critic = float('inf') self._metrics: Dict[str, Union[float, Dict[str, float]]] = {} @property def alpha(self): return self.log_alpha.exp() @property def loss(self): return {'actor': self._loss_actor, 'critic': self._loss_critic} @loss.setter def loss(self, value): if isinstance(value, dict): self._loss_actor = value['actor'] self._loss_critic = value['critic'] else: self._loss_actor = value self._loss_critic = value def reset_agent(self) -> None: self.actor.reset_parameters() self.policy.reset_parameters() self.double_critic.reset_parameters() hard_update(self.target_double_critic, self.double_critic) def state_dict(self) -> Dict[str, dict]: """ Returns network's weights in order: Actor, TargetActor, Critic, TargetCritic """ return { "actor": self.actor.state_dict(), "policy": self.policy.state_dict(), "double_critic": self.double_critic.state_dict(), "target_double_critic": self.target_double_critic.state_dict(), } @torch.no_grad() def act(self, state, epsilon: float = 0.0, deterministic=False) -> List[float]: if self.iteration < self.warm_up or self._rng.random() < epsilon: random_action = torch.rand(self.action_size) * ( self.action_max + self.action_min) + self.action_min return random_action.cpu().tolist() state = to_tensor(state).view(1, self.state_size).float().to(self.device) proto_action = self.actor(state) action = self.policy(proto_action, deterministic) return action.flatten().tolist() def step(self, state, action, reward, next_state, done): self.iteration += 1 self.memory.add( state=state, action=action, reward=reward, next_state=next_state, done=done, ) if self.iteration < self.warm_up: return if len(self.memory) > self.batch_size and (self.iteration % self.update_freq) == 0: for _ in range(self.number_updates): self.learn(self.memory.sample()) def compute_value_loss(self, states, actions, rewards, next_states, dones) -> Tuple[Tensor, Tensor]: Q1_expected, Q2_expected = self.double_critic(states, actions) with torch.no_grad(): proto_next_action = self.actor(states) next_actions = self.policy(proto_next_action) log_prob = self.policy.logprob assert next_actions.shape == (self.batch_size, self.action_size) assert log_prob.shape == (self.batch_size, 1) Q1_target_next, Q2_target_next = self.target_double_critic.act( next_states, next_actions) assert Q1_target_next.shape == Q2_target_next.shape == ( self.batch_size, 1) Q_min = torch.min(Q1_target_next, Q2_target_next) QH_target = Q_min - self.alpha * log_prob assert QH_target.shape == (self.batch_size, 1) Q_target = rewards + self.gamma * QH_target * (1 - dones) assert Q_target.shape == (self.batch_size, 1) Q1_diff = Q1_expected - Q_target error_1 = Q1_diff.pow(2) mse_loss_1 = error_1.mean() self._metrics['value/critic1'] = { 'mean': float(Q1_expected.mean()), 'std': float(Q1_expected.std()) } self._metrics['value/critic1_lse'] = float(mse_loss_1.item()) Q2_diff = Q2_expected - Q_target error_2 = Q2_diff.pow(2) mse_loss_2 = error_2.mean() self._metrics['value/critic2'] = { 'mean': float(Q2_expected.mean()), 'std': float(Q2_expected.std()) } self._metrics['value/critic2_lse'] = float(mse_loss_2.item()) Q_diff = Q1_expected - Q2_expected self._metrics['value/Q_diff'] = { 'mean': float(Q_diff.mean()), 'std': float(Q_diff.std()) } error = torch.min(error_1, error_2) loss = mse_loss_1 + mse_loss_2 return loss, error def compute_policy_loss(self, states): proto_actions = self.actor(states) pred_actions = self.policy(proto_actions) log_prob = self.policy.logprob assert pred_actions.shape == (self.batch_size, self.action_size) Q_estimate = torch.min(*self.double_critic(states, pred_actions)) assert Q_estimate.shape == (self.batch_size, 1) self._metrics['policy/entropy'] = -float(log_prob.detach().mean()) loss = (self.alpha * log_prob - Q_estimate).mean() # Update alpha if self.alpha_lr is not None: self.alpha_optimizer.zero_grad() loss_alpha = -(self.alpha * (log_prob + self.target_entropy).detach()).mean() loss_alpha.backward() nn.utils.clip_grad_norm_(self.log_alpha, self.max_grad_norm_alpha) self.alpha_optimizer.step() return loss def learn(self, samples): """update the critics and actors of all the agents """ rewards = to_tensor(samples['reward']).float().to(self.device).view( self.batch_size, 1) dones = to_tensor(samples['done']).int().to(self.device).view( self.batch_size, 1) states = to_tensor(samples['state']).float().to(self.device).view( self.batch_size, self.state_size) next_states = to_tensor(samples['next_state']).float().to( self.device).view(self.batch_size, self.state_size) actions = to_tensor(samples['action']).to(self.device).view( self.batch_size, self.action_size) # Critic (value) update for _ in range(self.critic_number_updates): value_loss, error = self.compute_value_loss( states, actions, rewards, next_states, dones) self.critic_optimizer.zero_grad() value_loss.backward() nn.utils.clip_grad_norm_(self.critic_params, self.max_grad_norm_critic) self.critic_optimizer.step() self._loss_critic = value_loss.item() # Actor (policy) update for _ in range(self.actor_number_updates): policy_loss = self.compute_policy_loss(states) self.actor_optimizer.zero_grad() policy_loss.backward() nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_actor) self.actor_optimizer.step() self._loss_actor = policy_loss.item() if hasattr(self.memory, 'priority_update'): assert any(~torch.isnan(error)) self.memory.priority_update(samples['index'], error.abs()) soft_update(self.target_double_critic, self.double_critic, self.tau) def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False): data_logger.log_value("loss/actor", self._loss_actor, step) data_logger.log_value("loss/critic", self._loss_critic, step) data_logger.log_value("loss/alpha", self.alpha, step) if self.simple_policy: policy_params = { str(i): v for i, v in enumerate( itertools.chain.from_iterable(self.policy.parameters())) } data_logger.log_values_dict("policy/param", policy_params, step) for name, value in self._metrics.items(): if isinstance(value, dict): data_logger.log_values_dict(name, value, step) else: data_logger.log_value(name, value, step) if full_log: # TODO: Add Policy layers for idx, layer in enumerate(self.actor.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"policy/layer_weights_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"policy/layer_bias_{idx}", layer.bias, step) for idx, layer in enumerate(self.double_critic.critic_1.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"critic_1/layer_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"critic_1/layer_bias_{idx}", layer.bias, step) for idx, layer in enumerate(self.double_critic.critic_2.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"critic_2/layer_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"critic_2/layer_bias_{idx}", layer.bias, step) def get_state(self): return dict( actor=self.actor.state_dict(), policy=self.policy.state_dict(), double_critic=self.double_critic.state_dict(), target_double_critic=self.target_double_critic.state_dict(), config=self._config, ) def save_state(self, path: str): agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, path: str): agent_state = torch.load(path) self._config = agent_state.get('config', {}) self.__dict__.update(**self._config) self.actor.load_state_dict(agent_state['actor']) self.policy.load_state_dict(agent_state['policy']) self.double_critic.load_state_dict(agent_state['double_critic']) self.target_double_critic.load_state_dict( agent_state['target_double_critic'])
class D3PGAgent(AgentBase): """Distributional DDPG (D3PG) [1]. It's closely related to, and sits in-between, D4PG and DDPG. Compared to D4PG it lacks the multi actors support. It extends the DDPG agent with: 1. Distributional critic update. 2. N-step returns. 3. Prioritization of the experience replay (PER). [1] "Distributed Distributional Deterministic Policy Gradients" (2018, ICLR) by G. Barth-Maron & M. Hoffman et al. """ name = "D3PG" def __init__(self, state_size: int, action_size: int, hidden_layers: Sequence[int] = (128, 128), **kwargs): super().__init__(**kwargs) self.device = self._register_param(kwargs, "device", DEVICE) self.state_size = state_size self.action_size = action_size self.num_atoms = int(self._register_param(kwargs, 'num_atoms', 51)) v_min = float(self._register_param(kwargs, 'v_min', -10)) v_max = float(self._register_param(kwargs, 'v_max', 10)) # Reason sequence initiation. self.action_min = float(self._register_param(kwargs, 'action_min', -1)) self.action_max = float(self._register_param(kwargs, 'action_max', 1)) self.action_scale = int(self._register_param(kwargs, 'action_scale', 1)) self.gamma = float(self._register_param(kwargs, 'gamma', 0.99)) self.tau = float(self._register_param(kwargs, 'tau', 0.02)) self.batch_size: int = int( self._register_param(kwargs, 'batch_size', 64)) self.buffer_size: int = int( self._register_param(kwargs, 'buffer_size', int(1e6))) self.buffer = PERBuffer(self.batch_size, self.buffer_size) self.n_steps = int(self._register_param(kwargs, "n_steps", 3)) self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma) self.warm_up: int = int(self._register_param(kwargs, 'warm_up', 0)) self.update_freq: int = int( self._register_param(kwargs, 'update_freq', 1)) if kwargs.get("simple_policy", False): std_init = kwargs.get("std_init", 1.0) std_max = kwargs.get("std_max", 1.5) std_min = kwargs.get("std_min", 0.25) self.policy = MultivariateGaussianPolicySimple(self.action_size, std_init=std_init, std_min=std_min, std_max=std_max, device=self.device) else: self.policy = MultivariateGaussianPolicy(self.action_size, device=self.device) self.actor_hidden_layers = to_numbers_seq( self._register_param(kwargs, 'actor_hidden_layers', hidden_layers)) self.critic_hidden_layers = to_numbers_seq( self._register_param(kwargs, 'critic_hidden_layers', hidden_layers)) # This looks messy but it's not that bad. Actor, critic_net and Critic(critic_net). Then the same for `target_`. self.actor = ActorBody(state_size, self.policy.param_dim * action_size, hidden_layers=self.actor_hidden_layers, gate_out=torch.tanh, device=self.device) critic_net = CriticBody(state_size, action_size, out_features=self.num_atoms, hidden_layers=self.critic_hidden_layers, device=self.device) self.critic = CategoricalNet(num_atoms=self.num_atoms, v_min=v_min, v_max=v_max, net=critic_net, device=self.device) self.target_actor = ActorBody(state_size, self.policy.param_dim * action_size, hidden_layers=self.actor_hidden_layers, gate_out=torch.tanh, device=self.device) target_critic_net = CriticBody(state_size, action_size, out_features=self.num_atoms, hidden_layers=self.critic_hidden_layers, device=self.device) self.target_critic = CategoricalNet(num_atoms=self.num_atoms, v_min=v_min, v_max=v_max, net=target_critic_net, device=self.device) # Target sequence initiation hard_update(self.target_actor, self.actor) hard_update(self.target_critic, self.critic) # Optimization sequence initiation. self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4)) self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4)) self.value_loss_func = nn.BCELoss(reduction='none') # self.actor_params = list(self.actor.parameters()) #+ list(self.policy.parameters()) self.actor_params = list(self.actor.parameters()) + list( self.policy.parameters()) self.actor_optimizer = Adam(self.actor_params, lr=self.actor_lr) self.critic_optimizer = Adam(self.critic.parameters(), lr=self.critic_lr) self.max_grad_norm_actor = float( self._register_param(kwargs, "max_grad_norm_actor", 50.0)) self.max_grad_norm_critic = float( self._register_param(kwargs, "max_grad_norm_critic", 50.0)) # Breath, my child. self.iteration = 0 self._loss_actor = float('nan') self._loss_critic = float('nan') self._display_dist = torch.zeros(self.critic.z_atoms.shape) self._metric_batch_error = torch.zeros(self.batch_size) self._metric_batch_value_dist = torch.zeros(self.batch_size) @property def loss(self) -> Dict[str, float]: return {'actor': self._loss_actor, 'critic': self._loss_critic} @loss.setter def loss(self, value): if isinstance(value, dict): self._loss_actor = value['actor'] self._loss_critic = value['critic'] else: self._loss_actor = value self._loss_critic = value @torch.no_grad() def act(self, state, epsilon: float = 0.0) -> List[float]: """ Returns actions for given state as per current policy. Parameters: state: Current available state from the environment. epislon: Epsilon value in the epislon-greedy policy. """ state = to_tensor(state).float().to(self.device) if self._rng.random() < epsilon: action = self.action_scale * (torch.rand(self.action_size) - 0.5) else: action_seed = self.actor.act(state).view(1, -1) action_dist = self.policy(action_seed) action = action_dist.sample() action *= self.action_scale action = action.squeeze() # Purely for logging self._display_dist = self.target_critic.act( state, action.to(self.device)).squeeze().cpu() self._display_dist = F.softmax(self._display_dist, dim=0) return torch.clamp(action, self.action_min, self.action_max).cpu().tolist() def step(self, state, action, reward, next_state, done): self.iteration += 1 # Delay adding to buffer to account for n_steps (particularly the reward) self.n_buffer.add(state=state, action=action, reward=[reward], done=[done], next_state=next_state) if not self.n_buffer.available: return self.buffer.add(**self.n_buffer.get().get_dict()) if self.iteration < self.warm_up: return if len(self.buffer) > self.batch_size and (self.iteration % self.update_freq) == 0: self.learn(self.buffer.sample()) def compute_value_loss(self, states, actions, next_states, rewards, dones, indices=None): # Q_w estimate value_dist_estimate = self.critic(states, actions) assert value_dist_estimate.shape == (self.batch_size, 1, self.num_atoms) value_dist = F.softmax(value_dist_estimate.squeeze(), dim=1) assert value_dist.shape == (self.batch_size, self.num_atoms) # Q_w' estimate via Bellman's dist operator next_action_seeds = self.target_actor.act(next_states) next_actions = self.policy(next_action_seeds).sample() assert next_actions.shape == (self.batch_size, self.action_size) target_value_dist_estimate = self.target_critic.act( states, next_actions) assert target_value_dist_estimate.shape == (self.batch_size, 1, self.num_atoms) target_value_dist_estimate = target_value_dist_estimate.squeeze() assert target_value_dist_estimate.shape == (self.batch_size, self.num_atoms) discount = self.gamma**self.n_steps target_value_projected = self.target_critic.dist_projection( rewards, 1 - dones, discount, target_value_dist_estimate) assert target_value_projected.shape == (self.batch_size, self.num_atoms) target_value_dist = F.softmax(target_value_dist_estimate, dim=-1).detach() assert target_value_dist.shape == (self.batch_size, self.num_atoms) # Comparing Q_w with Q_w' loss = self.value_loss_func(value_dist, target_value_projected) self._metric_batch_error = loss.detach().sum(dim=-1) samples_error = loss.sum(dim=-1).pow(2) loss_critic = samples_error.mean() if hasattr(self.buffer, 'priority_update') and indices is not None: assert (~torch.isnan(samples_error)).any() self.buffer.priority_update(indices, samples_error.detach().cpu().numpy()) return loss_critic def compute_policy_loss(self, states): # Compute actor loss pred_action_seeds = self.actor(states) pred_actions = self.policy(pred_action_seeds).rsample() # Negative because the optimizer minimizes, but we want to maximize the value value_dist = self.critic(states, pred_actions) self._metric_batch_value_dist = value_dist.detach() # Estimate on Z support return -torch.mean(value_dist * self.critic.z_atoms) def learn(self, experiences): """Update critics and actors""" rewards = to_tensor(experiences['reward']).float().to(self.device) dones = to_tensor(experiences['done']).type(torch.int).to(self.device) states = to_tensor(experiences['state']).float().to(self.device) actions = to_tensor(experiences['action']).to(self.device) next_states = to_tensor(experiences['next_state']).float().to( self.device) assert rewards.shape == dones.shape == (self.batch_size, 1) assert states.shape == next_states.shape == (self.batch_size, self.state_size) assert actions.shape == (self.batch_size, self.action_size) indices = None if hasattr(self.buffer, 'priority_update'): # When using PER buffer indices = experiences['index'] loss_critic = self.compute_value_loss(states, actions, next_states, rewards, dones, indices) # Value (critic) optimization self.critic_optimizer.zero_grad() loss_critic.backward() nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_critic) self.critic_optimizer.step() self._loss_critic = float(loss_critic.item()) # Policy (actor) optimization loss_actor = self.compute_policy_loss(states) self.actor_optimizer.zero_grad() loss_actor.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm_actor) self.actor_optimizer.step() self._loss_actor = float(loss_actor.item()) # Networks gradual sync soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau) def state_dict(self) -> Dict[str, dict]: """Describes agent's networks. Returns: state: (dict) Provides actors and critics states. """ return { "actor": self.actor.state_dict(), "target_actor": self.target_actor.state_dict(), "critic": self.critic.state_dict(), "target_critic": self.target_critic() } def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False): data_logger.log_value("loss/actor", self._loss_actor, step) data_logger.log_value("loss/critic", self._loss_critic, step) policy_params = { str(i): v for i, v in enumerate( itertools.chain.from_iterable(self.policy.parameters())) } data_logger.log_values_dict("policy/param", policy_params, step) data_logger.create_histogram('metric/batch_errors', self._metric_batch_error, step) data_logger.create_histogram('metric/batch_value_dist', self._metric_batch_value_dist, step) if full_log: dist = self._display_dist z_atoms = self.critic.z_atoms z_delta = self.critic.z_delta data_logger.add_histogram('dist/dist_value', min=z_atoms[0], max=z_atoms[-1], num=self.num_atoms, sum=dist.sum(), sum_squares=dist.pow(2).sum(), bucket_limits=z_atoms + z_delta, bucket_counts=dist, global_step=step) def get_state(self): return dict( actor=self.actor.state_dict(), target_actor=self.target_actor.state_dict(), critic=self.critic.state_dict(), target_critic=self.target_critic.state_dict(), config=self._config, ) def save_state(self, path: str): agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, path: str): agent_state = torch.load(path) self._config = agent_state.get('config', {}) self.__dict__.update(**self._config) self.actor.load_state_dict(agent_state['actor']) self.critic.load_state_dict(agent_state['critic']) self.target_actor.load_state_dict(agent_state['target_actor']) self.target_critic.load_state_dict(agent_state['target_critic'])
class DDPGAgent(AgentBase): """ Deep Deterministic Policy Gradients (DDPG). Instead of popular Ornstein-Uhlenbeck (OU) process for noise this agent uses Gaussian noise. """ name = "DDPG" def __init__(self, state_size: int, action_size: int, actor_lr: float = 2e-3, critic_lr: float = 2e-3, noise_scale: float = 0.2, noise_sigma: float = 0.1, **kwargs): super().__init__(**kwargs) self.device = self._register_param(kwargs, "device", DEVICE) self.state_size = state_size self.action_size = action_size # Reason sequence initiation. hidden_layers = to_numbers_seq( self._register_param(kwargs, 'hidden_layers', (128, 128))) self.actor = ActorBody(state_size, action_size, hidden_layers=hidden_layers, gate_out=torch.tanh).to(self.device) self.critic = CriticBody(state_size, action_size, hidden_layers=hidden_layers).to(self.device) self.target_actor = ActorBody(state_size, action_size, hidden_layers=hidden_layers, gate_out=torch.tanh).to(self.device) self.target_critic = CriticBody(state_size, action_size, hidden_layers=hidden_layers).to( self.device) # Noise sequence initiation self.noise = GaussianNoise(shape=(action_size, ), mu=1e-8, sigma=noise_sigma, scale=noise_scale, device=self.device) # Target sequence initiation hard_update(self.target_actor, self.actor) hard_update(self.target_critic, self.critic) # Optimization sequence initiation. self.actor_lr = float( self._register_param(kwargs, 'actor_lr', actor_lr)) self.critic_lr = float( self._register_param(kwargs, 'critic_lr', critic_lr)) self.actor_optimizer = Adam(self.actor.parameters(), lr=self.actor_lr) self.critic_optimizer = Adam(self.critic.parameters(), lr=self.critic_lr) self.max_grad_norm_actor = float( self._register_param(kwargs, "max_grad_norm_actor", 10.0)) self.max_grad_norm_critic = float( self._register_param(kwargs, "max_grad_norm_critic", 10.0)) self.action_min = float(self._register_param(kwargs, 'action_min', -1)) self.action_max = float(self._register_param(kwargs, 'action_max', 1)) self.action_scale = float( self._register_param(kwargs, 'action_scale', 1)) self.gamma = float(self._register_param(kwargs, 'gamma', 0.99)) self.tau = float(self._register_param(kwargs, 'tau', 0.02)) self.batch_size = int(self._register_param(kwargs, 'batch_size', 64)) self.buffer_size = int( self._register_param(kwargs, 'buffer_size', int(1e6))) self.buffer = ReplayBuffer(self.batch_size, self.buffer_size) self.warm_up = int(self._register_param(kwargs, 'warm_up', 0)) self.update_freq = int(self._register_param(kwargs, 'update_freq', 1)) self.number_updates = int( self._register_param(kwargs, 'number_updates', 1)) # Breath, my child. self.reset_agent() self.iteration = 0 self._loss_actor = 0. self._loss_critic = 0. def reset_agent(self) -> None: self.actor.reset_parameters() self.critic.reset_parameters() self.target_actor.reset_parameters() self.target_critic.reset_parameters() @property def loss(self) -> Dict[str, float]: return {'actor': self._loss_actor, 'critic': self._loss_critic} @loss.setter def loss(self, value): if isinstance(value, dict): self._loss_actor = value['actor'] self._loss_critic = value['critic'] else: self._loss_actor = value self._loss_critic = value @torch.no_grad() def act(self, obs, noise: float = 0.0) -> List[float]: """Acting on the observations. Returns action. Returns: action: (list float) Action values. """ obs = to_tensor(obs).float().to(self.device) action = self.actor(obs) action += noise * self.noise.sample() action = torch.clamp(action * self.action_scale, self.action_min, self.action_max) return action.cpu().numpy().tolist() def step(self, state, action, reward, next_state, done) -> None: self.iteration += 1 self.buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done) if self.iteration < self.warm_up: return if len(self.buffer) > self.batch_size and (self.iteration % self.update_freq) == 0: for _ in range(self.number_updates): self.learn(self.buffer.sample()) def compute_value_loss(self, states, actions, next_states, rewards, dones): next_actions = self.target_actor.act(next_states) assert next_actions.shape == actions.shape Q_target_next = self.target_critic.act(next_states, next_actions) Q_target = rewards + self.gamma * Q_target_next * (1 - dones) Q_expected = self.critic(states, actions) assert Q_expected.shape == Q_target.shape == Q_target_next.shape return mse_loss(Q_expected, Q_target) def compute_policy_loss(self, states) -> None: """Compute Policy loss based on provided states. Loss = Mean(-Q(s, _a) ), where _a is actor's estimate based on state, _a = Actor(s). """ pred_actions = self.actor(states) return -self.critic(states, pred_actions).mean() def learn(self, experiences) -> None: """Update critics and actors""" rewards = to_tensor(experiences['reward']).float().to( self.device).unsqueeze(1) dones = to_tensor(experiences['done']).type(torch.int).to( self.device).unsqueeze(1) states = to_tensor(experiences['state']).float().to(self.device) actions = to_tensor(experiences['action']).to(self.device) next_states = to_tensor(experiences['next_state']).float().to( self.device) assert rewards.shape == dones.shape == (self.batch_size, 1) assert states.shape == next_states.shape == (self.batch_size, self.state_size) assert actions.shape == (self.batch_size, self.action_size) # Value (critic) optimization loss_critic = self.compute_value_loss(states, actions, next_states, rewards, dones) self.critic_optimizer.zero_grad() loss_critic.backward() nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm_critic) self.critic_optimizer.step() self._loss_critic = float(loss_critic.item()) # Policy (actor) optimization loss_actor = self.compute_policy_loss(states) self.actor_optimizer.zero_grad() loss_actor.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm_actor) self.actor_optimizer.step() self._loss_actor = loss_actor.item() # Soft update target weights soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau) def state_dict(self) -> Dict[str, dict]: """Describes agent's networks. Returns: state: (dict) Provides actors and critics states. """ return { "actor": self.actor.state_dict(), "target_actor": self.target_actor.state_dict(), "critic": self.critic.state_dict(), "target_critic": self.target_critic.state_dict() } def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False): data_logger.log_value("loss/actor", self._loss_actor, step) data_logger.log_value("loss/critic", self._loss_critic, step) if full_log: for idx, layer in enumerate(self.actor.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"actor/layer_weights_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"actor/layer_bias_{idx}", layer.bias, step) for idx, layer in enumerate(self.critic.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"critic/layer_weights_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"critic/layer_bias_{idx}", layer.bias, step) def get_state(self) -> AgentState: net = dict( actor=self.actor.state_dict(), target_actor=self.target_actor.state_dict(), critic=self.critic.state_dict(), target_critic=self.target_critic.state_dict(), ) network_state: NetworkState = NetworkState(net=net) return AgentState(model=self.name, state_space=self.state_size, action_space=self.action_size, config=self._config, buffer=self.buffer.get_state(), network=network_state) def save_state(self, path: str) -> None: agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, *, path: Optional[str] = None, agent_state: Optional[dict] = None): if path is None and agent_state: raise ValueError( "Either `path` or `agent_state` must be provided to load agent's state." ) if path is not None and agent_state is None: agent_state = torch.load(path) self._config = agent_state.get('config', {}) self.__dict__.update(**self._config) self.actor.load_state_dict(agent_state['actor']) self.critic.load_state_dict(agent_state['critic']) self.target_actor.load_state_dict(agent_state['target_actor']) self.target_critic.load_state_dict(agent_state['target_critic'])
class TD3Agent(AgentBase): """ Twin Delayed Deep Deterministic (TD3) Policy Gradient. In short, it's a slightly modified/improved version of the DDPG. Compared to the DDPG in this package, which uses Guassian noise, this TD3 uses Ornstein–Uhlenbeck process as the noise. """ name = "TD3" def __init__(self, state_size: int, action_size: int, noise_scale: float = 0.2, noise_sigma: float = 0.1, **kwargs): """ Parameters: state_size (int): Number of input dimensions. action_size (int): Number of output dimensions noise_scale (float): Added noise amplitude. Default: 0.2. noise_sigma (float): Added noise variance. Default: 0.1. Keyword parameters: hidden_layers (tuple of ints): Tuple defining hidden dimensions in fully connected nets. Default: (128, 128). actor_lr (float): Learning rate for the actor (policy). Default: 0.003. critic_lr (float): Learning rate for the critic (value function). Default: 0.003. gamma (float): Discount value. Default: 0.99. tau (float): Soft-copy factor. Default: 0.02. actor_hidden_layers (tuple of ints): Shape of network for actor. Default: `hideen_layers`. critic_hidden_layers (tuple of ints): Shape of network for critic. Default: `hideen_layers`. max_grad_norm_actor (float) Maximum norm value for actor gradient. Default: 100. max_grad_norm_critic (float): Maximum norm value for critic gradient. Default: 100. batch_size (int): Number of samples used in learning. Default: 64. buffer_size (int): Maximum number of samples to store. Default: 1e6. warm_up (int): Number of samples to observe before starting any learning step. Default: 0. update_freq (int): Number of steps between each learning step. Default 1. number_updates (int): How many times to use learning step in the learning phase. Default: 1. action_min (float): Minimum returned action value. Default: -1. action_max (float): Maximum returned action value. Default: 1. action_scale (float): Multipler value for action. Default: 1. """ super().__init__(**kwargs) self.device = self._register_param( kwargs, "device", DEVICE) # Default device is CUDA if available # Reason sequence initiation. self.state_size = state_size self.action_size = action_size hidden_layers = to_numbers_seq( self._register_param(kwargs, 'hidden_layers', (128, 128))) self.actor = ActorBody(state_size, action_size, hidden_layers=hidden_layers).to(self.device) self.critic = DoubleCritic(state_size, action_size, CriticBody, hidden_layers=hidden_layers).to(self.device) self.target_actor = ActorBody(state_size, action_size, hidden_layers=hidden_layers).to( self.device) self.target_critic = DoubleCritic(state_size, action_size, CriticBody, hidden_layers=hidden_layers).to( self.device) # Noise sequence initiation # self.noise = GaussianNoise(shape=(action_size,), mu=1e-8, sigma=noise_sigma, scale=noise_scale, device=device) self.noise = OUProcess(shape=action_size, scale=noise_scale, sigma=noise_sigma, device=self.device) # Target sequence initiation hard_update(self.target_actor, self.actor) hard_update(self.target_critic, self.critic) # Optimization sequence initiation. actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-3)) critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-3)) self.actor_optimizer = AdamW(self.actor.parameters(), lr=actor_lr) self.critic_optimizer = AdamW(self.critic.parameters(), lr=critic_lr) self.max_grad_norm_actor: float = float( kwargs.get("max_grad_norm_actor", 100)) self.max_grad_norm_critic: float = float( kwargs.get("max_grad_norm_critic", 100)) self.action_min = float(self._register_param(kwargs, 'action_min', -1.)) self.action_max = float(self._register_param(kwargs, 'action_max', 1.)) self.action_scale = float( self._register_param(kwargs, 'action_scale', 1.)) self.gamma = float(self._register_param(kwargs, 'gamma', 0.99)) self.tau = float(self._register_param(kwargs, 'tau', 0.02)) self.batch_size = int(self._register_param(kwargs, 'batch_size', 64)) self.buffer_size = int( self._register_param(kwargs, 'buffer_size', int(1e6))) self.buffer = ReplayBuffer(self.batch_size, self.buffer_size) self.warm_up = int(self._register_param(kwargs, 'warm_up', 0)) self.update_freq = int(self._register_param(kwargs, 'update_freq', 1)) self.update_policy_freq = int( self._register_param(kwargs, 'update_policy_freq', 1)) self.number_updates = int( self._register_param(kwargs, 'number_updates', 1)) self.noise_reset_freq = int( self._register_param(kwargs, 'noise_reset_freq', 10000)) # Breath, my child. self.reset_agent() self.iteration = 0 self._loss_actor = 0. self._loss_critic = 0. @property def loss(self) -> Dict[str, float]: return {'actor': self._loss_actor, 'critic': self._loss_critic} @loss.setter def loss(self, value): if isinstance(value, dict): self._loss_actor = value['actor'] self._loss_critic = value['critic'] else: self._loss_actor = value self._loss_critic = value def reset_agent(self) -> None: self.actor.reset_parameters() self.critic.reset_parameters() self.target_actor.reset_parameters() self.target_critic.reset_parameters() def act(self, state, epsilon: float = 0.0, training_mode=True) -> List[float]: """ Agent acting on observations. When the training_mode is True (default) a noise is added to each action. """ # Epsilon greedy if self._rng.random() < epsilon: rnd_actions = torch.rand(self.action_size) * ( self.action_max - self.action_min) - self.action_min return rnd_actions.tolist() with torch.no_grad(): state = to_tensor(state).float().to(self.device) action = self.actor(state) if training_mode: action += self.noise.sample() return (self.action_scale * torch.clamp(action, self.action_min, self.action_max)).tolist() def target_act(self, staten, noise: float = 0.0): with torch.no_grad(): staten = to_tensor(staten).float().to(self.device) action = self.target_actor(staten) + noise * self.noise.sample() return torch.clamp(action, self.action_min, self.action_max).cpu().numpy().astype( np.float32) def step(self, state, action, reward, next_state, done): self.iteration += 1 self.buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done) if (self.iteration % self.noise_reset_freq) == 0: self.noise.reset_states() if self.iteration < self.warm_up: return if len(self.buffer) <= self.batch_size: return if not (self.iteration % self.update_freq) or not ( self.iteration % self.update_policy_freq): for _ in range(self.number_updates): # Note: Inside this there's a delayed policy update. # Every `update_policy_freq` it will learn `number_updates` times. self.learn(self.buffer.sample()) def learn(self, experiences): """Update critics and actors""" rewards = to_tensor(experiences['reward']).float().to( self.device).unsqueeze(1) dones = to_tensor(experiences['done']).type(torch.int).to( self.device).unsqueeze(1) states = to_tensor(experiences['state']).float().to(self.device) actions = to_tensor(experiences['action']).to(self.device) next_states = to_tensor(experiences['next_state']).float().to( self.device) if (self.iteration % self.update_freq) == 0: self._update_value_function(states, actions, rewards, next_states, dones) if (self.iteration % self.update_policy_freq) == 0: self._update_policy(states) soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau) def _update_value_function(self, states, actions, rewards, next_states, dones): # critic loss next_actions = self.target_actor.act(next_states) Q_target_next = torch.min( *self.target_critic.act(next_states, next_actions)) Q_target = rewards + (self.gamma * Q_target_next * (1 - dones)) Q1_expected, Q2_expected = self.critic(states, actions) loss_critic = mse_loss(Q1_expected, Q_target) + mse_loss( Q2_expected, Q_target) # Minimize the loss self.critic_optimizer.zero_grad() loss_critic.backward() nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm_critic) self.critic_optimizer.step() self._loss_critic = float(loss_critic.item()) def _update_policy(self, states): # Compute actor loss pred_actions = self.actor(states) loss_actor = -self.critic(states, pred_actions)[0].mean() self.actor_optimizer.zero_grad() loss_actor.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm_actor) self.actor_optimizer.step() self._loss_actor = loss_actor.item() def state_dict(self) -> Dict[str, dict]: """Describes agent's networks. Returns: state: (dict) Provides actors and critics states. """ return { "actor": self.actor.state_dict(), "target_actor": self.target_actor.state_dict(), "critic": self.critic.state_dict(), "target_critic": self.target_critic() } def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False): data_logger.log_value("loss/actor", self._loss_actor, step) data_logger.log_value("loss/critic", self._loss_critic, step) def get_state(self): return dict( actor=self.actor.state_dict(), target_actor=self.target_actor.state_dict(), critic=self.critic.state_dict(), target_critic=self.target_critic.state_dict(), config=self._config, ) def save_state(self, path: str): agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, path: str): agent_state = torch.load(path) self._config = agent_state.get('config', {}) self.__dict__.update(**self._config) self.actor.load_state_dict(agent_state['actor']) self.critic.load_state_dict(agent_state['critic']) self.target_actor.load_state_dict(agent_state['target_actor']) self.target_critic.load_state_dict(agent_state['target_critic'])
class PPOAgent(AgentBase): """ Proximal Policy Optimization (PPO) [1] is an online policy gradient method that could be considered as an implementation-wise simplified version of the Trust Region Policy Optimization (TRPO). [1] "Proximal Policy Optimization Algorithms" (2017) by J. Schulman, F. Wolski, P. Dhariwal, A. Radford, O. Klimov. https://arxiv.org/abs/1707.06347 """ name = "PPO" def __init__(self, state_size: int, action_size: int, **kwargs): """ Parameters: state_size: Number of input dimensions. action_size: Number of output dimensions hidden_layers: (default: (100, 100) ) Tuple defining hidden dimensions in fully connected nets. is_discrete: (default: False) Whether return discrete action. kl_div: (default: False) Whether to use KL divergence in loss. using_gae: (default: True) Whether to use General Advantage Estimator. gae_lambda: (default: 0.96) Value of \lambda in GAE. actor_lr: (default: 0.0003) Learning rate for the actor (policy). critic_lr: (default: 0.001) Learning rate for the critic (value function). actor_betas: (default: (0.9, 0.999) Adam's betas for actor optimizer. critic_betas: (default: (0.9, 0.999) Adam's betas for critic optimizer. gamma: (default: 0.99) Discount value. ppo_ratio_clip: (default: 0.25) Policy ratio clipping value. num_epochs: (default: 1) Number of time to learn from samples. rollout_length: (default: 48) Number of actions to take before update. batch_size: (default: rollout_length) Number of samples used in learning. actor_number_updates: (default: 10) Number of times policy losses are propagated. critic_number_updates: (default: 10) Number of times value losses are propagated. entropy_weight: (default: 0.005) Weight of the entropy term in the loss. value_loss_weight: (default: 0.005) Weight of the entropy term in the loss. """ super().__init__(**kwargs) self.device = self._register_param( kwargs, "device", DEVICE) # Default device is CUDA if available self.state_size = state_size self.action_size = action_size self.hidden_layers = to_numbers_seq( self._register_param(kwargs, "hidden_layers", (100, 100))) self.iteration = 0 self.is_discrete = bool( self._register_param(kwargs, "is_discrete", False)) self.using_gae = bool(self._register_param(kwargs, "using_gae", True)) self.gae_lambda = float( self._register_param(kwargs, "gae_lambda", 0.96)) self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4)) self.actor_betas: Tuple[float, float] = to_numbers_seq( self._register_param(kwargs, 'actor_betas', (0.9, 0.999))) self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 1e-3)) self.critic_betas: Tuple[float, float] = to_numbers_seq( self._register_param(kwargs, 'critic_betas', (0.9, 0.999))) self.gamma = float(self._register_param(kwargs, "gamma", 0.99)) self.ppo_ratio_clip = float( self._register_param(kwargs, "ppo_ratio_clip", 0.25)) self.using_kl_div = bool( self._register_param(kwargs, "using_kl_div", False)) self.kl_beta = float(self._register_param(kwargs, 'kl_beta', 0.1)) self.target_kl = float(self._register_param(kwargs, "target_kl", 0.01)) self.kl_div = float('inf') self.num_workers = int(self._register_param(kwargs, "num_workers", 1)) self.num_epochs = int(self._register_param(kwargs, "num_epochs", 1)) self.rollout_length = int( self._register_param(kwargs, "rollout_length", 48)) # "Much less than the episode length" self.batch_size = int( self._register_param(kwargs, "batch_size", self.rollout_length)) self.actor_number_updates = int( self._register_param(kwargs, "actor_number_updates", 10)) self.critic_number_updates = int( self._register_param(kwargs, "critic_number_updates", 10)) self.entropy_weight = float( self._register_param(kwargs, "entropy_weight", 0.5)) self.value_loss_weight = float( self._register_param(kwargs, "value_loss_weight", 1.0)) self.local_memory_buffer = {} self.action_scale = float( self._register_param(kwargs, "action_scale", 1)) self.action_min = float(self._register_param(kwargs, "action_min", -1)) self.action_max = float(self._register_param(kwargs, "action_max", 1)) self.max_grad_norm_actor = float( self._register_param(kwargs, "max_grad_norm_actor", 100.0)) self.max_grad_norm_critic = float( self._register_param(kwargs, "max_grad_norm_critic", 100.0)) if kwargs.get("simple_policy", False): self.policy = MultivariateGaussianPolicySimple( self.action_size, **kwargs) else: self.policy = MultivariateGaussianPolicy(self.action_size, device=self.device) self.buffer = RolloutBuffer(batch_size=self.batch_size, buffer_size=self.rollout_length) self.actor = ActorBody(state_size, self.policy.param_dim * action_size, gate_out=torch.tanh, hidden_layers=self.hidden_layers, device=self.device) self.critic = ActorBody(state_size, 1, gate_out=None, hidden_layers=self.hidden_layers, device=self.device) self.actor_params = list(self.actor.parameters()) + list( self.policy.parameters()) self.critic_params = list(self.critic.parameters()) self.actor_opt = optim.Adam(self.actor_params, lr=self.actor_lr, betas=self.actor_betas) self.critic_opt = optim.Adam(self.critic_params, lr=self.critic_lr, betas=self.critic_betas) self._loss_actor = float('nan') self._loss_critic = float('nan') self._metrics: Dict[str, float] = {} @property def loss(self) -> Dict[str, float]: return {'actor': self._loss_actor, 'critic': self._loss_critic} @loss.setter def loss(self, value): if isinstance(value, dict): self._loss_actor = value['actor'] self._loss_critic = value['critic'] else: self._loss_actor = value self._loss_critic = value def __eq__(self, o: object) -> bool: return super().__eq__(o) \ and self._config == o._config \ and self.buffer == o.buffer \ and self.get_network_state() == o.get_network_state() # TODO @dawid: Currently net isn't compared properly def __clear_memory(self): self.buffer.clear() @torch.no_grad() def act(self, state, epsilon: float = 0.): actions = [] logprobs = [] values = [] state = to_tensor(state).view(self.num_workers, self.state_size).float().to(self.device) for worker in range(self.num_workers): actor_est = self.actor.act(state[worker].unsqueeze(0)) assert not torch.any(torch.isnan(actor_est)) dist = self.policy(actor_est) action = dist.sample() value = self.critic.act( state[worker].unsqueeze(0)) # Shape: (1, 1) logprob = self.policy.log_prob(dist, action) # Shape: (1,) values.append(value) logprobs.append(logprob) if self.is_discrete: # *Technically* it's the max of Softmax but that's monotonic. action = int(torch.argmax(action)) else: action = torch.clamp(action * self.action_scale, self.action_min, self.action_max) action = action.cpu().numpy().flatten().tolist() actions.append(action) self.local_memory_buffer['value'] = torch.cat(values) self.local_memory_buffer['logprob'] = torch.stack(logprobs) assert len(actions) == self.num_workers return actions if self.num_workers > 1 else actions[0] def step(self, state, action, reward, next_state, done, **kwargs): self.iteration += 1 self.buffer.add( state=torch.tensor(state).reshape(self.num_workers, self.state_size).float(), action=torch.tensor(action).reshape(self.num_workers, self.action_size).float(), reward=torch.tensor(reward).reshape(self.num_workers, 1), done=torch.tensor(done).reshape(self.num_workers, 1), logprob=self.local_memory_buffer['logprob'].reshape( self.num_workers, 1), value=self.local_memory_buffer['value'].reshape( self.num_workers, 1), ) if self.iteration % self.rollout_length == 0: self.train() self.__clear_memory() def train(self): """ Main loop that initiates the training. """ experiences = self.buffer.all_samples() rewards = to_tensor(experiences['reward']).to(self.device) dones = to_tensor(experiences['done']).type(torch.int).to(self.device) states = to_tensor(experiences['state']).to(self.device) actions = to_tensor(experiences['action']).to(self.device) values = to_tensor(experiences['value']).to(self.device) logprobs = to_tensor(experiences['logprob']).to(self.device) assert rewards.shape == dones.shape == values.shape == logprobs.shape assert states.shape == ( self.rollout_length, self.num_workers, self.state_size), f"Wrong states shape: {states.shape}" assert actions.shape == ( self.rollout_length, self.num_workers, self.action_size), f"Wrong action shape: {actions.shape}" with torch.no_grad(): if self.using_gae: next_value = self.critic.act(states[-1]) advantages = compute_gae(rewards, dones, values, next_value, self.gamma, self.gae_lambda) advantages = normalize(advantages) returns = advantages + values # returns = normalize(advantages + values) assert advantages.shape == returns.shape == values.shape else: returns = revert_norm_returns(rewards, dones, self.gamma) returns = returns.float() advantages = normalize(returns - values) assert advantages.shape == returns.shape == values.shape for _ in range(self.num_epochs): idx = 0 self.kl_div = 0 while idx < self.rollout_length: _states = states[idx:idx + self.batch_size].view( -1, self.state_size).detach() _actions = actions[idx:idx + self.batch_size].view( -1, self.action_size).detach() _logprobs = logprobs[idx:idx + self.batch_size].view( -1, 1).detach() _returns = returns[idx:idx + self.batch_size].view(-1, 1).detach() _advantages = advantages[idx:idx + self.batch_size].view( -1, 1).detach() idx += self.batch_size self.learn( (_states, _actions, _logprobs, _returns, _advantages)) self.kl_div = abs( self.kl_div) / (self.actor_number_updates * self.rollout_length / self.batch_size) if self.kl_div > self.target_kl * 1.75: self.kl_beta = min(2 * self.kl_beta, 1e2) # Max 100 if self.kl_div < self.target_kl / 1.75: self.kl_beta = max(0.5 * self.kl_beta, 1e-6) # Min 0.000001 self._metrics['policy/kl_beta'] = self.kl_beta def compute_policy_loss(self, samples): states, actions, old_log_probs, _, advantages = samples actor_est = self.actor(states) dist = self.policy(actor_est) entropy = dist.entropy() new_log_probs = self.policy.log_prob(dist, actions).view(-1, 1) assert new_log_probs.shape == old_log_probs.shape r_theta = (new_log_probs - old_log_probs).exp() r_theta_clip = torch.clamp(r_theta, 1.0 - self.ppo_ratio_clip, 1.0 + self.ppo_ratio_clip) assert r_theta.shape == r_theta_clip.shape # KL = E[log(P/Q)] = sum_{P}( P * log(P/Q) ) -- \approx --> avg_{P}( log(P) - log(Q) ) approx_kl_div = (old_log_probs - new_log_probs).mean().item() if self.using_kl_div: # Ratio threshold for updates is 1.75 (although it should be configurable) policy_loss = -torch.mean( r_theta * advantages) + self.kl_beta * approx_kl_div else: joint_theta_adv = torch.stack( (r_theta * advantages, r_theta_clip * advantages)) assert joint_theta_adv.shape[0] == 2 policy_loss = -torch.amin(joint_theta_adv, dim=0).mean() entropy_loss = -self.entropy_weight * entropy.mean() loss = policy_loss + entropy_loss self._metrics['policy/kl_div'] = approx_kl_div self._metrics['policy/policy_ratio'] = float(r_theta.mean()) self._metrics['policy/policy_ratio_clip_mean'] = float( r_theta_clip.mean()) return loss, approx_kl_div def compute_value_loss(self, samples): states, _, _, returns, _ = samples values = self.critic(states) self._metrics['value/value_mean'] = values.mean() self._metrics['value/value_std'] = values.std() return F.mse_loss(values, returns) def learn(self, samples): self._loss_actor = 0. for _ in range(self.actor_number_updates): self.actor_opt.zero_grad() loss_actor, kl_div = self.compute_policy_loss(samples) self.kl_div += kl_div if kl_div > 1.5 * self.target_kl: # Early break # print(f"Iter: {i:02} Early break") break loss_actor.backward() nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_actor) self.actor_opt.step() self._loss_actor = loss_actor.item() for _ in range(self.critic_number_updates): self.critic_opt.zero_grad() loss_critic = self.compute_value_loss(samples) loss_critic.backward() nn.utils.clip_grad_norm_(self.critic_params, self.max_grad_norm_critic) self.critic_opt.step() self._loss_critic = float(loss_critic.item()) def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False): data_logger.log_value("loss/actor", self._loss_actor, step) data_logger.log_value("loss/critic", self._loss_critic, step) for metric_name, metric_value in self._metrics.items(): data_logger.log_value(metric_name, metric_value, step) policy_params = { str(i): v for i, v in enumerate( itertools.chain.from_iterable(self.policy.parameters())) } data_logger.log_values_dict("policy/param", policy_params, step) if full_log: for idx, layer in enumerate(self.actor.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"actor/layer_weights_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"actor/layer_bias_{idx}", layer.bias, step) for idx, layer in enumerate(self.critic.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"critic/layer_weights_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"critic/layer_bias_{idx}", layer.bias, step) def get_state(self) -> AgentState: return AgentState(model=self.name, state_space=self.state_size, action_space=self.action_size, config=self._config, buffer=copy.deepcopy(self.buffer.get_state()), network=copy.deepcopy(self.get_network_state())) def get_network_state(self) -> NetworkState: return NetworkState(net=dict( policy=self.policy.state_dict(), actor=self.actor.state_dict(), critic=self.critic.state_dict(), )) def set_buffer(self, buffer_state: BufferState) -> None: self.buffer = BufferFactory.from_state(buffer_state) def set_network(self, network_state: NetworkState) -> None: self.policy.load_state_dict(network_state.net['policy']) self.actor.load_state_dict(network_state.net['actor']) self.critic.load_state_dict(network_state.net['critic']) @staticmethod def from_state(state: AgentState) -> AgentBase: agent = PPOAgent(state.state_space, state.action_space, **state.config) if state.network is not None: agent.set_network(state.network) if state.buffer is not None: agent.set_buffer(state.buffer) return agent def save_state(self, path: str): agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, path: str): agent_state = torch.load(path) self._config = agent_state.get('config', {}) self.__dict__.update(**self._config) self.policy.load_state_dict(agent_state['policy']) self.actor.load_state_dict(agent_state['actor']) self.critic.load_state_dict(agent_state['critic'])
class DDPGAgent(AgentBase): """ Deep Deterministic Policy Gradients (DDPG). Instead of popular Ornstein-Uhlenbeck (OU) process for noise this agent uses Gaussian noise. """ name = "DDPG" def __init__(self, state_size: int, action_size: int, noise_scale: float = 0.2, noise_sigma: float = 0.1, **kwargs): """ Parameters: state_size: Number of input dimensions. action_size: Number of output dimensions noise_scale (float): Added noise amplitude. Default: 0.2. noise_sigma (float): Added noise variance. Default: 0.1. Keyword parameters: hidden_layers (tuple of ints): Tuple defining hidden dimensions in fully connected nets. Default: (64, 64). gamma (float): Discount value. Default: 0.99. tau (float): Soft-copy factor. Default: 0.002. actor_lr (float): Learning rate for the actor (policy). Default: 0.0003. critic_lr (float): Learning rate for the critic (value function). Default: 0.0003. max_grad_norm_actor (float) Maximum norm value for actor gradient. Default: 10. max_grad_norm_critic (float): Maximum norm value for critic gradient. Default: 10. batch_size (int): Number of samples used in learning. Default: 64. buffer_size (int): Maximum number of samples to store. Default: 1e6. warm_up (int): Number of samples to observe before starting any learning step. Default: 0. update_freq (int): Number of steps between each learning step. Default 1. number_updates (int): How many times to use learning step in the learning phase. Default: 1. action_min (float): Minimum returned action value. Default: -1. action_max (float): Maximum returned action value. Default: 1. action_scale (float): Multipler value for action. Default: 1. """ super().__init__(**kwargs) self.device = self._register_param(kwargs, "device", DEVICE) self.state_size = state_size self.action_size = action_size # Reason sequence initiation. hidden_layers = to_numbers_seq( self._register_param(kwargs, 'hidden_layers', (64, 64))) self.actor = ActorBody(state_size, action_size, hidden_layers=hidden_layers, gate_out=torch.tanh).to(self.device) self.critic = CriticBody(state_size, action_size, hidden_layers=hidden_layers).to(self.device) self.target_actor = ActorBody(state_size, action_size, hidden_layers=hidden_layers, gate_out=torch.tanh).to(self.device) self.target_critic = CriticBody(state_size, action_size, hidden_layers=hidden_layers).to( self.device) # Noise sequence initiation self.noise = GaussianNoise(shape=(action_size, ), mu=1e-8, sigma=noise_sigma, scale=noise_scale, device=self.device) # Target sequence initiation hard_update(self.target_actor, self.actor) hard_update(self.target_critic, self.critic) # Optimization sequence initiation. self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4)) self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4)) self.actor_optimizer = Adam(self.actor.parameters(), lr=self.actor_lr) self.critic_optimizer = Adam(self.critic.parameters(), lr=self.critic_lr) self.max_grad_norm_actor = float( self._register_param(kwargs, "max_grad_norm_actor", 10.0)) self.max_grad_norm_critic = float( self._register_param(kwargs, "max_grad_norm_critic", 10.0)) self.action_min = float(self._register_param(kwargs, 'action_min', -1)) self.action_max = float(self._register_param(kwargs, 'action_max', 1)) self.action_scale = float( self._register_param(kwargs, 'action_scale', 1)) self.gamma = float(self._register_param(kwargs, 'gamma', 0.99)) self.tau = float(self._register_param(kwargs, 'tau', 0.02)) self.batch_size = int(self._register_param(kwargs, 'batch_size', 64)) self.buffer_size = int( self._register_param(kwargs, 'buffer_size', int(1e6))) self.buffer = ReplayBuffer(self.batch_size, self.buffer_size) self.warm_up = int(self._register_param(kwargs, 'warm_up', 0)) self.update_freq = int(self._register_param(kwargs, 'update_freq', 1)) self.number_updates = int( self._register_param(kwargs, 'number_updates', 1)) # Breath, my child. self.reset_agent() self.iteration = 0 self._loss_actor = 0. self._loss_critic = 0. def reset_agent(self) -> None: self.actor.reset_parameters() self.critic.reset_parameters() self.target_actor.reset_parameters() self.target_critic.reset_parameters() @property def loss(self) -> Dict[str, float]: return {'actor': self._loss_actor, 'critic': self._loss_critic} @loss.setter def loss(self, value): if isinstance(value, dict): self._loss_actor = value['actor'] self._loss_critic = value['critic'] else: self._loss_actor = value self._loss_critic = value def __eq__(self, o: object) -> bool: return super().__eq__(o) \ and self._config == o._config \ and self.buffer == o.buffer \ and self.get_network_state() == o.get_network_state() @torch.no_grad() def act(self, obs, noise: float = 0.0) -> List[float]: """Acting on the observations. Returns action. Returns: action: (list float) Action values. """ obs = to_tensor(obs).float().to(self.device) action = self.actor(obs) action += noise * self.noise.sample() action = torch.clamp(action * self.action_scale, self.action_min, self.action_max) return action.cpu().numpy().tolist() def step(self, state, action, reward, next_state, done) -> None: self.iteration += 1 self.buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done) if self.iteration < self.warm_up: return if len(self.buffer) > self.batch_size and (self.iteration % self.update_freq) == 0: for _ in range(self.number_updates): self.learn(self.buffer.sample()) def compute_value_loss(self, states, actions, next_states, rewards, dones): next_actions = self.target_actor.act(next_states) assert next_actions.shape == actions.shape Q_target_next = self.target_critic.act(next_states, next_actions) Q_target = rewards + self.gamma * Q_target_next * (1 - dones) Q_expected = self.critic(states, actions) assert Q_expected.shape == Q_target.shape == Q_target_next.shape return mse_loss(Q_expected, Q_target) def compute_policy_loss(self, states) -> None: """Compute Policy loss based on provided states. Loss = Mean(-Q(s, _a) ), where _a is actor's estimate based on state, _a = Actor(s). """ pred_actions = self.actor(states) return -self.critic(states, pred_actions).mean() def learn(self, experiences) -> None: """Update critics and actors""" rewards = to_tensor(experiences['reward']).float().to( self.device).unsqueeze(1) dones = to_tensor(experiences['done']).type(torch.int).to( self.device).unsqueeze(1) states = to_tensor(experiences['state']).float().to(self.device) actions = to_tensor(experiences['action']).to(self.device) next_states = to_tensor(experiences['next_state']).float().to( self.device) assert rewards.shape == dones.shape == (self.batch_size, 1) assert states.shape == next_states.shape == (self.batch_size, self.state_size) assert actions.shape == (self.batch_size, self.action_size) # Value (critic) optimization loss_critic = self.compute_value_loss(states, actions, next_states, rewards, dones) self.critic_optimizer.zero_grad() loss_critic.backward() nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm_critic) self.critic_optimizer.step() self._loss_critic = float(loss_critic.item()) # Policy (actor) optimization loss_actor = self.compute_policy_loss(states) self.actor_optimizer.zero_grad() loss_actor.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm_actor) self.actor_optimizer.step() self._loss_actor = loss_actor.item() # Soft update target weights soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau) def state_dict(self) -> Dict[str, dict]: """Describes agent's networks. Returns: state: (dict) Provides actors and critics states. """ return { "actor": self.actor.state_dict(), "target_actor": self.target_actor.state_dict(), "critic": self.critic.state_dict(), "target_critic": self.target_critic.state_dict() } def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False): data_logger.log_value("loss/actor", self._loss_actor, step) data_logger.log_value("loss/critic", self._loss_critic, step) if full_log: for idx, layer in enumerate(self.actor.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"actor/layer_weights_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"actor/layer_bias_{idx}", layer.bias, step) for idx, layer in enumerate(self.critic.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"critic/layer_weights_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"critic/layer_bias_{idx}", layer.bias, step) def get_state(self) -> AgentState: return AgentState( model=self.name, state_space=self.state_size, action_space=self.action_size, config=self._config, buffer=copy.deepcopy(self.buffer.get_state()), network=copy.deepcopy(self.get_network_state()), ) def get_network_state(self) -> NetworkState: net = dict( actor=self.actor.state_dict(), target_actor=self.target_actor.state_dict(), critic=self.critic.state_dict(), target_critic=self.target_critic.state_dict(), ) return NetworkState(net=net) @staticmethod def from_state(state: AgentState) -> AgentBase: config = copy.copy(state.config) config.update({ 'state_size': state.state_space, 'action_size': state.action_space }) agent = DDPGAgent(**config) if state.network is not None: agent.set_network(state.network) if state.buffer is not None: agent.set_buffer(state.buffer) return agent def set_buffer(self, buffer_state: BufferState) -> None: self.buffer = BufferFactory.from_state(buffer_state) def set_network(self, network_state: NetworkState) -> None: self.actor.load_state_dict(copy.deepcopy(network_state.net['actor'])) self.target_actor.load_state_dict(network_state.net['target_actor']) self.critic.load_state_dict(network_state.net['critic']) self.target_critic.load_state_dict(network_state.net['target_critic']) def save_state(self, path: str) -> None: agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, *, path: Optional[str] = None, agent_state: Optional[dict] = None): if path is None and agent_state: raise ValueError( "Either `path` or `agent_state` must be provided to load agent's state." ) if path is not None and agent_state is None: agent_state = torch.load(path) self._config = agent_state.get('config', {}) self.__dict__.update(**self._config) self.actor.load_state_dict(agent_state['actor']) self.critic.load_state_dict(agent_state['critic']) self.target_actor.load_state_dict(agent_state['target_actor']) self.target_critic.load_state_dict(agent_state['target_critic'])
class D4PGAgent(AgentBase): """ Distributed Distributional DDPG (D4PG) [1]. Extends the DDPG agent with: 1. Distributional critic update. 2. The use of distributed parallel actors. 3. N-step returns. 4. Prioritization of the experience replay (PER). [1] "Distributed Distributional Deterministic Policy Gradients" (2018, ICLR) by G. Barth-Maron & M. Hoffman et al. """ name = "D4PG" def __init__(self, state_size: int, action_size: int, hidden_layers: Sequence[int] = (128, 128), **kwargs): """ Parameters: state_size (int): Number of input dimensions. action_size (int): Number of output dimensions hidden_layers (tuple of ints): Tuple defining hidden dimensions in fully connected nets. Default: (128, 128). Keyword parameters: gamma (float): Discount value. Default: 0.99. tau (float): Soft-copy factor. Default: 0.02. actor_lr (float): Learning rate for the actor (policy). Default: 0.0003. critic_lr (float): Learning rate for the critic (value function). Default: 0.0003. actor_hidden_layers (tuple of ints): Shape of network for actor. Default: `hideen_layers`. critic_hidden_layers (tuple of ints): Shape of network for critic. Default: `hideen_layers`. max_grad_norm_actor (float) Maximum norm value for actor gradient. Default: 100. max_grad_norm_critic (float): Maximum norm value for critic gradient. Default: 100. num_atoms (int): Number of discrete values for the value distribution. Default: 51. v_min (float): Value distribution minimum (left most) value. Default: -10. v_max (float): Value distribution maximum (right most) value. Default: 10. n_steps (int): Number of steps (N-steps) for the TD. Defualt: 3. batch_size (int): Number of samples used in learning. Default: 64. buffer_size (int): Maximum number of samples to store. Default: 1e6. warm_up (int): Number of samples to observe before starting any learning step. Default: 0. update_freq (int): Number of steps between each learning step. Default 1. number_updates (int): How many times to use learning step in the learning phase. Default: 1. action_min (float): Minimum returned action value. Default: -1. action_max (float): Maximum returned action value. Default: 1. action_scale (float): Multipler value for action. Default: 1. num_workers (int): Number of workers that will assume this agent. Default: 1. """ super().__init__(**kwargs) self.device = self._register_param(kwargs, "device", DEVICE) self.state_size = state_size self.action_size = action_size self.num_atoms = int(self._register_param(kwargs, 'num_atoms', 51)) v_min = float(self._register_param(kwargs, 'v_min', -10)) v_max = float(self._register_param(kwargs, 'v_max', 10)) # Reason sequence initiation. self.action_min = float(self._register_param(kwargs, 'action_min', -1)) self.action_max = float(self._register_param(kwargs, 'action_max', 1)) self.action_scale = float( self._register_param(kwargs, 'action_scale', 1)) self.gamma = float(self._register_param(kwargs, 'gamma', 0.99)) self.tau = float(self._register_param(kwargs, 'tau', 0.02)) self.batch_size = int(self._register_param(kwargs, 'batch_size', 64)) self.buffer_size = int( self._register_param(kwargs, 'buffer_size', int(1e6))) self.buffer = PERBuffer(self.batch_size, self.buffer_size) self.n_steps = int(self._register_param(kwargs, "n_steps", 3)) self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma) self.warm_up = int(self._register_param(kwargs, 'warm_up', 0)) self.update_freq = int(self._register_param(kwargs, 'update_freq', 1)) self.actor_hidden_layers = to_numbers_seq( self._register_param(kwargs, 'actor_hidden_layers', hidden_layers)) self.critic_hidden_layers = to_numbers_seq( self._register_param(kwargs, 'critic_hidden_layers', hidden_layers)) if kwargs.get("simple_policy", False): std_init = float(self._register_param(kwargs, "std_init", 1.0)) std_max = float(self._register_param(kwargs, "std_max", 2.0)) std_min = float(self._register_param(kwargs, "std_min", 0.05)) self.policy = MultivariateGaussianPolicySimple(self.action_size, std_init=std_init, std_min=std_min, std_max=std_max, device=self.device) else: self.policy = MultivariateGaussianPolicy(self.action_size, device=self.device) # This looks messy but it's not that bad. Actor, critic_net and Critic(critic_net). Then the same for `target_`. self.actor = ActorBody(state_size, self.policy.param_dim * action_size, hidden_layers=self.actor_hidden_layers, gate_out=torch.tanh, device=self.device) critic_net = CriticBody(state_size, action_size, out_features=self.num_atoms, hidden_layers=self.critic_hidden_layers, device=self.device) self.critic = CategoricalNet(num_atoms=self.num_atoms, v_min=v_min, v_max=v_max, net=critic_net, device=self.device) self.target_actor = ActorBody(state_size, self.policy.param_dim * action_size, hidden_layers=self.actor_hidden_layers, gate_out=torch.tanh, device=self.device) target_critic_net = CriticBody(state_size, action_size, out_features=self.num_atoms, hidden_layers=self.critic_hidden_layers, device=self.device) self.target_critic = CategoricalNet(num_atoms=self.num_atoms, v_min=v_min, v_max=v_max, net=target_critic_net, device=self.device) # Target sequence initiation hard_update(self.target_actor, self.actor) hard_update(self.target_critic, self.critic) # Optimization sequence initiation. self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4)) self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4)) self.value_loss_func = nn.BCELoss(reduction='none') self.actor_params = list(self.actor.parameters()) + list( self.policy.parameters()) self.actor_optimizer = Adam(self.actor_params, lr=self.actor_lr) self.critic_optimizer = Adam(self.critic.parameters(), lr=self.critic_lr) self.max_grad_norm_actor = float( self._register_param(kwargs, "max_grad_norm_actor", 100)) self.max_grad_norm_critic = float( self._register_param(kwargs, "max_grad_norm_critic", 100)) self.num_workers = int(self._register_param(kwargs, "num_workers", 1)) # Breath, my child. self.iteration = 0 self._loss_actor = float('nan') self._loss_critic = float('nan') @property def loss(self) -> Dict[str, float]: return {'actor': self._loss_actor, 'critic': self._loss_critic} @loss.setter def loss(self, value): if isinstance(value, dict): self._loss_actor = value['actor'] self._loss_critic = value['critic'] else: self._loss_actor = value self._loss_critic = value @torch.no_grad() def act(self, state, epsilon: float = 0.) -> List[float]: """ Returns actions for given state as per current policy. Parameters: state: Current available state from the environment. epislon: Epsilon value in the epislon-greedy policy. """ actions = [] state = to_tensor(state).view(self.num_workers, self.state_size).float().to(self.device) for worker in range(self.num_workers): if self._rng.random() < epsilon: action = self.action_scale * (torch.rand(self.action_size) - 0.5) else: action_seed = self.actor.act(state[worker].view(1, -1)) action_dist = self.policy(action_seed) action = action_dist.sample() action *= self.action_scale action = torch.clamp(action.squeeze(), self.action_min, self.action_max).cpu() actions.append(action.tolist()) assert len(actions) == self.num_workers return actions def step(self, states, actions, rewards, next_states, dones): self.iteration += 1 # Delay adding to buffer to account for n_steps (particularly the reward) self.n_buffer.add( state=torch.tensor(states).reshape(self.num_workers, self.state_size).float(), next_state=torch.tensor(next_states).reshape( self.num_workers, self.state_size).float(), action=torch.tensor(actions).reshape(self.num_workers, self.action_size).float(), reward=torch.tensor(rewards).reshape(self.num_workers, 1), done=torch.tensor(dones).reshape(self.num_workers, 1), ) if not self.n_buffer.available: return samples = self.n_buffer.get().get_dict() for worker_idx in range(self.num_workers): self.buffer.add( state=samples['state'][worker_idx], next_state=samples['next_state'][worker_idx], action=samples['action'][worker_idx], done=samples['done'][worker_idx], reward=samples['reward'][worker_idx], ) if self.iteration < self.warm_up: return if len(self.buffer) > self.batch_size and (self.iteration % self.update_freq) == 0: self.learn(self.buffer.sample()) def compute_value_loss(self, states, actions, next_states, rewards, dones, indices=None): # Q_w estimate value_dist_estimate = self.critic(states, actions) assert value_dist_estimate.shape == (self.batch_size, 1, self.num_atoms) value_dist = F.softmax(value_dist_estimate.squeeze(), dim=1) assert value_dist.shape == (self.batch_size, self.num_atoms) # Q_w' estimate via Bellman's dist operator next_action_seeds = self.target_actor.act(next_states) next_actions = self.policy(next_action_seeds).sample() assert next_actions.shape == (self.batch_size, self.action_size) target_value_dist_estimate = self.target_critic.act( states, next_actions) assert target_value_dist_estimate.shape == (self.batch_size, 1, self.num_atoms) target_value_dist_estimate = target_value_dist_estimate.squeeze() assert target_value_dist_estimate.shape == (self.batch_size, self.num_atoms) discount = self.gamma**self.n_steps target_value_projected = self.target_critic.dist_projection( rewards, 1 - dones, discount, target_value_dist_estimate) assert target_value_projected.shape == (self.batch_size, self.num_atoms) target_value_dist = F.softmax(target_value_dist_estimate, dim=-1).detach() assert target_value_dist.shape == (self.batch_size, self.num_atoms) # Comparing Q_w with Q_w' loss = self.value_loss_func(value_dist, target_value_projected) self._metric_batch_error = loss.detach().sum(dim=-1) samples_error = loss.sum(dim=-1).pow(2) loss_critic = samples_error.mean() if hasattr(self.buffer, 'priority_update') and indices is not None: assert (~torch.isnan(samples_error)).any() self.buffer.priority_update(indices, samples_error.detach().cpu().numpy()) return loss_critic def compute_policy_loss(self, states): # Compute actor loss pred_action_seeds = self.actor(states) pred_actions = self.policy.act(pred_action_seeds) pred_actions = self.policy(pred_action_seeds).rsample() # Negative because the optimizer minimizes, but we want to maximize the value value_dist = self.critic(states, pred_actions) self._batch_value_dist_metric = value_dist.detach() # Estimate on Z support return -torch.mean(value_dist * self.critic.z_atoms) def learn(self, experiences): """Update critics and actors""" # No need for size assertion since .view() has explicit sizes rewards = to_tensor(experiences['reward']).view( self.batch_size, 1).float().to(self.device) dones = to_tensor(experiences['done']).view(self.batch_size, 1).type( torch.int).to(self.device) states = to_tensor(experiences['state']).view( self.batch_size, self.state_size).float().to(self.device) actions = to_tensor(experiences['action']).view( self.batch_size, self.action_size).to(self.device) next_states = to_tensor(experiences['next_state']).view( self.batch_size, self.state_size).float().to(self.device) indices = None if hasattr(self.buffer, 'priority_update'): # When using PER buffer indices = experiences['index'] # Value (critic) optimization loss_critic = self.compute_value_loss(states, actions, next_states, rewards, dones, indices) self.critic_optimizer.zero_grad() loss_critic.backward() nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_critic) self.critic_optimizer.step() self._loss_critic = float(loss_critic.item()) # Policy (actor) optimization loss_actor = self.compute_policy_loss(states) self.actor_optimizer.zero_grad() loss_actor.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm_actor) self.actor_optimizer.step() self._loss_actor = float(loss_actor.item()) # Networks gradual sync soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau) def state_dict(self) -> Dict[str, dict]: """Describes agent's networks. Returns: state: (dict) Provides actors and critics states. """ return { "actor": self.actor.state_dict(), "target_actor": self.target_actor.state_dict(), "critic": self.critic.state_dict(), "target_critic": self.target_critic() } def log_metrics(self, data_logger: DataLogger, step, full_log=False): data_logger.log_value("loss/actor", self._loss_actor, step) data_logger.log_value("loss/critic", self._loss_critic, step) policy_params = { str(i): v for i, v in enumerate( itertools.chain.from_iterable(self.policy.parameters())) } data_logger.log_values_dict("policy/param", policy_params, step) data_logger.create_histogram('metric/batch_errors', self._metric_batch_error.sum(-1), step) data_logger.create_histogram('metric/batch_value_dist', self._batch_value_dist_metric, step) # This method, `log_metrics`, isn't executed on every iteration but just in case we delay plotting weights. # It simply might be quite costly. Thread wisely. if full_log: for idx, layer in enumerate(self.actor.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"actor/layer_weights_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"actor/layer_bias_{idx}", layer.bias, step) for idx, layer in enumerate(self.critic.net.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"critic/layer_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"critic/layer_bias_{idx}", layer.bias, step) def get_state(self): return dict( actor=self.actor.state_dict(), target_actor=self.target_actor.state_dict(), critic=self.critic.state_dict(), target_critic=self.target_critic.state_dict(), config=self._config, ) def save_state(self, path: str): agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, path: str): agent_state = torch.load(path) self._config = agent_state.get('config', {}) self.__dict__.update(**self._config) self.actor.load_state_dict(agent_state['actor']) self.critic.load_state_dict(agent_state['critic']) self.target_actor.load_state_dict(agent_state['target_actor']) self.target_critic.load_state_dict(agent_state['target_critic'])