def test_nstep_buffer_add_many_samples_discounted_terminate(): # Assign buffer_size = 4 gamma = 0.9 buffer = NStepBuffer(n_steps=buffer_size, gamma=gamma) populate_buffer(buffer, 20) # in-place last_samples = [sars for sars in generate_sample_SARS(4, dict_type=True)] expected_rewards = [] for idx, sample in enumerate(last_samples): expected_rewards.append(sample['reward'][0]) for iidx, sample in enumerate(last_samples[idx + 1:]): if any(sample['done']): break expected_rewards[-1] += gamma**(iidx + 1) * sample['reward'][0] # Act for sample in last_samples: buffer.add(**sample) # Assert assert len(buffer) == buffer_size for idx, expected_len in enumerate(range(buffer_size)[::-1]): sample = buffer.get() assert len(buffer) == expected_len assert sample.reward[0] == expected_rewards[idx], f"{sample}"
def test_nstep_buffer_add_many_samples(): # Assign buffer_size = 4 gamma = 1. buffer = NStepBuffer(n_steps=buffer_size, gamma=gamma) populate_buffer(buffer, 20) # in-place last_samples = [ sars for sars in generate_sample_SARS(buffer_size, dict_type=True) ] last_rewards = [s['reward'][0] for s in last_samples] # Act for sample in last_samples: sample['done'] = [False] # Make sure all samples are counted buffer.add(**sample) # Assert assert len(buffer) == buffer_size for expected_len in range(buffer_size)[::-1]: sample = buffer.get() assert len(buffer) == expected_len assert sample.reward[0] == sum(last_rewards[-expected_len - 1:])
class D3PGAgent(AgentBase): """Distributional DDPG (D3PG) [1]. It's closely related to, and sits in-between, D4PG and DDPG. Compared to D4PG it lacks the multi actors support. It extends the DDPG agent with: 1. Distributional critic update. 2. N-step returns. 3. Prioritization of the experience replay (PER). [1] "Distributed Distributional Deterministic Policy Gradients" (2018, ICLR) by G. Barth-Maron & M. Hoffman et al. """ name = "D3PG" def __init__(self, state_size: int, action_size: int, hidden_layers: Sequence[int] = (128, 128), **kwargs): super().__init__(**kwargs) self.device = self._register_param(kwargs, "device", DEVICE) self.state_size = state_size self.action_size = action_size self.num_atoms = int(self._register_param(kwargs, 'num_atoms', 51)) v_min = float(self._register_param(kwargs, 'v_min', -10)) v_max = float(self._register_param(kwargs, 'v_max', 10)) # Reason sequence initiation. self.action_min = float(self._register_param(kwargs, 'action_min', -1)) self.action_max = float(self._register_param(kwargs, 'action_max', 1)) self.action_scale = int(self._register_param(kwargs, 'action_scale', 1)) self.gamma = float(self._register_param(kwargs, 'gamma', 0.99)) self.tau = float(self._register_param(kwargs, 'tau', 0.02)) self.batch_size: int = int( self._register_param(kwargs, 'batch_size', 64)) self.buffer_size: int = int( self._register_param(kwargs, 'buffer_size', int(1e6))) self.buffer = PERBuffer(self.batch_size, self.buffer_size) self.n_steps = int(self._register_param(kwargs, "n_steps", 3)) self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma) self.warm_up: int = int(self._register_param(kwargs, 'warm_up', 0)) self.update_freq: int = int( self._register_param(kwargs, 'update_freq', 1)) if kwargs.get("simple_policy", False): std_init = kwargs.get("std_init", 1.0) std_max = kwargs.get("std_max", 1.5) std_min = kwargs.get("std_min", 0.25) self.policy = MultivariateGaussianPolicySimple(self.action_size, std_init=std_init, std_min=std_min, std_max=std_max, device=self.device) else: self.policy = MultivariateGaussianPolicy(self.action_size, device=self.device) self.actor_hidden_layers = to_numbers_seq( self._register_param(kwargs, 'actor_hidden_layers', hidden_layers)) self.critic_hidden_layers = to_numbers_seq( self._register_param(kwargs, 'critic_hidden_layers', hidden_layers)) # This looks messy but it's not that bad. Actor, critic_net and Critic(critic_net). Then the same for `target_`. self.actor = ActorBody(state_size, self.policy.param_dim * action_size, hidden_layers=self.actor_hidden_layers, gate_out=torch.tanh, device=self.device) critic_net = CriticBody(state_size, action_size, out_features=self.num_atoms, hidden_layers=self.critic_hidden_layers, device=self.device) self.critic = CategoricalNet(num_atoms=self.num_atoms, v_min=v_min, v_max=v_max, net=critic_net, device=self.device) self.target_actor = ActorBody(state_size, self.policy.param_dim * action_size, hidden_layers=self.actor_hidden_layers, gate_out=torch.tanh, device=self.device) target_critic_net = CriticBody(state_size, action_size, out_features=self.num_atoms, hidden_layers=self.critic_hidden_layers, device=self.device) self.target_critic = CategoricalNet(num_atoms=self.num_atoms, v_min=v_min, v_max=v_max, net=target_critic_net, device=self.device) # Target sequence initiation hard_update(self.target_actor, self.actor) hard_update(self.target_critic, self.critic) # Optimization sequence initiation. self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4)) self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4)) self.value_loss_func = nn.BCELoss(reduction='none') # self.actor_params = list(self.actor.parameters()) #+ list(self.policy.parameters()) self.actor_params = list(self.actor.parameters()) + list( self.policy.parameters()) self.actor_optimizer = Adam(self.actor_params, lr=self.actor_lr) self.critic_optimizer = Adam(self.critic.parameters(), lr=self.critic_lr) self.max_grad_norm_actor = float( self._register_param(kwargs, "max_grad_norm_actor", 50.0)) self.max_grad_norm_critic = float( self._register_param(kwargs, "max_grad_norm_critic", 50.0)) # Breath, my child. self.iteration = 0 self._loss_actor = float('nan') self._loss_critic = float('nan') self._display_dist = torch.zeros(self.critic.z_atoms.shape) self._metric_batch_error = torch.zeros(self.batch_size) self._metric_batch_value_dist = torch.zeros(self.batch_size) @property def loss(self) -> Dict[str, float]: return {'actor': self._loss_actor, 'critic': self._loss_critic} @loss.setter def loss(self, value): if isinstance(value, dict): self._loss_actor = value['actor'] self._loss_critic = value['critic'] else: self._loss_actor = value self._loss_critic = value @torch.no_grad() def act(self, state, epsilon: float = 0.0) -> List[float]: """ Returns actions for given state as per current policy. Parameters: state: Current available state from the environment. epislon: Epsilon value in the epislon-greedy policy. """ state = to_tensor(state).float().to(self.device) if self._rng.random() < epsilon: action = self.action_scale * (torch.rand(self.action_size) - 0.5) else: action_seed = self.actor.act(state).view(1, -1) action_dist = self.policy(action_seed) action = action_dist.sample() action *= self.action_scale action = action.squeeze() # Purely for logging self._display_dist = self.target_critic.act( state, action.to(self.device)).squeeze().cpu() self._display_dist = F.softmax(self._display_dist, dim=0) return torch.clamp(action, self.action_min, self.action_max).cpu().tolist() def step(self, state, action, reward, next_state, done): self.iteration += 1 # Delay adding to buffer to account for n_steps (particularly the reward) self.n_buffer.add(state=state, action=action, reward=[reward], done=[done], next_state=next_state) if not self.n_buffer.available: return self.buffer.add(**self.n_buffer.get().get_dict()) if self.iteration < self.warm_up: return if len(self.buffer) > self.batch_size and (self.iteration % self.update_freq) == 0: self.learn(self.buffer.sample()) def compute_value_loss(self, states, actions, next_states, rewards, dones, indices=None): # Q_w estimate value_dist_estimate = self.critic(states, actions) assert value_dist_estimate.shape == (self.batch_size, 1, self.num_atoms) value_dist = F.softmax(value_dist_estimate.squeeze(), dim=1) assert value_dist.shape == (self.batch_size, self.num_atoms) # Q_w' estimate via Bellman's dist operator next_action_seeds = self.target_actor.act(next_states) next_actions = self.policy(next_action_seeds).sample() assert next_actions.shape == (self.batch_size, self.action_size) target_value_dist_estimate = self.target_critic.act( states, next_actions) assert target_value_dist_estimate.shape == (self.batch_size, 1, self.num_atoms) target_value_dist_estimate = target_value_dist_estimate.squeeze() assert target_value_dist_estimate.shape == (self.batch_size, self.num_atoms) discount = self.gamma**self.n_steps target_value_projected = self.target_critic.dist_projection( rewards, 1 - dones, discount, target_value_dist_estimate) assert target_value_projected.shape == (self.batch_size, self.num_atoms) target_value_dist = F.softmax(target_value_dist_estimate, dim=-1).detach() assert target_value_dist.shape == (self.batch_size, self.num_atoms) # Comparing Q_w with Q_w' loss = self.value_loss_func(value_dist, target_value_projected) self._metric_batch_error = loss.detach().sum(dim=-1) samples_error = loss.sum(dim=-1).pow(2) loss_critic = samples_error.mean() if hasattr(self.buffer, 'priority_update') and indices is not None: assert (~torch.isnan(samples_error)).any() self.buffer.priority_update(indices, samples_error.detach().cpu().numpy()) return loss_critic def compute_policy_loss(self, states): # Compute actor loss pred_action_seeds = self.actor(states) pred_actions = self.policy(pred_action_seeds).rsample() # Negative because the optimizer minimizes, but we want to maximize the value value_dist = self.critic(states, pred_actions) self._metric_batch_value_dist = value_dist.detach() # Estimate on Z support return -torch.mean(value_dist * self.critic.z_atoms) def learn(self, experiences): """Update critics and actors""" rewards = to_tensor(experiences['reward']).float().to(self.device) dones = to_tensor(experiences['done']).type(torch.int).to(self.device) states = to_tensor(experiences['state']).float().to(self.device) actions = to_tensor(experiences['action']).to(self.device) next_states = to_tensor(experiences['next_state']).float().to( self.device) assert rewards.shape == dones.shape == (self.batch_size, 1) assert states.shape == next_states.shape == (self.batch_size, self.state_size) assert actions.shape == (self.batch_size, self.action_size) indices = None if hasattr(self.buffer, 'priority_update'): # When using PER buffer indices = experiences['index'] loss_critic = self.compute_value_loss(states, actions, next_states, rewards, dones, indices) # Value (critic) optimization self.critic_optimizer.zero_grad() loss_critic.backward() nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_critic) self.critic_optimizer.step() self._loss_critic = float(loss_critic.item()) # Policy (actor) optimization loss_actor = self.compute_policy_loss(states) self.actor_optimizer.zero_grad() loss_actor.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm_actor) self.actor_optimizer.step() self._loss_actor = float(loss_actor.item()) # Networks gradual sync soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau) def state_dict(self) -> Dict[str, dict]: """Describes agent's networks. Returns: state: (dict) Provides actors and critics states. """ return { "actor": self.actor.state_dict(), "target_actor": self.target_actor.state_dict(), "critic": self.critic.state_dict(), "target_critic": self.target_critic() } def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False): data_logger.log_value("loss/actor", self._loss_actor, step) data_logger.log_value("loss/critic", self._loss_critic, step) policy_params = { str(i): v for i, v in enumerate( itertools.chain.from_iterable(self.policy.parameters())) } data_logger.log_values_dict("policy/param", policy_params, step) data_logger.create_histogram('metric/batch_errors', self._metric_batch_error, step) data_logger.create_histogram('metric/batch_value_dist', self._metric_batch_value_dist, step) if full_log: dist = self._display_dist z_atoms = self.critic.z_atoms z_delta = self.critic.z_delta data_logger.add_histogram('dist/dist_value', min=z_atoms[0], max=z_atoms[-1], num=self.num_atoms, sum=dist.sum(), sum_squares=dist.pow(2).sum(), bucket_limits=z_atoms + z_delta, bucket_counts=dist, global_step=step) def get_state(self): return dict( actor=self.actor.state_dict(), target_actor=self.target_actor.state_dict(), critic=self.critic.state_dict(), target_critic=self.target_critic.state_dict(), config=self._config, ) def save_state(self, path: str): agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, path: str): agent_state = torch.load(path) self._config = agent_state.get('config', {}) self.__dict__.update(**self._config) self.actor.load_state_dict(agent_state['actor']) self.critic.load_state_dict(agent_state['critic']) self.target_actor.load_state_dict(agent_state['target_actor']) self.target_critic.load_state_dict(agent_state['target_critic'])
class RainbowAgent(AgentBase): """Rainbow agent as described in [1]. Rainbow is a DQN agent with some improvments that were suggested before 2017. As mentioned by the authors it's not exhaustive improvment but all changes are in relatively separate areas so their connection makes sense. These improvements are: * Priority Experience Replay * Multi-step * Double Q net * Dueling nets * NoisyNet * CategoricalNet for Q estimate Consider this class as a particular version of the DQN agent. [1] "Rainbow: Combining Improvements in Deep Reinforcement Learning" by Hessel et al. (DeepMind team) https://arxiv.org/abs/1710.02298 """ name = "Rainbow" def __init__(self, input_shape: Union[Sequence[int], int], output_shape: Union[Sequence[int], int], state_transform: Optional[Callable] = None, reward_transform: Optional[Callable] = None, **kwargs): """ A wrapper over the DQN thus majority of the logic is in the DQNAgent. Special treatment is required because the Rainbow agent uses categorical nets which operate on probability distributions. Each action is taken as the estimate from such distributions. Parameters: input_shape (tuple of ints): Most likely that's your *state* shape. output_shape (tuple of ints): Most likely that's you *action* shape. pre_network_fn (function that takes input_shape and returns network): Used to preprocess state before it is used in the value- and advantage-function in the dueling nets. hidden_layers (tuple of ints): Shape and sizes of fully connected networks used. Default: (100, 100). lr (default: 1e-3): Learning rate value. gamma (float): Discount factor. Default: 0.99. tau (float): Soft-copy factor. Default: 0.002. update_freq (int): Number of steps between each learning step. Default 1. batch_size (int): Number of samples to use at each learning step. Default: 80. buffer_size (int): Number of most recent samples to keep in memory for learning. Default: 1e5. warm_up (int): Number of samples to observe before starting any learning step. Default: 0. number_updates (int): How many times to use learning step in the learning phase. Default: 1. max_grad_norm (float): Maximum norm of the gradient used in learning. Default: 10. using_double_q (bool): Whether to use Double Q Learning network. Default: True. n_steps (int): Number of lookahead steps when estimating reward. See :ref:`NStepBuffer`. Default: 3. v_min (float): Lower bound for distributional value V. Default: -10. v_max (float): Upper bound for distributional value V. Default: 10. num_atoms (int): Number of atoms (discrete states) in the value V distribution. Default: 21. """ super().__init__(**kwargs) self.device = self._register_param(kwargs, "device", DEVICE, update=True) self.input_shape: Sequence[int] = input_shape if not isinstance( input_shape, int) else (input_shape, ) self.state_size: int = self.input_shape[0] self.output_shape: Sequence[int] = output_shape if not isinstance( output_shape, int) else (output_shape, ) self.action_size: int = self.output_shape[0] self.lr = float(self._register_param(kwargs, 'lr', 3e-4)) self.gamma = float(self._register_param(kwargs, 'gamma', 0.99)) self.tau = float(self._register_param(kwargs, 'tau', 0.002)) self.update_freq = int(self._register_param(kwargs, 'update_freq', 1)) self.batch_size = int( self._register_param(kwargs, 'batch_size', 80, update=True)) self.buffer_size = int( self._register_param(kwargs, 'buffer_size', int(1e5), update=True)) self.warm_up = int(self._register_param(kwargs, 'warm_up', 0)) self.number_updates = int( self._register_param(kwargs, 'number_updates', 1)) self.max_grad_norm = float( self._register_param(kwargs, 'max_grad_norm', 10)) self.iteration: int = 0 self.using_double_q = bool( self._register_param(kwargs, "using_double_q", True)) self.state_transform = state_transform if state_transform is not None else lambda x: x self.reward_transform = reward_transform if reward_transform is not None else lambda x: x v_min = float(self._register_param(kwargs, "v_min", -10)) v_max = float(self._register_param(kwargs, "v_max", 10)) self.num_atoms = int( self._register_param(kwargs, "num_atoms", 21, drop=True)) self.z_atoms = torch.linspace(v_min, v_max, self.num_atoms, device=self.device) self.z_delta = self.z_atoms[1] - self.z_atoms[0] self.buffer = PERBuffer(**kwargs) self.__batch_indices = torch.arange(self.batch_size, device=self.device) self.n_steps = int(self._register_param(kwargs, "n_steps", 3)) self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma) # Note that in case a pre_network is provided, e.g. a shared net that extracts pixels values, # it should be explicitly passed in kwargs kwargs["hidden_layers"] = to_numbers_seq( self._register_param(kwargs, "hidden_layers", (100, 100))) self.net = RainbowNet(self.input_shape, self.output_shape, num_atoms=self.num_atoms, **kwargs) self.target_net = RainbowNet(self.input_shape, self.output_shape, num_atoms=self.num_atoms, **kwargs) self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr) self.dist_probs = None self._loss = float('inf') @property def loss(self): return {'loss': self._loss} @loss.setter def loss(self, value): if isinstance(value, dict): value = value['loss'] self._loss = value def step(self, state, action, reward, next_state, done) -> None: """Letting the agent to take a step. On some steps the agent will initiate learning step. This is dependent on the `update_freq` value. Parameters: state: S(t) action: A(t) reward: R(t) nexxt_state: S(t+1) done: (bool) Whether the state is terminal. """ self.iteration += 1 state = to_tensor(self.state_transform(state)).float().to("cpu") next_state = to_tensor( self.state_transform(next_state)).float().to("cpu") reward = self.reward_transform(reward) # Delay adding to buffer to account for n_steps (particularly the reward) self.n_buffer.add(state=state.numpy(), action=[int(action)], reward=[reward], done=[done], next_state=next_state.numpy()) if not self.n_buffer.available: return self.buffer.add(**self.n_buffer.get().get_dict()) if self.iteration < self.warm_up: return if len(self.buffer) >= self.batch_size and (self.iteration % self.update_freq) == 0: for _ in range(self.number_updates): self.learn(self.buffer.sample()) # Update networks only once - sync local & target soft_update(self.target_net, self.net, self.tau) def act(self, state, eps: float = 0.) -> int: """ Returns actions for given state as per current policy. Parameters: state: Current available state from the environment. epislon: Epsilon value in the epislon-greedy policy. """ # Epsilon-greedy action selection if self._rng.random() < eps: return self._rng.randint(0, self.action_size - 1) state = to_tensor(self.state_transform(state)).float().unsqueeze(0).to( self.device) # state = to_tensor(self.state_transform(state)).float().to(self.device) self.dist_probs = self.net.act(state) q_values = (self.dist_probs * self.z_atoms).sum(-1) return int( q_values.argmax(-1)) # Action maximizes state-action value Q(s, a) def learn(self, experiences: Dict[str, List]) -> None: """ Parameters: experiences: Contains all experiences for the agent. Typically sampled from the memory buffer. Five keys are expected, i.e. `state`, `action`, `reward`, `next_state`, `done`. Each key contains a array and all arrays have to have the same length. """ rewards = to_tensor(experiences['reward']).float().to(self.device) dones = to_tensor(experiences['done']).type(torch.int).to(self.device) states = to_tensor(experiences['state']).float().to(self.device) next_states = to_tensor(experiences['next_state']).float().to( self.device) actions = to_tensor(experiences['action']).type(torch.long).to( self.device) assert rewards.shape == dones.shape == (self.batch_size, 1) assert states.shape == next_states.shape == (self.batch_size, self.state_size) assert actions.shape == (self.batch_size, 1) # Discrete domain with torch.no_grad(): prob_next = self.target_net.act(next_states) q_next = (prob_next * self.z_atoms).sum(-1) * self.z_delta if self.using_double_q: duel_prob_next = self.net.act(next_states) a_next = torch.argmax((duel_prob_next * self.z_atoms).sum(-1), dim=-1) else: a_next = torch.argmax(q_next, dim=-1) prob_next = prob_next[self.__batch_indices, a_next, :] m = self.net.dist_projection(rewards, 1 - dones, self.gamma**self.n_steps, prob_next) assert m.shape == (self.batch_size, self.num_atoms) log_prob = self.net(states, log_prob=True) assert log_prob.shape == (self.batch_size, self.action_size, self.num_atoms) log_prob = log_prob[self.__batch_indices, actions.squeeze(), :] assert log_prob.shape == m.shape == (self.batch_size, self.num_atoms) # Cross-entropy loss error and the loss is batch mean error = -torch.sum(m * log_prob, 1) assert error.shape == (self.batch_size, ) loss = error.mean() assert loss >= 0 self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm) self.optimizer.step() self._loss = float(loss.item()) if hasattr(self.buffer, 'priority_update'): assert (~torch.isnan(error)).any() self.buffer.priority_update(experiences['index'], error.detach().cpu().numpy()) # Update networks - sync local & target soft_update(self.target_net, self.net, self.tau) def state_dict(self) -> Dict[str, dict]: """Returns agent's state dictionary. Returns: State dicrionary for internal networks. """ return { "net": self.net.state_dict(), "target_net": self.target_net.state_dict() } def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False): data_logger.log_value("loss/agent", self._loss, step) if full_log and self.dist_probs is not None: for action_idx in range(self.action_size): dist = self.dist_probs[0, action_idx] data_logger.log_value(f'dist/expected_{action_idx}', (dist * self.z_atoms).sum().item(), step) data_logger.add_histogram(f'dist/Q_{action_idx}', min=self.z_atoms[0], max=self.z_atoms[-1], num=len(self.z_atoms), sum=dist.sum(), sum_squares=dist.pow(2).sum(), bucket_limits=self.z_atoms + self.z_delta, bucket_counts=dist, global_step=step) # This method, `log_metrics`, isn't executed on every iteration but just in case we delay plotting weights. # It simply might be quite costly. Thread wisely. if full_log: for idx, layer in enumerate(self.net.value_net.layers): if hasattr(layer, "weight"): data_logger.create_histogram( f"value_net/layer_weights_{idx}", layer.weight.cpu(), step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"value_net/layer_bias_{idx}", layer.bias.cpu(), step) for idx, layer in enumerate(self.net.advantage_net.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"advantage_net/layer_{idx}", layer.weight.cpu(), step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram( f"advantage_net/layer_bias_{idx}", layer.bias.cpu(), step) def get_state(self) -> AgentState: """Provides agent's internal state.""" return AgentState( model=self.name, state_space=self.state_size, action_space=self.action_size, config=self._config, buffer=copy.deepcopy(self.buffer.get_state()), network=copy.deepcopy(self.get_network_state()), ) def get_network_state(self) -> NetworkState: return NetworkState(net=dict(net=self.net.state_dict(), target_net=self.target_net.state_dict())) @staticmethod def from_state(state: AgentState) -> AgentBase: config = copy.copy(state.config) config.update({ 'input_shape': state.state_space, 'output_shape': state.action_space }) agent = RainbowAgent(**config) if state.network is not None: agent.set_network(state.network) if state.buffer is not None: agent.set_buffer(state.buffer) return agent def set_network(self, network_state: NetworkState) -> None: self.net.load_state_dict(network_state.net['net']) self.target_net.load_state_dict(network_state.net['target_net']) def set_buffer(self, buffer_state: BufferState) -> None: self.buffer = BufferFactory.from_state(buffer_state) def save_state(self, path: str) -> None: """Saves agent's state into a file. Parameters: path: String path where to write the state. """ agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, path: str) -> None: """Loads state from a file under provided path. Parameters: path: String path indicating where the state is stored. """ agent_state = torch.load(path) self._config = agent_state.get('config', {}) self.__dict__.update(**self._config) self.net.load_state_dict(agent_state['net']) self.target_net.load_state_dict(agent_state['target_net']) def save_buffer(self, path: str) -> None: """Saves data from the buffer into a file under provided path. Parameters: path: String path where to write the buffer. """ import json dump = self.buffer.dump_buffer(serialize=True) with open(path, 'w') as f: json.dump(dump, f) def load_buffer(self, path: str) -> None: """Loads data into the buffer from provided file path. Parameters: path: String path indicating where the buffer is stored. """ import json with open(path, 'r') as f: buffer_dump = json.load(f) self.buffer.load_buffer(buffer_dump) def __eq__(self, o: object) -> bool: return super().__eq__(o) \ and self._config == o._config \ and self.buffer == o.buffer \ and self.get_network_state() == o.get_network_state()
class DQNAgent(AgentBase): """Deep Q-Learning Network (DQN). The agent is not a vanilla DQN, although can be configured as such. The default config includes dual dueling nets and the priority experience buffer. Learning is also delayed by slowly copying to target nets (via tau parameter). Although NStep is implemented the default value is 1-step reward. There is also a specific implemntation of the DQN called the Rainbow which differs to this implementation by working on the discrete space projection of the Q(s,a) function. """ name = "DQN" def __init__(self, input_shape: Union[Sequence[int], int], output_shape: Union[Sequence[int], int], network_fn: Callable[[], NetworkType] = None, network_class: Type[NetworkTypeClass] = None, state_transform: Optional[Callable] = None, reward_transform: Optional[Callable] = None, **kwargs): """Initiates the DQN agent. Parameters: hidden_layers: (default: (64, 64) ) Tuple defining hidden dimensions in fully connected nets. lr: (default: 1e-3) learning rate gamma: (default: 0.99) discount factor tau: (default: 0.002) soft-copy factor update_freq: (default: 1) batch_size: (default: 32) buffer_size: (default: 1e5) warm_up: (default: 0) number_updates: (default: 1) max_grad_norm: (default: 10) using_double_q: (default: True) Whether to use double Q value n_steps: (int: 1) N steps reward lookahead """ super().__init__(**kwargs) self.device = self._register_param(kwargs, "device", DEVICE, update=True) # TODO: All this should be condenced with some structure, e.g. gym spaces self.input_shape: Sequence[int] = input_shape if not isinstance( input_shape, int) else (input_shape, ) self.state_size: int = self.input_shape[0] self.output_shape: Sequence[int] = output_shape if not isinstance( output_shape, int) else (output_shape, ) self.action_size: int = self.output_shape[0] self._config['state_size'] = self.state_size self._config['action_size'] = self.action_size self.lr = float(self._register_param(kwargs, 'lr', 3e-4)) # Learning rate self.gamma = float(self._register_param(kwargs, 'gamma', 0.99)) # Discount value self.tau = float(self._register_param(kwargs, 'tau', 0.002)) # Soft update self.update_freq = int(self._register_param(kwargs, 'update_freq', 1)) self.batch_size = int( self._register_param(kwargs, 'batch_size', 64, update=True)) self.buffer_size = int( self._register_param(kwargs, 'buffer_size', int(1e5), update=True)) self.warm_up = int(self._register_param(kwargs, 'warm_up', 0)) self.number_updates = int( self._register_param(kwargs, 'number_updates', 1)) self.max_grad_norm = float( self._register_param(kwargs, 'max_grad_norm', 10)) self.iteration: int = 0 self.buffer = PERBuffer(**kwargs) self.using_double_q = bool( self._register_param(kwargs, "using_double_q", True)) self.n_steps = int(self._register_param(kwargs, 'n_steps', 1)) self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma) hidden_layers = to_numbers_seq( self._register_param(kwargs, 'hidden_layers', (64, 64))) self.state_transform = state_transform if state_transform is not None else lambda x: x self.reward_transform = reward_transform if reward_transform is not None else lambda x: x if network_fn is not None: self.net = network_fn() self.target_net = network_fn() elif network_class is not None: self.net = network_class(self.input_shape, self.action_size, hidden_layers=hidden_layers, device=self.device) self.target_net = network_class(self.input_shape, self.action_size, hidden_layers=hidden_layers, device=self.device) else: self.net = DuelingNet(self.input_shape, self.output_shape, hidden_layers=hidden_layers, device=self.device) self.target_net = DuelingNet(self.input_shape, self.output_shape, hidden_layers=hidden_layers, device=self.device) self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr) self._loss: float = float('inf') @property def loss(self) -> Dict[str, float]: return {'loss': self._loss} @loss.setter def loss(self, value): if isinstance(value, dict): value = value['loss'] self._loss = value def __eq__(self, o: object) -> bool: return super().__eq__(o) \ and self._config == o._config \ and self.buffer == o.buffer \ and self.n_buffer == o.n_buffer \ and self.get_network_state() == o.get_network_state() def reset(self): self.net.reset_parameters() self.target_net.reset_parameters() self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr) def step(self, state, action, reward, next_state, done) -> None: """Letting the agent to take a step. On some steps the agent will initiate learning step. This is dependent on the `update_freq` value. Parameters: state: S(t) action: A(t) reward: R(t) next_state: S(t+1) done: (bool) Whether the state is terminal. """ self.iteration += 1 state = to_tensor(self.state_transform(state)).float().to("cpu") next_state = to_tensor( self.state_transform(next_state)).float().to("cpu") reward = self.reward_transform(reward) # Delay adding to buffer to account for n_steps (particularly the reward) self.n_buffer.add(state=state.numpy(), action=[int(action)], reward=[reward], done=[done], next_state=next_state.numpy()) if not self.n_buffer.available: return self.buffer.add(**self.n_buffer.get().get_dict()) if self.iteration < self.warm_up: return if len(self.buffer) >= self.batch_size and (self.iteration % self.update_freq) == 0: for _ in range(self.number_updates): self.learn(self.buffer.sample()) # Update networks only once - sync local & target soft_update(self.target_net, self.net, self.tau) def act(self, state, eps: float = 0.) -> int: """Returns actions for given state as per current policy. Parameters: state (array_like): current state eps (float): epsilon, for epsilon-greedy action selection Returns: Categorical value for the action. """ # Epsilon-greedy action selection if self._rng.random() < eps: return self._rng.randint(0, self.action_size - 1) state = to_tensor(self.state_transform(state)).float() state = state.unsqueeze(0).to(self.device) action_values = self.net.act(state) return int(torch.argmax(action_values.cpu())) def learn(self, experiences: Dict[str, list]) -> None: """Updates agent's networks based on provided experience. Parameters: experiences: Samples experiences from the experience buffer. """ rewards = to_tensor(experiences['reward']).type(torch.float32).to( self.device) dones = to_tensor(experiences['done']).type(torch.int).to(self.device) states = to_tensor(experiences['state']).type(torch.float32).to( self.device) next_states = to_tensor(experiences['next_state']).type( torch.float32).to(self.device) actions = to_tensor(experiences['action']).type(torch.long).to( self.device) with torch.no_grad(): Q_targets_next = self.target_net.act(next_states).detach() if self.using_double_q: _a = torch.argmax(self.net(next_states), dim=-1).unsqueeze(-1) max_Q_targets_next = Q_targets_next.gather(1, _a) else: max_Q_targets_next = Q_targets_next.max(1)[0].unsqueeze(1) Q_targets = rewards + self.n_buffer.n_gammas[ -1] * max_Q_targets_next * (1 - dones) Q_expected: torch.Tensor = self.net(states).gather(1, actions) loss = F.mse_loss(Q_expected, Q_targets) self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm) self.optimizer.step() self._loss = float(loss.item()) if hasattr(self.buffer, 'priority_update'): error = Q_expected - Q_targets assert any(~torch.isnan(error)) self.buffer.priority_update(experiences['index'], error.abs()) def state_dict(self) -> Dict[str, dict]: """Describes agent's networks. Returns: state: (dict) Provides actors and critics states. """ return { "net": self.net.state_dict(), "target_net": self.target_net.state_dict(), } def log_metrics(self, data_logger: DataLogger, step: int, full_log: bool = False): """Uses provided DataLogger to provide agent's metrics. Parameters: data_logger (DataLogger): Instance of the SummaryView, e.g. torch.utils.tensorboard.SummaryWritter. step (int): Ordering value, e.g. episode number. full_log (bool): Whether to all available information. Useful to log with lesser frequency. """ data_logger.log_value("loss/agent", self._loss, step) def get_state(self) -> AgentState: """Provides agent's internal state.""" return AgentState( model=self.name, state_space=self.state_size, action_space=self.action_size, config=self._config, buffer=copy.deepcopy(self.buffer.get_state()), network=copy.deepcopy(self.get_network_state()), ) def get_network_state(self) -> NetworkState: return NetworkState(net=dict(net=self.net.state_dict(), target_net=self.target_net.state_dict())) @staticmethod def from_state(state: AgentState) -> AgentBase: agent = DQNAgent(state.state_space, state.action_space, **state.config) if state.network is not None: agent.set_network(state.network) if state.buffer is not None: agent.set_buffer(state.buffer) return agent def set_buffer(self, buffer_state: BufferState) -> None: self.buffer = BufferFactory.from_state(buffer_state) def set_network(self, network_state: NetworkState) -> None: self.net.load_state_dict(network_state.net['net']) self.target_net.load_state_dict(network_state.net['target_net']) def save_state(self, path: str): """Saves agent's state into a file. Parameters: path: String path where to write the state. """ agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, *, path: Optional[str] = None, state: Optional[AgentState] = None) -> None: """Loads state from a file under provided path. Parameters: path: String path indicating where the state is stored. """ if path is None and state is None: raise ValueError( "Either `path` or `state` must be provided to load agent's state." ) if path is not None: state = torch.load(path) # Populate agent agent_state = state.agent self._config = agent_state.config self.__dict__.update(**self._config) # Populate network network_state = state.network self.net.load_state_dict(network_state.net['net']) self.target_net.load_state_dict(network_state.net['target_net']) self.buffer = PERBuffer(**self._config) def save_buffer(self, path: str) -> None: """Saves data from the buffer into a file under provided path. Parameters: path: String path where to write the buffer. """ import json dump = self.buffer.dump_buffer(serialize=True) with open(path, 'w') as f: json.dump(dump, f) def load_buffer(self, path: str) -> None: """Loads data into the buffer from provided file path. Parameters: path: String path indicating where the buffer is stored. """ import json with open(path, 'r') as f: buffer_dump = json.load(f) self.buffer.load_buffer(buffer_dump)
class D4PGAgent(AgentBase): """ Distributed Distributional DDPG (D4PG) [1]. Extends the DDPG agent with: 1. Distributional critic update. 2. The use of distributed parallel actors. 3. N-step returns. 4. Prioritization of the experience replay (PER). [1] "Distributed Distributional Deterministic Policy Gradients" (2018, ICLR) by G. Barth-Maron & M. Hoffman et al. """ name = "D4PG" def __init__(self, state_size: int, action_size: int, hidden_layers: Sequence[int] = (128, 128), **kwargs): """ Parameters: state_size (int): Number of input dimensions. action_size (int): Number of output dimensions hidden_layers (tuple of ints): Tuple defining hidden dimensions in fully connected nets. Default: (128, 128). Keyword parameters: gamma (float): Discount value. Default: 0.99. tau (float): Soft-copy factor. Default: 0.02. actor_lr (float): Learning rate for the actor (policy). Default: 0.0003. critic_lr (float): Learning rate for the critic (value function). Default: 0.0003. actor_hidden_layers (tuple of ints): Shape of network for actor. Default: `hideen_layers`. critic_hidden_layers (tuple of ints): Shape of network for critic. Default: `hideen_layers`. max_grad_norm_actor (float) Maximum norm value for actor gradient. Default: 100. max_grad_norm_critic (float): Maximum norm value for critic gradient. Default: 100. num_atoms (int): Number of discrete values for the value distribution. Default: 51. v_min (float): Value distribution minimum (left most) value. Default: -10. v_max (float): Value distribution maximum (right most) value. Default: 10. n_steps (int): Number of steps (N-steps) for the TD. Defualt: 3. batch_size (int): Number of samples used in learning. Default: 64. buffer_size (int): Maximum number of samples to store. Default: 1e6. warm_up (int): Number of samples to observe before starting any learning step. Default: 0. update_freq (int): Number of steps between each learning step. Default 1. number_updates (int): How many times to use learning step in the learning phase. Default: 1. action_min (float): Minimum returned action value. Default: -1. action_max (float): Maximum returned action value. Default: 1. action_scale (float): Multipler value for action. Default: 1. num_workers (int): Number of workers that will assume this agent. Default: 1. """ super().__init__(**kwargs) self.device = self._register_param(kwargs, "device", DEVICE) self.state_size = state_size self.action_size = action_size self.num_atoms = int(self._register_param(kwargs, 'num_atoms', 51)) v_min = float(self._register_param(kwargs, 'v_min', -10)) v_max = float(self._register_param(kwargs, 'v_max', 10)) # Reason sequence initiation. self.action_min = float(self._register_param(kwargs, 'action_min', -1)) self.action_max = float(self._register_param(kwargs, 'action_max', 1)) self.action_scale = float( self._register_param(kwargs, 'action_scale', 1)) self.gamma = float(self._register_param(kwargs, 'gamma', 0.99)) self.tau = float(self._register_param(kwargs, 'tau', 0.02)) self.batch_size = int(self._register_param(kwargs, 'batch_size', 64)) self.buffer_size = int( self._register_param(kwargs, 'buffer_size', int(1e6))) self.buffer = PERBuffer(self.batch_size, self.buffer_size) self.n_steps = int(self._register_param(kwargs, "n_steps", 3)) self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma) self.warm_up = int(self._register_param(kwargs, 'warm_up', 0)) self.update_freq = int(self._register_param(kwargs, 'update_freq', 1)) self.actor_hidden_layers = to_numbers_seq( self._register_param(kwargs, 'actor_hidden_layers', hidden_layers)) self.critic_hidden_layers = to_numbers_seq( self._register_param(kwargs, 'critic_hidden_layers', hidden_layers)) if kwargs.get("simple_policy", False): std_init = float(self._register_param(kwargs, "std_init", 1.0)) std_max = float(self._register_param(kwargs, "std_max", 2.0)) std_min = float(self._register_param(kwargs, "std_min", 0.05)) self.policy = MultivariateGaussianPolicySimple(self.action_size, std_init=std_init, std_min=std_min, std_max=std_max, device=self.device) else: self.policy = MultivariateGaussianPolicy(self.action_size, device=self.device) # This looks messy but it's not that bad. Actor, critic_net and Critic(critic_net). Then the same for `target_`. self.actor = ActorBody(state_size, self.policy.param_dim * action_size, hidden_layers=self.actor_hidden_layers, gate_out=torch.tanh, device=self.device) critic_net = CriticBody(state_size, action_size, out_features=self.num_atoms, hidden_layers=self.critic_hidden_layers, device=self.device) self.critic = CategoricalNet(num_atoms=self.num_atoms, v_min=v_min, v_max=v_max, net=critic_net, device=self.device) self.target_actor = ActorBody(state_size, self.policy.param_dim * action_size, hidden_layers=self.actor_hidden_layers, gate_out=torch.tanh, device=self.device) target_critic_net = CriticBody(state_size, action_size, out_features=self.num_atoms, hidden_layers=self.critic_hidden_layers, device=self.device) self.target_critic = CategoricalNet(num_atoms=self.num_atoms, v_min=v_min, v_max=v_max, net=target_critic_net, device=self.device) # Target sequence initiation hard_update(self.target_actor, self.actor) hard_update(self.target_critic, self.critic) # Optimization sequence initiation. self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4)) self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4)) self.value_loss_func = nn.BCELoss(reduction='none') self.actor_params = list(self.actor.parameters()) + list( self.policy.parameters()) self.actor_optimizer = Adam(self.actor_params, lr=self.actor_lr) self.critic_optimizer = Adam(self.critic.parameters(), lr=self.critic_lr) self.max_grad_norm_actor = float( self._register_param(kwargs, "max_grad_norm_actor", 100)) self.max_grad_norm_critic = float( self._register_param(kwargs, "max_grad_norm_critic", 100)) self.num_workers = int(self._register_param(kwargs, "num_workers", 1)) # Breath, my child. self.iteration = 0 self._loss_actor = float('nan') self._loss_critic = float('nan') @property def loss(self) -> Dict[str, float]: return {'actor': self._loss_actor, 'critic': self._loss_critic} @loss.setter def loss(self, value): if isinstance(value, dict): self._loss_actor = value['actor'] self._loss_critic = value['critic'] else: self._loss_actor = value self._loss_critic = value @torch.no_grad() def act(self, state, epsilon: float = 0.) -> List[float]: """ Returns actions for given state as per current policy. Parameters: state: Current available state from the environment. epislon: Epsilon value in the epislon-greedy policy. """ actions = [] state = to_tensor(state).view(self.num_workers, self.state_size).float().to(self.device) for worker in range(self.num_workers): if self._rng.random() < epsilon: action = self.action_scale * (torch.rand(self.action_size) - 0.5) else: action_seed = self.actor.act(state[worker].view(1, -1)) action_dist = self.policy(action_seed) action = action_dist.sample() action *= self.action_scale action = torch.clamp(action.squeeze(), self.action_min, self.action_max).cpu() actions.append(action.tolist()) assert len(actions) == self.num_workers return actions def step(self, states, actions, rewards, next_states, dones): self.iteration += 1 # Delay adding to buffer to account for n_steps (particularly the reward) self.n_buffer.add( state=torch.tensor(states).reshape(self.num_workers, self.state_size).float(), next_state=torch.tensor(next_states).reshape( self.num_workers, self.state_size).float(), action=torch.tensor(actions).reshape(self.num_workers, self.action_size).float(), reward=torch.tensor(rewards).reshape(self.num_workers, 1), done=torch.tensor(dones).reshape(self.num_workers, 1), ) if not self.n_buffer.available: return samples = self.n_buffer.get().get_dict() for worker_idx in range(self.num_workers): self.buffer.add( state=samples['state'][worker_idx], next_state=samples['next_state'][worker_idx], action=samples['action'][worker_idx], done=samples['done'][worker_idx], reward=samples['reward'][worker_idx], ) if self.iteration < self.warm_up: return if len(self.buffer) > self.batch_size and (self.iteration % self.update_freq) == 0: self.learn(self.buffer.sample()) def compute_value_loss(self, states, actions, next_states, rewards, dones, indices=None): # Q_w estimate value_dist_estimate = self.critic(states, actions) assert value_dist_estimate.shape == (self.batch_size, 1, self.num_atoms) value_dist = F.softmax(value_dist_estimate.squeeze(), dim=1) assert value_dist.shape == (self.batch_size, self.num_atoms) # Q_w' estimate via Bellman's dist operator next_action_seeds = self.target_actor.act(next_states) next_actions = self.policy(next_action_seeds).sample() assert next_actions.shape == (self.batch_size, self.action_size) target_value_dist_estimate = self.target_critic.act( states, next_actions) assert target_value_dist_estimate.shape == (self.batch_size, 1, self.num_atoms) target_value_dist_estimate = target_value_dist_estimate.squeeze() assert target_value_dist_estimate.shape == (self.batch_size, self.num_atoms) discount = self.gamma**self.n_steps target_value_projected = self.target_critic.dist_projection( rewards, 1 - dones, discount, target_value_dist_estimate) assert target_value_projected.shape == (self.batch_size, self.num_atoms) target_value_dist = F.softmax(target_value_dist_estimate, dim=-1).detach() assert target_value_dist.shape == (self.batch_size, self.num_atoms) # Comparing Q_w with Q_w' loss = self.value_loss_func(value_dist, target_value_projected) self._metric_batch_error = loss.detach().sum(dim=-1) samples_error = loss.sum(dim=-1).pow(2) loss_critic = samples_error.mean() if hasattr(self.buffer, 'priority_update') and indices is not None: assert (~torch.isnan(samples_error)).any() self.buffer.priority_update(indices, samples_error.detach().cpu().numpy()) return loss_critic def compute_policy_loss(self, states): # Compute actor loss pred_action_seeds = self.actor(states) pred_actions = self.policy.act(pred_action_seeds) pred_actions = self.policy(pred_action_seeds).rsample() # Negative because the optimizer minimizes, but we want to maximize the value value_dist = self.critic(states, pred_actions) self._batch_value_dist_metric = value_dist.detach() # Estimate on Z support return -torch.mean(value_dist * self.critic.z_atoms) def learn(self, experiences): """Update critics and actors""" # No need for size assertion since .view() has explicit sizes rewards = to_tensor(experiences['reward']).view( self.batch_size, 1).float().to(self.device) dones = to_tensor(experiences['done']).view(self.batch_size, 1).type( torch.int).to(self.device) states = to_tensor(experiences['state']).view( self.batch_size, self.state_size).float().to(self.device) actions = to_tensor(experiences['action']).view( self.batch_size, self.action_size).to(self.device) next_states = to_tensor(experiences['next_state']).view( self.batch_size, self.state_size).float().to(self.device) indices = None if hasattr(self.buffer, 'priority_update'): # When using PER buffer indices = experiences['index'] # Value (critic) optimization loss_critic = self.compute_value_loss(states, actions, next_states, rewards, dones, indices) self.critic_optimizer.zero_grad() loss_critic.backward() nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_critic) self.critic_optimizer.step() self._loss_critic = float(loss_critic.item()) # Policy (actor) optimization loss_actor = self.compute_policy_loss(states) self.actor_optimizer.zero_grad() loss_actor.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm_actor) self.actor_optimizer.step() self._loss_actor = float(loss_actor.item()) # Networks gradual sync soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau) def state_dict(self) -> Dict[str, dict]: """Describes agent's networks. Returns: state: (dict) Provides actors and critics states. """ return { "actor": self.actor.state_dict(), "target_actor": self.target_actor.state_dict(), "critic": self.critic.state_dict(), "target_critic": self.target_critic() } def log_metrics(self, data_logger: DataLogger, step, full_log=False): data_logger.log_value("loss/actor", self._loss_actor, step) data_logger.log_value("loss/critic", self._loss_critic, step) policy_params = { str(i): v for i, v in enumerate( itertools.chain.from_iterable(self.policy.parameters())) } data_logger.log_values_dict("policy/param", policy_params, step) data_logger.create_histogram('metric/batch_errors', self._metric_batch_error.sum(-1), step) data_logger.create_histogram('metric/batch_value_dist', self._batch_value_dist_metric, step) # This method, `log_metrics`, isn't executed on every iteration but just in case we delay plotting weights. # It simply might be quite costly. Thread wisely. if full_log: for idx, layer in enumerate(self.actor.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"actor/layer_weights_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"actor/layer_bias_{idx}", layer.bias, step) for idx, layer in enumerate(self.critic.net.layers): if hasattr(layer, "weight"): data_logger.create_histogram(f"critic/layer_{idx}", layer.weight, step) if hasattr(layer, "bias") and layer.bias is not None: data_logger.create_histogram(f"critic/layer_bias_{idx}", layer.bias, step) def get_state(self): return dict( actor=self.actor.state_dict(), target_actor=self.target_actor.state_dict(), critic=self.critic.state_dict(), target_critic=self.target_critic.state_dict(), config=self._config, ) def save_state(self, path: str): agent_state = self.get_state() torch.save(agent_state, path) def load_state(self, path: str): agent_state = torch.load(path) self._config = agent_state.get('config', {}) self.__dict__.update(**self._config) self.actor.load_state_dict(agent_state['actor']) self.critic.load_state_dict(agent_state['critic']) self.target_actor.load_state_dict(agent_state['target_actor']) self.target_critic.load_state_dict(agent_state['target_critic'])
class DQNAgent(AgentType): """Deep Q-Learning Network. Dual DQN implementation. """ name = "DQN" def __init__(self, state_size: Union[Sequence[int], int], action_size: int, lr: float = 0.001, gamma: float = 0.99, tau: float = 0.002, network_fn: Callable[[], NetworkType] = None, hidden_layers: Sequence[int] = (64, 64), state_transform: Optional[Callable] = None, reward_transform: Optional[Callable] = None, device=None, **kwargs): """ Accepted parameters: :param float lr: learning rate (default: 1e-3) :param float gamma: discount factor (default: 0.99) :param float tau: soft-copy factor (default: 0.002) """ self.device = device if device is not None else DEVICE self.state_size = state_size if not isinstance(state_size, int) else ( state_size, ) self.action_size = action_size self.lr = float(kwargs.get('lr', lr)) self.gamma = float(kwargs.get('gamma', gamma)) self.tau = float(kwargs.get('tau', tau)) self.update_freq = int(kwargs.get('update_freq', 1)) self.batch_size = int(kwargs.get('batch_size', 32)) self.warm_up = int(kwargs.get('warm_up', 0)) self.number_updates = int(kwargs.get('number_updates', 1)) self.max_grad_norm = float(kwargs.get('max_grad_norm', 10)) self.iteration: int = 0 self.buffer = PERBuffer(self.batch_size) self.using_double_q = bool(kwargs.get("using_double_q", False)) self.n_steps = kwargs.get("n_steps", 1) self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma) self.state_transform = state_transform if state_transform is not None else lambda x: x self.reward_transform = reward_transform if reward_transform is not None else lambda x: x if network_fn: self.net = network_fn().to(self.device) self.target_net = network_fn().to(self.device) else: hidden_layers = kwargs.get('hidden_layers', hidden_layers) self.net = DuelingNet(self.state_size[0], self.action_size, hidden_layers=hidden_layers).to(self.device) self.target_net = DuelingNet(self.state_size[0], self.action_size, hidden_layers=hidden_layers).to( self.device) self.optimizer = optim.SGD(self.net.parameters(), lr=self.lr) def step(self, state, action, reward, next_state, done) -> None: self.iteration += 1 state = self.state_transform(state) next_state = self.state_transform(next_state) reward = self.reward_transform(reward) # Delay adding to buffer to account for n_steps (particularly the reward) self.n_buffer.add(state=state, action=[action], reward=[reward], done=[done], next_state=next_state) if not self.n_buffer.available: return self.buffer.add(**self.n_buffer.get().get_dict()) if self.iteration < self.warm_up: return if len(self.buffer) > self.batch_size and (self.iteration % self.update_freq) == 0: for _ in range(self.number_updates): self.learn(self.buffer.sample()) def act(self, state, eps: float = 0.) -> int: """Returns actions for given state as per current policy. Params ====== state (array_like): current state eps (float): epsilon, for epsilon-greedy action selection """ # Epsilon-greedy action selection if np.random.random() < eps: return np.random.randint(self.action_size) state = self.state_transform(state) state = torch.from_numpy(state).float().unsqueeze(0).to(self.device) action_values = self.net.act(state) return np.argmax(action_values.cpu().data.numpy()) def learn(self, experiences) -> None: rewards = torch.tensor(experiences['reward'], dtype=torch.float32).to(self.device) dones = torch.tensor(experiences['done']).type(torch.int).to( self.device) states = torch.tensor(experiences['state'], dtype=torch.float32).to(self.device) next_states = torch.tensor(experiences['next_state'], dtype=torch.float32).to(self.device) actions = torch.tensor(experiences['action'], dtype=torch.long).to(self.device) with torch.no_grad(): Q_targets_next = self.target_net(next_states).detach() if self.using_double_q: _a = torch.argmax(self.net(next_states), dim=-1).unsqueeze(-1) max_Q_targets_next = Q_targets_next.gather(1, _a) else: max_Q_targets_next = Q_targets_next.max(1)[0].unsqueeze(1) Q_targets = rewards + self.n_buffer.n_gammas[ -1] * max_Q_targets_next * (1 - dones) Q_expected = self.net(states).gather(1, actions) loss = F.mse_loss(Q_expected, Q_targets) self.optimizer.zero_grad() loss.backward() clip_grad_norm_(self.net.parameters(), self.max_grad_norm) self.optimizer.step() self.loss = loss.item() if hasattr(self.buffer, 'priority_update'): td_error = Q_expected - Q_targets + 1e-9 # Tiny offset for zero-div assert any(~torch.isnan(td_error)) self.buffer.priority_update(experiences['index'], 1. / td_error.abs()) # Update networks - sync local & target soft_update(self.target_net, self.net, self.tau) def describe_agent(self) -> Dict: """Returns agent's state dictionary.""" return self.net.state_dict() def log_writer(self, writer, episode): writer.add_scalar("loss/actor", self.actor_loss, episode) writer.add_scalar("loss/critic", self.critic_loss, episode) def save_state(self, path: str): agent_state = dict(net=self.net.state_dict(), target_net=self.target_net.state_dict()) torch.save(agent_state, path) def load_state(self, path: str): agent_state = torch.load(path) self.net.load_state_dict(agent_state['net']) self.target_net.load_state_dict(agent_state['target_net'])