class TD3(AttributeSavingMixin, BatchAgent): """Twin Delayed Deep Deterministic Policy Gradients (TD3). See http://arxiv.org/abs/1802.09477 Args: policy (Policy): Policy. q_func1 (Module): First Q-function that takes state-action pairs as input and outputs predicted Q-values. q_func2 (Module): Second Q-function that takes state-action pairs as input and outputs predicted Q-values. policy_optimizer (Optimizer): Optimizer setup with the policy q_func1_optimizer (Optimizer): Optimizer setup with the first Q-function. q_func2_optimizer (Optimizer): Optimizer setup with the second Q-function. replay_buffer (ReplayBuffer): Replay buffer gamma (float): Discount factor explorer (Explorer): Explorer that specifies an exploration strategy. gpu (int): GPU device id if not None nor negative. replay_start_size (int): if the replay buffer's size is less than replay_start_size, skip update minibatch_size (int): Minibatch size update_interval (int): Model update interval in step phi (callable): Feature extractor applied to observations soft_update_tau (float): Tau of soft target update. logger (Logger): Logger used batch_states (callable): method which makes a batch of observations. default is `pfrl.utils.batch_states.batch_states` burnin_action_func (callable or None): If not None, this callable object is used to select actions before the model is updated one or more times during training. policy_update_delay (int): Delay of policy updates. Policy is updated once in `policy_update_delay` times of Q-function updates. target_policy_smoothing_func (callable): Callable that takes a batch of actions as input and outputs a noisy version of it. It is used for target policy smoothing when computing target Q-values. """ saved_attributes = ( "policy", "q_func1", "q_func2", "target_policy", "target_q_func1", "target_q_func2", "policy_optimizer", "q_func1_optimizer", "q_func2_optimizer", ) def __init__( self, policy, q_func1, q_func2, policy_optimizer, q_func1_optimizer, q_func2_optimizer, replay_buffer, gamma, explorer, gpu=None, replay_start_size=10000, minibatch_size=100, update_interval=1, phi=lambda x: x, soft_update_tau=5e-3, n_times_update=1, max_grad_norm=None, logger=getLogger(__name__), batch_states=batch_states, burnin_action_func=None, policy_update_delay=2, target_policy_smoothing_func=default_target_policy_smoothing_func, ): self.policy = policy self.q_func1 = q_func1 self.q_func2 = q_func2 if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.policy.to(self.device) self.q_func1.to(self.device) self.q_func2.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.phi = phi self.soft_update_tau = soft_update_tau self.logger = logger self.policy_optimizer = policy_optimizer self.q_func1_optimizer = q_func1_optimizer self.q_func2_optimizer = q_func2_optimizer self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=self.update, batchsize=minibatch_size, n_times_update=1, replay_start_size=replay_start_size, update_interval=update_interval, episodic_update=False, ) self.max_grad_norm = max_grad_norm self.batch_states = batch_states self.burnin_action_func = burnin_action_func self.policy_update_delay = policy_update_delay self.target_policy_smoothing_func = target_policy_smoothing_func self.t = 0 self.policy_n_updates = 0 self.q_func_n_updates = 0 self.last_state = None self.last_action = None # Target model self.target_policy = copy.deepcopy( self.policy).eval().requires_grad_(False) self.target_q_func1 = copy.deepcopy( self.q_func1).eval().requires_grad_(False) self.target_q_func2 = copy.deepcopy( self.q_func2).eval().requires_grad_(False) # Statistics self.q1_record = collections.deque(maxlen=1000) self.q2_record = collections.deque(maxlen=1000) self.q_func1_loss_record = collections.deque(maxlen=100) self.q_func2_loss_record = collections.deque(maxlen=100) self.policy_loss_record = collections.deque(maxlen=100) def sync_target_network(self): """Synchronize target network with current network.""" synchronize_parameters( src=self.policy, dst=self.target_policy, method="soft", tau=self.soft_update_tau, ) synchronize_parameters( src=self.q_func1, dst=self.target_q_func1, method="soft", tau=self.soft_update_tau, ) synchronize_parameters( src=self.q_func2, dst=self.target_q_func2, method="soft", tau=self.soft_update_tau, ) def update_q_func(self, batch): """Compute loss for a given Q-function.""" batch_next_state = batch["next_state"] batch_rewards = batch["reward"] batch_terminal = batch["is_state_terminal"] batch_state = batch["state"] batch_actions = batch["action"] batch_discount = batch["discount"] with torch.no_grad(), pfrl.utils.evaluating( self.target_policy), pfrl.utils.evaluating( self.target_q_func1), pfrl.utils.evaluating( self.target_q_func2): next_actions = self.target_policy_smoothing_func( self.target_policy(batch_next_state).sample()) next_q1 = self.target_q_func1((batch_next_state, next_actions)) next_q2 = self.target_q_func2((batch_next_state, next_actions)) next_q = torch.min(next_q1, next_q2) target_q = batch_rewards + batch_discount * ( 1.0 - batch_terminal) * torch.flatten(next_q) predict_q1 = torch.flatten(self.q_func1((batch_state, batch_actions))) predict_q2 = torch.flatten(self.q_func2((batch_state, batch_actions))) loss1 = F.mse_loss(target_q, predict_q1) loss2 = F.mse_loss(target_q, predict_q2) # Update stats self.q1_record.extend(predict_q1.detach().cpu().numpy()) self.q2_record.extend(predict_q2.detach().cpu().numpy()) self.q_func1_loss_record.append(float(loss1)) self.q_func2_loss_record.append(float(loss2)) self.q_func1_optimizer.zero_grad() loss1.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm) self.q_func1_optimizer.step() self.q_func2_optimizer.zero_grad() loss2.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm) self.q_func2_optimizer.step() self.q_func_n_updates += 1 def update_policy(self, batch): """Compute loss for actor.""" batch_state = batch["state"] onpolicy_actions = self.policy(batch_state).rsample() q = self.q_func1((batch_state, onpolicy_actions)) # Since we want to maximize Q, loss is negation of Q loss = -torch.mean(q) self.policy_loss_record.append(float(loss)) self.policy_optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy_optimizer.step() self.policy_n_updates += 1 def update(self, experiences, errors_out=None): """Update the model from experiences""" batch = batch_experiences(experiences, self.device, self.phi, self.gamma) self.update_q_func(batch) if self.q_func_n_updates % self.policy_update_delay == 0: self.update_policy(batch) self.sync_target_network() def batch_select_onpolicy_action(self, batch_obs): with torch.no_grad(), pfrl.utils.evaluating(self.policy): batch_xs = self.batch_states(batch_obs, self.device, self.phi) batch_action = self.policy(batch_xs).sample().cpu().numpy() return list(batch_action) def batch_act(self, batch_obs): if self.training: return self._batch_act_train(batch_obs) else: return self._batch_act_eval(batch_obs) def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): if self.training: self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) def _batch_act_eval(self, batch_obs): assert not self.training return self.batch_select_onpolicy_action(batch_obs) def _batch_act_train(self, batch_obs): assert self.training if self.burnin_action_func is not None and self.policy_n_updates == 0: batch_action = [ self.burnin_action_func() for _ in range(len(batch_obs)) ] else: batch_onpolicy_action = self.batch_select_onpolicy_action( batch_obs) batch_action = [ self.explorer.select_action(self.t, lambda: batch_onpolicy_action[i]) for i in range(len(batch_onpolicy_action)) ] self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) return batch_action def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): assert self.training for i in range(len(batch_obs)): self.t += 1 if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer self.replay_buffer.append( state=self.batch_last_obs[i], action=self.batch_last_action[i], reward=batch_reward[i], next_state=batch_obs[i], next_action=None, is_state_terminal=batch_done[i], env_id=i, ) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) def get_statistics(self): return [ ("average_q1", _mean_or_nan(self.q1_record)), ("average_q2", _mean_or_nan(self.q2_record)), ("average_q_func1_loss", _mean_or_nan(self.q_func1_loss_record)), ("average_q_func2_loss", _mean_or_nan(self.q_func2_loss_record)), ("average_policy_loss", _mean_or_nan(self.policy_loss_record)), ("policy_n_updates", self.policy_n_updates), ("q_func_n_updates", self.q_func_n_updates), ]
def __init__( self, policy, q_func1, q_func2, policy_optimizer, q_func1_optimizer, q_func2_optimizer, replay_buffer, gamma, explorer, gpu=None, replay_start_size=10000, minibatch_size=100, update_interval=1, phi=lambda x: x, soft_update_tau=5e-3, n_times_update=1, max_grad_norm=None, logger=getLogger(__name__), batch_states=batch_states, burnin_action_func=None, policy_update_delay=2, target_policy_smoothing_func=default_target_policy_smoothing_func, ): self.policy = policy self.q_func1 = q_func1 self.q_func2 = q_func2 if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.policy.to(self.device) self.q_func1.to(self.device) self.q_func2.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.phi = phi self.soft_update_tau = soft_update_tau self.logger = logger self.policy_optimizer = policy_optimizer self.q_func1_optimizer = q_func1_optimizer self.q_func2_optimizer = q_func2_optimizer self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=self.update, batchsize=minibatch_size, n_times_update=1, replay_start_size=replay_start_size, update_interval=update_interval, episodic_update=False, ) self.max_grad_norm = max_grad_norm self.batch_states = batch_states self.burnin_action_func = burnin_action_func self.policy_update_delay = policy_update_delay self.target_policy_smoothing_func = target_policy_smoothing_func self.t = 0 self.policy_n_updates = 0 self.q_func_n_updates = 0 self.last_state = None self.last_action = None # Target model self.target_policy = copy.deepcopy( self.policy).eval().requires_grad_(False) self.target_q_func1 = copy.deepcopy( self.q_func1).eval().requires_grad_(False) self.target_q_func2 = copy.deepcopy( self.q_func2).eval().requires_grad_(False) # Statistics self.q1_record = collections.deque(maxlen=1000) self.q2_record = collections.deque(maxlen=1000) self.q_func1_loss_record = collections.deque(maxlen=100) self.q_func2_loss_record = collections.deque(maxlen=100) self.policy_loss_record = collections.deque(maxlen=100)
class DQN(agent.AttributeSavingMixin, agent.BatchAgent): """Deep Q-Network algorithm. Args: q_function (StateQFunction): Q-function optimizer (Optimizer): Optimizer that is already setup replay_buffer (ReplayBuffer): Replay buffer gamma (float): Discount factor explorer (Explorer): Explorer that specifies an exploration strategy. gpu (int): GPU device id if not None nor negative. replay_start_size (int): if the replay buffer's size is less than replay_start_size, skip update minibatch_size (int): Minibatch size update_interval (int): Model update interval in step target_update_interval (int): Target model update interval in step clip_delta (bool): Clip delta if set True phi (callable): Feature extractor applied to observations target_update_method (str): 'hard' or 'soft'. soft_update_tau (float): Tau of soft target update. n_times_update (int): Number of repetition of update batch_accumulator (str): 'mean' or 'sum' episodic_update_len (int or None): Subsequences of this length are used for update if set int and episodic_update=True logger (Logger): Logger used batch_states (callable): method which makes a batch of observations. default is `pfrl.utils.batch_states.batch_states` recurrent (bool): If set to True, `model` is assumed to implement `pfrl.nn.Recurrent` and is updated in a recurrent manner. max_grad_norm (float or None): Maximum L2 norm of the gradient used for gradient clipping. If set to None, the gradient is not clipped. """ saved_attributes = ("model", "target_model", "optimizer") def __init__( self, q_function, optimizer, replay_buffer, gamma, explorer, gpu=None, replay_start_size=50000, minibatch_size=32, update_interval=1, target_update_interval=10000, clip_delta=True, phi=lambda x: x, target_update_method="hard", soft_update_tau=1e-2, n_times_update=1, batch_accumulator="mean", episodic_update_len=None, logger=getLogger(__name__), batch_states=batch_states, recurrent=False, max_grad_norm=None, ): self.rnd_reward = 0 self.ngu_reward = 0 self.model = q_function if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.model.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.optimizer = optimizer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.target_update_interval = target_update_interval self.clip_delta = clip_delta self.phi = phi self.target_update_method = target_update_method self.soft_update_tau = soft_update_tau self.batch_accumulator = batch_accumulator assert batch_accumulator in ("mean", "sum") self.logger = logger self.batch_states = batch_states self.recurrent = recurrent if self.recurrent: update_func = self.update_from_episodes else: update_func = self.update self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=update_func, batchsize=minibatch_size, episodic_update=recurrent, episodic_update_len=episodic_update_len, n_times_update=n_times_update, replay_start_size=replay_start_size, update_interval=update_interval, ) self.minibatch_size = minibatch_size self.episodic_update_len = episodic_update_len self.replay_start_size = replay_start_size self.update_interval = update_interval self.max_grad_norm = max_grad_norm assert ( target_update_interval % update_interval == 0 ), "target_update_interval should be a multiple of update_interval" self.t = 0 self.optim_t = 0 # Compensate pytorch optim not having `t` self._cumulative_steps = 0 self.last_state = None self.last_action = None self.target_model = None self.sync_target_network() # Statistics self.q_record = collections.deque(maxlen=1000) self.loss_record = collections.deque(maxlen=100) # Recurrent states of the model self.train_recurrent_states = None self.train_prev_recurrent_states = None self.test_recurrent_states = None self.replay_buffer_lock = None # Error checking if (self.replay_buffer.capacity is not None and self.replay_buffer.capacity < self.replay_updater.replay_start_size): raise ValueError( "Replay start size cannot exceed replay buffer capacity.") def set_rnd_module(self, rnd_module): self.rnd_module = rnd_module self.rnd_reward = 1 def set_ngu_module(self, ngu_module): self.ngu_module = ngu_module self.ngu_reward = 1 @property def cumulative_steps(self): # cumulative_steps counts the overall steps during the training. return self._cumulative_steps def _setup_actor_learner_training(self, n_actors, actor_update_interval, update_counter): assert actor_update_interval > 0 self.actor_update_interval = actor_update_interval self.update_counter = update_counter # Make a copy on shared memory and share among actors and the poller shared_model = copy.deepcopy(self.model).cpu() shared_model.share_memory() # Pipes are used for infrequent communication learner_pipes, actor_pipes = list( zip(*[mp.Pipe() for _ in range(n_actors)])) return (shared_model, learner_pipes, actor_pipes) def sync_target_network(self): """Synchronize target network with current network.""" if self.target_model is None: self.target_model = copy.deepcopy(self.model) def flatten_parameters(mod): if isinstance(mod, torch.nn.RNNBase): mod.flatten_parameters() # RNNBase.flatten_parameters must be called again after deep-copy. # See: https://discuss.pytorch.org/t/why-do-we-need-flatten-parameters-when-using-rnn-with-dataparallel/46506 # NOQA self.target_model.apply(flatten_parameters) # set target n/w to evaluate only. self.target_model.eval() else: synchronize_parameters( src=self.model, dst=self.target_model, method=self.target_update_method, tau=self.soft_update_tau, ) def update(self, experiences, errors_out=None): """Update the model from experiences Args: experiences (list): List of lists of dicts. For DQN, each dict must contains: - state (object): State - action (object): Action - reward (float): Reward - is_state_terminal (bool): True iff next state is terminal - next_state (object): Next state - weight (float, optional): Weight coefficient. It can be used for importance sampling. errors_out (list or None): If set to a list, then TD-errors computed from the given experiences are appended to the list. Returns: None """ has_weight = "weight" in experiences[0][0] exp_batch = batch_experiences( experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) if self.rnd_reward: self.rnd_module.train(exp_batch) # if self.ngu_reward: # self.ngu_module.train(exp_batch) if has_weight: exp_batch["weights"] = torch.tensor( # pylint: disable=not-callable [elem[0]["weight"] for elem in experiences], device=self.device, dtype=torch.float32, ) if errors_out is None: errors_out = [] loss = self._compute_loss(exp_batch, errors_out=errors_out) if has_weight: self.replay_buffer.update_errors(errors_out) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: pfrl.utils.clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.optim_t += 1 def update_from_episodes(self, episodes, errors_out=None): assert errors_out is None, "Recurrent DQN does not support PrioritizedBuffer" episodes = sorted(episodes, key=len, reverse=True) exp_batch = batch_recurrent_experiences( episodes, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) loss = self._compute_loss(exp_batch, errors_out=None) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: pfrl.utils.clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.optim_t += 1 def _compute_target_values(self, exp_batch): batch_next_state = exp_batch["next_state"] if self.recurrent: target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_qout = self.target_model(batch_next_state) next_q_max = target_next_qout.max batch_terminal = exp_batch["is_state_terminal"] discount = exp_batch["discount"] batch_rewards = exp_batch["reward"] return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max def _compute_y_and_t(self, exp_batch): batch_size = exp_batch["reward"].shape[0] # Compute Q-values for current states batch_state = exp_batch["state"] if self.recurrent: qout, _ = pack_and_forward(self.model, batch_state, exp_batch["recurrent_state"]) else: qout = self.model(batch_state) batch_actions = exp_batch["action"] batch_q = torch.reshape(qout.evaluate_actions(batch_actions), (batch_size, 1)) with torch.no_grad(): batch_q_target = torch.reshape( self._compute_target_values(exp_batch), (batch_size, 1)) return batch_q, batch_q_target def _compute_loss(self, exp_batch, errors_out=None): """Compute the Q-learning loss for a batch of experiences Args: exp_batch (dict): A dict of batched arrays of transitions Returns: Computed loss from the minibatch of experiences """ y, t = self._compute_y_and_t(exp_batch) self.q_record.extend(y.detach().cpu().numpy().ravel()) if errors_out is not None: del errors_out[:] delta = torch.abs(y - t) if delta.ndim == 2: delta = torch.sum(delta, dim=1) delta = delta.detach().cpu().numpy() for e in delta: errors_out.append(e) if "weights" in exp_batch: return compute_weighted_value_loss( y, t, exp_batch["weights"], clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator, ) else: return compute_value_loss( y, t, clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator, ) def _evaluate_model_and_update_recurrent_states(self, batch_obs): batch_xs = self.batch_states(batch_obs, self.device, self.phi) if self.recurrent: if self.training: self.train_prev_recurrent_states = self.train_recurrent_states batch_av, self.train_recurrent_states = one_step_forward( self.model, batch_xs, self.train_recurrent_states) else: batch_av, self.test_recurrent_states = one_step_forward( self.model, batch_xs, self.test_recurrent_states) else: batch_av = self.model(batch_xs) return batch_av def batch_act(self, batch_obs): with torch.no_grad(), evaluating(self.model): batch_av = self._evaluate_model_and_update_recurrent_states( batch_obs) batch_argmax = batch_av.greedy_actions.cpu().numpy() if self.training: batch_action = [ self.explorer.select_action( self.t, lambda: batch_argmax[i], action_value=batch_av[i:i + 1], ) for i in range(len(batch_obs)) ] self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) else: batch_action = batch_argmax return batch_action def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): for i in range(len(batch_obs)): self.t += 1 self._cumulative_steps += 1 # Update the target network if self.t % self.target_update_interval == 0: self.sync_target_network() if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer transition = { "state": self.batch_last_obs[i], "action": self.batch_last_action[i], "reward": batch_reward[i], "next_state": batch_obs[i], "next_action": None, "is_state_terminal": batch_done[i], } if self.recurrent: transition["recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at( self.train_prev_recurrent_states, i, detach=True)) transition[ "next_recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at(self.train_recurrent_states, i, detach=True)) self.replay_buffer.append(env_id=i, **transition) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) if self.recurrent: # Reset recurrent states when episodes end self.train_prev_recurrent_states = None self.train_recurrent_states = _batch_reset_recurrent_states_when_episodes_end( # NOQA batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.train_recurrent_states, ) def _batch_observe_eval(self, batch_obs, batch_reward, batch_done, batch_reset): if self.recurrent: # Reset recurrent states when episodes end self.test_recurrent_states = _batch_reset_recurrent_states_when_episodes_end( # NOQA batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.test_recurrent_states, ) def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): if self.training: return self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) else: return self._batch_observe_eval(batch_obs, batch_reward, batch_done, batch_reset) def _can_start_replay(self): if len(self.replay_buffer) < self.replay_start_size: return False if self.recurrent and self.replay_buffer.n_episodes < self.minibatch_size: return False return True def _poll_pipe(self, actor_idx, pipe, replay_buffer_lock, exception_event): if pipe.closed: return try: while pipe.poll() and not exception_event.is_set(): cmd, data = pipe.recv() if cmd == "get_statistics": assert data is None with replay_buffer_lock: stats = self.get_statistics() pipe.send(stats) elif cmd == "load": self.load(data) pipe.send(None) elif cmd == "save": self.save(data) pipe.send(None) elif cmd == "transition": with replay_buffer_lock: if "env_id" not in data: data["env_id"] = actor_idx self.replay_buffer.append(**data) self._cumulative_steps += 1 elif cmd == "stop_episode": idx = actor_idx if data is None else data with replay_buffer_lock: self.replay_buffer.stop_current_episode(env_id=idx) stats = self.get_statistics() pipe.send(stats) else: raise RuntimeError( "Unknown command from actor: {}".format(cmd)) except EOFError: pipe.close() except Exception: self.logger.exception("Poller loop failed. Exiting") exception_event.set() def _learner_loop( self, shared_model, pipes, replay_buffer_lock, stop_event, exception_event, n_updates=None, ): try: update_counter = 0 # To stop this loop, call stop_event.set() while not stop_event.is_set(): # Update model if possible if not self._can_start_replay(): continue if n_updates is not None: assert self.optim_t <= n_updates if self.optim_t == n_updates: stop_event.set() break if self.recurrent: with replay_buffer_lock: episodes = self.replay_buffer.sample_episodes( self.minibatch_size, self.episodic_update_len) self.update_from_episodes(episodes) else: with replay_buffer_lock: transitions = self.replay_buffer.sample( self.minibatch_size) self.update(transitions) # Update the shared model. This can be expensive if GPU is used # since this is a DtoH copy, so it is updated only at regular # intervals. update_counter += 1 if update_counter % self.actor_update_interval == 0: with self.update_counter.get_lock(): self.update_counter.value += 1 shared_model.load_state_dict(self.model.state_dict()) # To keep the ratio of target updates to model updates, # here we calculate back the effective current timestep # from update_interval and number of updates so far. effective_timestep = self.optim_t * self.update_interval # We can safely assign self.t since in the learner # it isn't updated by any other method self.t = effective_timestep if effective_timestep % self.target_update_interval == 0: self.sync_target_network() except Exception: self.logger.exception("Learner loop failed. Exiting") exception_event.set() def _poller_loop(self, shared_model, pipes, replay_buffer_lock, stop_event, exception_event): # To stop this loop, call stop_event.set() while not stop_event.is_set() and not exception_event.is_set(): time.sleep(1e-6) # Poll actors for messages for i, pipe in enumerate(pipes): self._poll_pipe(i, pipe, replay_buffer_lock, exception_event) def setup_actor_learner_training(self, n_actors, update_counter=None, n_updates=None, actor_update_interval=8): if update_counter is None: update_counter = mp.Value(ctypes.c_ulong) (shared_model, learner_pipes, actor_pipes) = self._setup_actor_learner_training( n_actors, actor_update_interval, update_counter) exception_event = mp.Event() def make_actor(i): return pfrl.agents.StateQFunctionActor( pipe=actor_pipes[i], model=shared_model, explorer=self.explorer, phi=self.phi, batch_states=self.batch_states, logger=self.logger, recurrent=self.recurrent, ) replay_buffer_lock = mp.Lock() self.replay_buffer_lock = replay_buffer_lock poller_stop_event = mp.Event() poller = pfrl.utils.StoppableThread( target=self._poller_loop, kwargs=dict( shared_model=shared_model, pipes=learner_pipes, replay_buffer_lock=replay_buffer_lock, stop_event=poller_stop_event, exception_event=exception_event, ), stop_event=poller_stop_event, ) learner_stop_event = mp.Event() learner = pfrl.utils.StoppableThread( target=self._learner_loop, kwargs=dict( shared_model=shared_model, pipes=learner_pipes, replay_buffer_lock=replay_buffer_lock, stop_event=learner_stop_event, n_updates=n_updates, exception_event=exception_event, ), stop_event=learner_stop_event, ) return make_actor, learner, poller, exception_event def stop_episode(self): if self.recurrent: self.test_recurrent_states = None def get_statistics(self): return [ ("average_q", _mean_or_nan(self.q_record)), ("average_loss", _mean_or_nan(self.loss_record)), ("cumulative_steps", self.cumulative_steps), ("n_updates", self.optim_t), ("rlen", len(self.replay_buffer)), ]
def __init__( self, q_function, optimizer, replay_buffer, gamma, explorer, gpu=None, replay_start_size=50000, minibatch_size=32, update_interval=1, target_update_interval=10000, clip_delta=True, phi=lambda x: x, target_update_method="hard", soft_update_tau=1e-2, n_times_update=1, batch_accumulator="mean", episodic_update_len=None, logger=getLogger(__name__), batch_states=batch_states, recurrent=False, max_grad_norm=None, ): self.rnd_reward = 0 self.ngu_reward = 0 self.model = q_function if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.model.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.optimizer = optimizer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.target_update_interval = target_update_interval self.clip_delta = clip_delta self.phi = phi self.target_update_method = target_update_method self.soft_update_tau = soft_update_tau self.batch_accumulator = batch_accumulator assert batch_accumulator in ("mean", "sum") self.logger = logger self.batch_states = batch_states self.recurrent = recurrent if self.recurrent: update_func = self.update_from_episodes else: update_func = self.update self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=update_func, batchsize=minibatch_size, episodic_update=recurrent, episodic_update_len=episodic_update_len, n_times_update=n_times_update, replay_start_size=replay_start_size, update_interval=update_interval, ) self.minibatch_size = minibatch_size self.episodic_update_len = episodic_update_len self.replay_start_size = replay_start_size self.update_interval = update_interval self.max_grad_norm = max_grad_norm assert ( target_update_interval % update_interval == 0 ), "target_update_interval should be a multiple of update_interval" self.t = 0 self.optim_t = 0 # Compensate pytorch optim not having `t` self._cumulative_steps = 0 self.last_state = None self.last_action = None self.target_model = None self.sync_target_network() # Statistics self.q_record = collections.deque(maxlen=1000) self.loss_record = collections.deque(maxlen=100) # Recurrent states of the model self.train_recurrent_states = None self.train_prev_recurrent_states = None self.test_recurrent_states = None self.replay_buffer_lock = None # Error checking if (self.replay_buffer.capacity is not None and self.replay_buffer.capacity < self.replay_updater.replay_start_size): raise ValueError( "Replay start size cannot exceed replay buffer capacity.")
class SoftActorCritic(AttributeSavingMixin, BatchAgent): """Soft Actor-Critic (SAC). See https://arxiv.org/abs/1812.05905 Args: policy (Policy): Policy. q_func1 (Module): First Q-function that takes state-action pairs as input and outputs predicted Q-values. q_func2 (Module): Second Q-function that takes state-action pairs as input and outputs predicted Q-values. policy_optimizer (Optimizer): Optimizer setup with the policy q_func1_optimizer (Optimizer): Optimizer setup with the first Q-function. q_func2_optimizer (Optimizer): Optimizer setup with the second Q-function. replay_buffer (ReplayBuffer): Replay buffer gamma (float): Discount factor gpu (int): GPU device id if not None nor negative. replay_start_size (int): if the replay buffer's size is less than replay_start_size, skip update minibatch_size (int): Minibatch size update_interval (int): Model update interval in step phi (callable): Feature extractor applied to observations soft_update_tau (float): Tau of soft target update. logger (Logger): Logger used batch_states (callable): method which makes a batch of observations. default is `pfrl.utils.batch_states.batch_states` burnin_action_func (callable or None): If not None, this callable object is used to select actions before the model is updated one or more times during training. initial_temperature (float): Initial temperature value. If `entropy_target` is set to None, the temperature is fixed to it. entropy_target (float or None): If set to a float, the temperature is adjusted during training to match the policy's entropy to it. temperature_optimizer_lr (float): Learning rate of the temperature optimizer. If set to None, Adam with default hyperparameters is used. act_deterministically (bool): If set to True, choose most probable actions in the act method instead of sampling from distributions. """ saved_attributes = ( "policy", "q_func1", "q_func2", "target_q_func1", "target_q_func2", "policy_optimizer", "q_func1_optimizer", "q_func2_optimizer", "temperature_holder", "temperature_optimizer", ) def __init__( self, policy, q_func1, q_func2, policy_optimizer, q_func1_optimizer, q_func2_optimizer, replay_buffer, gamma, gpu=None, replay_start_size=10000, minibatch_size=100, update_interval=1, phi=lambda x: x, soft_update_tau=5e-3, max_grad_norm=None, logger=getLogger(__name__), batch_states=batch_states, burnin_action_func=None, initial_temperature=1.0, entropy_target=None, temperature_optimizer_lr=None, act_deterministically=True, ): self.policy = policy self.q_func1 = q_func1 self.q_func2 = q_func2 if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.policy.to(self.device) self.q_func1.to(self.device) self.q_func2.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.gamma = gamma self.gpu = gpu self.phi = phi self.soft_update_tau = soft_update_tau self.logger = logger self.policy_optimizer = policy_optimizer self.q_func1_optimizer = q_func1_optimizer self.q_func2_optimizer = q_func2_optimizer self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=self.update, batchsize=minibatch_size, n_times_update=1, replay_start_size=replay_start_size, update_interval=update_interval, episodic_update=False, ) self.max_grad_norm = max_grad_norm self.batch_states = batch_states self.burnin_action_func = burnin_action_func self.initial_temperature = initial_temperature self.entropy_target = entropy_target if self.entropy_target is not None: self.temperature_holder = TemperatureHolder( initial_log_temperature=np.log(initial_temperature) ) if temperature_optimizer_lr is not None: self.temperature_optimizer = torch.optim.Adam( self.temperature_holder.parameters(), lr=temperature_optimizer_lr ) else: self.temperature_optimizer = torch.optim.Adam( self.temperature_holder.parameters() ) if gpu is not None and gpu >= 0: self.temperature_holder.to(self.device) else: self.temperature_holder = None self.temperature_optimizer = None self.act_deterministically = act_deterministically self.t = 0 # Target model self.target_q_func1 = copy.deepcopy(self.q_func1).eval().requires_grad_(False) self.target_q_func2 = copy.deepcopy(self.q_func2).eval().requires_grad_(False) # Statistics self.q1_record = collections.deque(maxlen=1000) self.q2_record = collections.deque(maxlen=1000) self.entropy_record = collections.deque(maxlen=1000) self.q_func1_loss_record = collections.deque(maxlen=100) self.q_func2_loss_record = collections.deque(maxlen=100) self.n_policy_updates = 0 @property def temperature(self): if self.entropy_target is None: return self.initial_temperature else: with torch.no_grad(): return float(self.temperature_holder()) def sync_target_network(self): """Synchronize target network with current network.""" synchronize_parameters( src=self.q_func1, dst=self.target_q_func1, method="soft", tau=self.soft_update_tau, ) synchronize_parameters( src=self.q_func2, dst=self.target_q_func2, method="soft", tau=self.soft_update_tau, ) def update_q_func(self, batch): """Compute loss for a given Q-function.""" batch_next_state = batch["next_state"] batch_rewards = batch["reward"] batch_terminal = batch["is_state_terminal"] batch_state = batch["state"] batch_actions = batch["action"] batch_discount = batch["discount"] with torch.no_grad(), pfrl.utils.evaluating(self.policy), pfrl.utils.evaluating( self.target_q_func1 ), pfrl.utils.evaluating(self.target_q_func2): next_action_distrib = self.policy(batch_next_state) next_actions = next_action_distrib.sample() next_log_prob = next_action_distrib.log_prob(next_actions) next_q1 = self.target_q_func1((batch_next_state, next_actions)) next_q2 = self.target_q_func2((batch_next_state, next_actions)) next_q = torch.min(next_q1, next_q2) entropy_term = self.temperature * next_log_prob[..., None] assert next_q.shape == entropy_term.shape target_q = batch_rewards + batch_discount * ( 1.0 - batch_terminal ) * torch.flatten(next_q - entropy_term) predict_q1 = torch.flatten(self.q_func1((batch_state, batch_actions))) predict_q2 = torch.flatten(self.q_func2((batch_state, batch_actions))) loss1 = 0.5 * F.mse_loss(target_q, predict_q1) loss2 = 0.5 * F.mse_loss(target_q, predict_q2) # Update stats self.q1_record.extend(predict_q1.detach().cpu().numpy()) self.q2_record.extend(predict_q2.detach().cpu().numpy()) self.q_func1_loss_record.append(float(loss1)) self.q_func2_loss_record.append(float(loss2)) self.q_func1_optimizer.zero_grad() loss1.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm) self.q_func1_optimizer.step() self.q_func2_optimizer.zero_grad() loss2.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm) self.q_func2_optimizer.step() def update_temperature(self, log_prob): assert not log_prob.requires_grad loss = -torch.mean(self.temperature_holder() * (log_prob + self.entropy_target)) self.temperature_optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.temperature_holder.parameters(), self.max_grad_norm) self.temperature_optimizer.step() def update_policy_and_temperature(self, batch): """Compute loss for actor.""" batch_state = batch["state"] action_distrib = self.policy(batch_state) actions = action_distrib.rsample() log_prob = action_distrib.log_prob(actions) q1 = self.q_func1((batch_state, actions)) q2 = self.q_func2((batch_state, actions)) q = torch.min(q1, q2) entropy_term = self.temperature * log_prob[..., None] assert q.shape == entropy_term.shape loss = torch.mean(entropy_term - q) self.policy_optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy_optimizer.step() self.n_policy_updates += 1 if self.entropy_target is not None: self.update_temperature(log_prob.detach()) # Record entropy with torch.no_grad(): try: self.entropy_record.extend( action_distrib.entropy().detach().cpu().numpy() ) except NotImplementedError: # Record - log p(x) instead self.entropy_record.extend(-log_prob.detach().cpu().numpy()) def update(self, experiences, errors_out=None): """Update the model from experiences""" batch = batch_experiences(experiences, self.device, self.phi, self.gamma) self.update_q_func(batch) self.update_policy_and_temperature(batch) self.sync_target_network() def batch_select_greedy_action(self, batch_obs, deterministic=False): with torch.no_grad(), pfrl.utils.evaluating(self.policy): batch_xs = self.batch_states(batch_obs, self.device, self.phi) policy_out = self.policy(batch_xs) if deterministic: batch_action = mode_of_distribution(policy_out).cpu().numpy() else: batch_action = policy_out.sample().cpu().numpy() return batch_action def batch_act(self, batch_obs): if self.training: return self._batch_act_train(batch_obs) else: return self._batch_act_eval(batch_obs) def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): if self.training: self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) def _batch_act_eval(self, batch_obs): assert not self.training return self.batch_select_greedy_action( batch_obs, deterministic=self.act_deterministically ) def _batch_act_train(self, batch_obs): assert self.training if self.burnin_action_func is not None and self.n_policy_updates == 0: batch_action = [self.burnin_action_func() for _ in range(len(batch_obs))] else: batch_action = self.batch_select_greedy_action(batch_obs) self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) return batch_action def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): assert self.training for i in range(len(batch_obs)): self.t += 1 if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer self.replay_buffer.append( state=self.batch_last_obs[i], action=self.batch_last_action[i], reward=batch_reward[i], next_state=batch_obs[i], next_action=None, is_state_terminal=batch_done[i], env_id=i, ) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) def get_statistics(self): return [ ("average_q1", _mean_or_nan(self.q1_record)), ("average_q2", _mean_or_nan(self.q2_record)), ("average_q_func1_loss", _mean_or_nan(self.q_func1_loss_record)), ("average_q_func2_loss", _mean_or_nan(self.q_func2_loss_record)), ("n_updates", self.n_policy_updates), ("average_entropy", _mean_or_nan(self.entropy_record)), ("temperature", self.temperature), ]
def __init__( self, policy, q_func1, q_func2, policy_optimizer, q_func1_optimizer, q_func2_optimizer, replay_buffer, gamma, gpu=None, replay_start_size=10000, minibatch_size=100, update_interval=1, phi=lambda x: x, soft_update_tau=5e-3, max_grad_norm=None, logger=getLogger(__name__), batch_states=batch_states, burnin_action_func=None, initial_temperature=1.0, entropy_target=None, temperature_optimizer_lr=None, act_deterministically=True, ): self.policy = policy self.q_func1 = q_func1 self.q_func2 = q_func2 if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.policy.to(self.device) self.q_func1.to(self.device) self.q_func2.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.gamma = gamma self.gpu = gpu self.phi = phi self.soft_update_tau = soft_update_tau self.logger = logger self.policy_optimizer = policy_optimizer self.q_func1_optimizer = q_func1_optimizer self.q_func2_optimizer = q_func2_optimizer self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=self.update, batchsize=minibatch_size, n_times_update=1, replay_start_size=replay_start_size, update_interval=update_interval, episodic_update=False, ) self.max_grad_norm = max_grad_norm self.batch_states = batch_states self.burnin_action_func = burnin_action_func self.initial_temperature = initial_temperature self.entropy_target = entropy_target if self.entropy_target is not None: self.temperature_holder = TemperatureHolder( initial_log_temperature=np.log(initial_temperature) ) if temperature_optimizer_lr is not None: self.temperature_optimizer = torch.optim.Adam( self.temperature_holder.parameters(), lr=temperature_optimizer_lr ) else: self.temperature_optimizer = torch.optim.Adam( self.temperature_holder.parameters() ) if gpu is not None and gpu >= 0: self.temperature_holder.to(self.device) else: self.temperature_holder = None self.temperature_optimizer = None self.act_deterministically = act_deterministically self.t = 0 # Target model self.target_q_func1 = copy.deepcopy(self.q_func1).eval().requires_grad_(False) self.target_q_func2 = copy.deepcopy(self.q_func2).eval().requires_grad_(False) # Statistics self.q1_record = collections.deque(maxlen=1000) self.q2_record = collections.deque(maxlen=1000) self.entropy_record = collections.deque(maxlen=1000) self.q_func1_loss_record = collections.deque(maxlen=100) self.q_func2_loss_record = collections.deque(maxlen=100) self.n_policy_updates = 0
def __init__( self, q_function: QNetworkWithValuebuffer, # torch.nn.Module, optimizer: torch.optim. Optimizer, # type: ignore # somehow mypy complains replay_buffer: EVAReplayBuffer, gamma: float, explorer: Explorer, gpu: Optional[int] = None, replay_start_size: int = 50000, minibatch_size: int = 32, update_interval: int = 1, target_update_interval: int = 10000, clip_delta: bool = True, phi: Callable[[Any], Any] = lambda x: x, target_update_method: str = "hard", soft_update_tau: float = 1e-2, n_times_update: int = 1, batch_accumulator: str = "mean", episodic_update_len: Optional[int] = None, interval_tcp=20, n_trj_step=50, use_eva=True, # If False, This Agent become DQN. logger: Logger = getLogger(__name__), batch_states: Callable[ [Sequence[Any], torch.device, Callable[[Any], Any]], Any] = batch_states, recurrent: bool = False, max_grad_norm: Optional[float] = None, ): self.model = q_function if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.model.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.optimizer = optimizer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.target_update_interval = target_update_interval self.clip_delta = clip_delta self.phi = phi self.target_update_method = target_update_method self.soft_update_tau = soft_update_tau self.batch_accumulator = batch_accumulator assert batch_accumulator in ("mean", "sum") self.logger = logger self.batch_states = batch_states # self.recurrent = recurrent self.recurrent = False self.n_actions = self.model.n_actions self.value_buffer = self.model.v_buffer self.interval_tcp = interval_tcp self.n_trj_step = n_trj_step self.use_eva = use_eva update_func: Callable[..., None] if self.recurrent: assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer) update_func = self.update_from_episodes else: update_func = self.update self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=update_func, batchsize=minibatch_size, episodic_update=recurrent, episodic_update_len=episodic_update_len, n_times_update=n_times_update, replay_start_size=replay_start_size, update_interval=update_interval, ) self.minibatch_size = minibatch_size self.episodic_update_len = episodic_update_len self.replay_start_size = replay_start_size self.update_interval = update_interval self.max_grad_norm = max_grad_norm assert ( target_update_interval % update_interval == 0 ), "target_update_interval should be a multiple of update_interval" self.t = 0 self.eval_t = 0 self.optim_t = 0 # Compensate pytorch optim not having `t` self._cumulative_steps = 0 self.target_model = make_target_model_as_copy(self.model.q_function) # Statistics self.q_record: collections.deque = collections.deque(maxlen=1000) self.loss_record: collections.deque = collections.deque(maxlen=100) # Recurrent states of the model self.train_recurrent_states: Any = None self.train_prev_recurrent_states: Any = None self.test_recurrent_states: Any = None # Error checking if (self.replay_buffer.capacity is not None and self.replay_buffer.capacity < self.replay_updater.replay_start_size): raise ValueError( "Replay start size cannot exceed replay buffer capacity.")
class EVA(agent.AttributeSavingMixin, agent.BatchAgent): """Ephemeral Value Adjusments Args: q_function (StateQFunction): Q-function optimizer (Optimizer): Optimizer that is already setup replay_buffer (ReplayBuffer): Replay buffer gamma (float): Discount factor explorer (Explorer): Explorer that specifies an exploration strategy. gpu (int): GPU device id if not None nor negative. replay_start_size (int): if the replay buffer's size is less than replay_start_size, skip update minibatch_size (int): Minibatch size update_interval (int): Model update interval in step target_update_interval (int): Target model update interval in step clip_delta (bool): Clip delta if set True phi (callable): Feature extractor applied to observations target_update_method (str): 'hard' or 'soft'. soft_update_tau (float): Tau of soft target update. n_times_update (int): Number of repetition of update batch_accumulator (str): 'mean' or 'sum' episodic_update_len (int or None): Subsequences of this length are used for update if set int and episodic_update=True logger (Logger): Logger used batch_states (callable): method which makes a batch of observations. default is `pfrl.utils.batch_states.batch_states` recurrent (bool): If set to True, `model` is assumed to implement `pfrl.nn.Recurrent` and is updated in a recurrent manner. max_grad_norm (float or None): Maximum L2 norm of the gradient used for gradient clipping. If set to None, the gradient is not clipped. """ saved_attributes = ("model", "target_model", "optimizer") def __init__( self, q_function: QNetworkWithValuebuffer, # torch.nn.Module, optimizer: torch.optim. Optimizer, # type: ignore # somehow mypy complains replay_buffer: EVAReplayBuffer, gamma: float, explorer: Explorer, gpu: Optional[int] = None, replay_start_size: int = 50000, minibatch_size: int = 32, update_interval: int = 1, target_update_interval: int = 10000, clip_delta: bool = True, phi: Callable[[Any], Any] = lambda x: x, target_update_method: str = "hard", soft_update_tau: float = 1e-2, n_times_update: int = 1, batch_accumulator: str = "mean", episodic_update_len: Optional[int] = None, interval_tcp=20, n_trj_step=50, use_eva=True, # If False, This Agent become DQN. logger: Logger = getLogger(__name__), batch_states: Callable[ [Sequence[Any], torch.device, Callable[[Any], Any]], Any] = batch_states, recurrent: bool = False, max_grad_norm: Optional[float] = None, ): self.model = q_function if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.model.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.optimizer = optimizer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.target_update_interval = target_update_interval self.clip_delta = clip_delta self.phi = phi self.target_update_method = target_update_method self.soft_update_tau = soft_update_tau self.batch_accumulator = batch_accumulator assert batch_accumulator in ("mean", "sum") self.logger = logger self.batch_states = batch_states # self.recurrent = recurrent self.recurrent = False self.n_actions = self.model.n_actions self.value_buffer = self.model.v_buffer self.interval_tcp = interval_tcp self.n_trj_step = n_trj_step self.use_eva = use_eva update_func: Callable[..., None] if self.recurrent: assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer) update_func = self.update_from_episodes else: update_func = self.update self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=update_func, batchsize=minibatch_size, episodic_update=recurrent, episodic_update_len=episodic_update_len, n_times_update=n_times_update, replay_start_size=replay_start_size, update_interval=update_interval, ) self.minibatch_size = minibatch_size self.episodic_update_len = episodic_update_len self.replay_start_size = replay_start_size self.update_interval = update_interval self.max_grad_norm = max_grad_norm assert ( target_update_interval % update_interval == 0 ), "target_update_interval should be a multiple of update_interval" self.t = 0 self.eval_t = 0 self.optim_t = 0 # Compensate pytorch optim not having `t` self._cumulative_steps = 0 self.target_model = make_target_model_as_copy(self.model.q_function) # Statistics self.q_record: collections.deque = collections.deque(maxlen=1000) self.loss_record: collections.deque = collections.deque(maxlen=100) # Recurrent states of the model self.train_recurrent_states: Any = None self.train_prev_recurrent_states: Any = None self.test_recurrent_states: Any = None # Error checking if (self.replay_buffer.capacity is not None and self.replay_buffer.capacity < self.replay_updater.replay_start_size): raise ValueError( "Replay start size cannot exceed replay buffer capacity.") @property def cumulative_steps(self) -> int: # cumulative_steps counts the overall steps during the training. return self._cumulative_steps def _setup_actor_learner_training( self, n_actors: int, actor_update_interval: int, update_counter: Any, ) -> Tuple[torch.nn.Module, Sequence[mp.connection.Connection], Sequence[mp.connection.Connection], ]: assert actor_update_interval > 0 self.actor_update_interval = actor_update_interval self.update_counter = update_counter # Make a copy on shared memory and share among actors and the poller shared_model = copy.deepcopy(self.model).cpu() shared_model.share_memory() # Pipes are used for infrequent communication learner_pipes, actor_pipes = list( zip(*[mp.Pipe() for _ in range(n_actors)])) return (shared_model, learner_pipes, actor_pipes) def sync_target_network(self) -> None: """Synchronize target network with current network.""" synchronize_parameters( src=self.model.q_function, dst=self.target_model, method=self.target_update_method, tau=self.soft_update_tau, ) def update(self, experiences: List[List[Dict[str, Any]]], errors_out: Optional[list] = None) -> None: """Update the model from experiences Args: experiences (list): List of lists of dicts. For DQN, each dict must contains: - state (object): State - action (object): Action - reward (float): Reward - is_state_terminal (bool): True iff next state is terminal - next_state (object): Next state - weight (float, optional): Weight coefficient. It can be used for importance sampling. errors_out (list or None): If set to a list, then TD-errors computed from the given experiences are appended to the list. Returns: None """ has_weight = "weight" in experiences[0][0] exp_batch = batch_experiences( experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) if has_weight: exp_batch["weights"] = torch.tensor( [elem[0]["weight"] for elem in experiences], device=self.device, dtype=torch.float32, ) if errors_out is None: errors_out = [] loss = self._compute_loss(exp_batch, errors_out=errors_out) if has_weight: assert isinstance(self.replay_buffer, PrioritizedReplayBuffer) self.replay_buffer.update_errors(errors_out) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: pfrl.utils.clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.optim_t += 1 def update_from_episodes(self, episodes: List[List[Dict[str, Any]]], errors_out: Optional[list] = None) -> None: assert errors_out is None, "Recurrent DQN does not support PrioritizedBuffer" episodes = sorted(episodes, key=len, reverse=True) exp_batch = batch_recurrent_experiences( episodes, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) loss = self._compute_loss(exp_batch, errors_out=None) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: pfrl.utils.clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.optim_t += 1 def _compute_target_values(self, exp_batch: Dict[str, Any]) -> torch.Tensor: batch_next_state = exp_batch["next_state"] if self.recurrent: target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_qout, _ = self.target_model(batch_next_state) next_q_max = target_next_qout.max batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] discount = exp_batch["discount"] return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max def _compute_y_and_t( self, exp_batch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = exp_batch["reward"].shape[0] # Compute Q-values for current states batch_state = exp_batch["state"] if self.recurrent: qout, _ = pack_and_forward(self.model, batch_state, exp_batch["recurrent_state"]) else: qout, _ = self.model(batch_state) batch_actions = exp_batch["action"] batch_q = torch.reshape(qout.evaluate_actions(batch_actions), (batch_size, 1)) with torch.no_grad(): batch_q_target = torch.reshape( self._compute_target_values(exp_batch), (batch_size, 1)) return batch_q, batch_q_target def _compute_loss(self, exp_batch: Dict[str, Any], errors_out: Optional[list] = None) -> torch.Tensor: """Compute the Q-learning loss for a batch of experiences Args: exp_batch (dict): A dict of batched arrays of transitions Returns: Computed loss from the minibatch of experiences """ y, t = self._compute_y_and_t(exp_batch) self.q_record.extend(y.detach().cpu().numpy().ravel()) if errors_out is not None: del errors_out[:] delta = torch.abs(y - t) if delta.ndim == 2: delta = torch.sum(delta, dim=1) delta = delta.detach().cpu().numpy() for e in delta: errors_out.append(e) if "weights" in exp_batch: return compute_weighted_value_loss( y, t, exp_batch["weights"], clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator, ) else: return compute_value_loss( y, t, clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator, ) def _evaluate_model_and_update_recurrent_states(self, batch_obs: Sequence[Any]): batch_xs = self.batch_states(batch_obs, self.device, self.phi) batch_h = None if self.recurrent: if self.training: self.train_prev_recurrent_states = self.train_recurrent_states batch_av, self.train_recurrent_states = one_step_forward( self.model, batch_xs, self.train_recurrent_states) else: batch_av, self.test_recurrent_states = one_step_forward( self.model, batch_xs, self.test_recurrent_states) else: batch_av, batch_h = self.model(batch_xs, self.use_eva) return batch_av, batch_h def batch_act(self, batch_obs: Sequence[Any]) -> Sequence[Any]: with torch.no_grad(), evaluating(self.model): batch_av, self.batch_h = self._evaluate_model_and_update_recurrent_states( batch_obs) batch_argmax = batch_av.greedy_actions.detach().cpu().numpy() if self.training: batch_action = [ self.explorer.select_action( self.t, lambda: batch_argmax[i], action_value=batch_av[i:i + 1], ) for i in range(len(batch_obs)) ] self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) else: batch_action = batch_argmax return batch_action def _batch_observe_train( self, batch_obs: Sequence[Any], batch_reward: Sequence[float], batch_done: Sequence[bool], batch_reset: Sequence[bool], ) -> None: for i in range(len(batch_obs)): self.t += 1 self._cumulative_steps += 1 # Update the target network if self.t % self.target_update_interval == 0: self.sync_target_network() if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer transition = { "state": self.batch_last_obs[i], "action": self.batch_last_action[i], "reward": batch_reward[i], "feature": self.batch_h[i], "next_state": batch_obs[i], "next_action": None, "is_state_terminal": batch_done[i], } if self.recurrent: transition["recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at( self.train_prev_recurrent_states, i, detach=True)) transition[ "next_recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at(self.train_recurrent_states, i, detach=True)) self.replay_buffer.append(env_id=i, **transition) self._backup_if_necessary(self.t, self.batch_h[i]) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) if self.recurrent: # Reset recurrent states when episodes end self.train_prev_recurrent_states = None self.train_recurrent_states = _batch_reset_recurrent_states_when_episodes_end( # NOQA batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.train_recurrent_states, ) def _batch_observe_eval( self, batch_obs: Sequence[Any], batch_reward: Sequence[float], batch_done: Sequence[bool], batch_reset: Sequence[bool], ) -> None: for i in range(len(batch_obs)): self._backup_if_necessary(self.eval_t, self.batch_h[i]) if batch_reset[i] or batch_done[i]: self.eval_t = 0 else: self.eval_t += 1 if self.recurrent: # Reset recurrent states when episodes end self.test_recurrent_states = _batch_reset_recurrent_states_when_episodes_end( # NOQA batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.test_recurrent_states, ) def batch_observe( self, batch_obs: Sequence[Any], batch_reward: Sequence[float], batch_done: Sequence[bool], batch_reset: Sequence[bool], ) -> None: if self.training: return self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) else: return self._batch_observe_eval(batch_obs, batch_reward, batch_done, batch_reset) def _backup_if_necessary(self, t, feature): if (t % self.interval_tcp == 0 and len(self.replay_buffer) >= self.replay_buffer.capacity and self.use_eva): trajectory_list = self.replay_buffer.lookup( feature, self.n_trj_step) batch_trj = [ batch_trajectory(trajectory, self.device, self.phi, batch_states=batch_states) for trajectory in trajectory_list ] q_np_arr = self._trajectory_centric_planning(batch_trj) batch_feature = [ elem for trj in batch_trj for elem in trj['feature'] ] batch_feature = torch.tensor(np.asarray(batch_feature), dtype=torch.float32) self.value_buffer.store(batch_feature, q_np_arr) def _trajectory_centric_planning(self, trajectories): state_shape = tuple( trajectories[0]["state"].shape)[1:] # torch.Size -> tuple # Aligning Shapes for Parallel Processing with GPUs # If Atari, it will be (0, 4, 84, 84) batch_states = torch.empty((0, ) + state_shape, dtype=torch.float32) for trajectory in trajectories: bs = torch.empty((self.n_trj_step, ) + state_shape, dtype=torch.float32) bs[:len(trajectory["state"])] = trajectory["state"] # numpy.vstack batch_states = torch.cat((batch_states, bs), dim=0) batch_states = batch_states.to(self.device) with torch.no_grad(), evaluating(self.model): batch_q, _ = self.model(batch_states) q_theta_arr = batch_q.q_values.cpu() q_theta_arr = q_theta_arr.reshape( (len(trajectories), self.n_trj_step, self.n_actions)) q_np_arr = torch.empty((0, self.n_actions), dtype=torch.float32) for q_np, trajectory in zip(q_theta_arr, trajectories): # batch_state = trajectory['state'] batch_action = trajectory['action'] batch_reward = trajectory['reward'] q_np = q_np[:len(batch_action)] for t in range(len(batch_action) - 2, -1, -1): # t:= T-2, 0 V_np = torch.max(q_np[t + 1]) # V_NP(s_t+1) := max_a Q(s_t+1, a) q_np[t, batch_action[t]] = batch_reward[t] + self.gamma * V_np q_np_arr = torch.cat((q_np_arr, q_np.reshape(-1, self.n_actions)), dim=0) return q_np_arr.to(self.device) def _can_start_replay(self) -> bool: if len(self.replay_buffer) < self.replay_start_size: return False if self.recurrent: assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer) if self.replay_buffer.n_episodes < self.minibatch_size: return False return True def _poll_pipe( self, actor_idx: int, pipe: mp.connection.Connection, replay_buffer_lock: mp.synchronize.Lock, exception_event: mp.synchronize.Event, ) -> None: if pipe.closed: return try: while pipe.poll() and not exception_event.is_set(): cmd, data = pipe.recv() if cmd == "get_statistics": assert data is None with replay_buffer_lock: stats = self.get_statistics() pipe.send(stats) elif cmd == "load": self.load(data) pipe.send(None) elif cmd == "save": self.save(data) pipe.send(None) elif cmd == "transition": with replay_buffer_lock: if "env_id" not in data: data["env_id"] = actor_idx self.replay_buffer.append(**data) self._cumulative_steps += 1 elif cmd == "stop_episode": idx = actor_idx if data is None else data with replay_buffer_lock: self.replay_buffer.stop_current_episode(env_id=idx) stats = self.get_statistics() pipe.send(stats) else: raise RuntimeError( "Unknown command from actor: {}".format(cmd)) except EOFError: pipe.close() except Exception: self.logger.exception("Poller loop failed. Exiting") exception_event.set() def _learner_loop( self, shared_model: torch.nn.Module, pipes: Sequence[mp.connection.Connection], replay_buffer_lock: mp.synchronize.Lock, stop_event: mp.synchronize.Event, exception_event: mp.synchronize.Event, n_updates: Optional[int] = None, ) -> None: try: update_counter = 0 # To stop this loop, call stop_event.set() while not stop_event.is_set(): # Update model if possible if not self._can_start_replay(): continue if n_updates is not None: assert self.optim_t <= n_updates if self.optim_t == n_updates: stop_event.set() break if self.recurrent: assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer) with replay_buffer_lock: episodes = self.replay_buffer.sample_episodes( self.minibatch_size, self.episodic_update_len) self.update_from_episodes(episodes) else: with replay_buffer_lock: transitions = self.replay_buffer.sample( self.minibatch_size) self.update(transitions) # Update the shared model. This can be expensive if GPU is used # since this is a DtoH copy, so it is updated only at regular # intervals. update_counter += 1 if update_counter % self.actor_update_interval == 0: with self.update_counter.get_lock(): self.update_counter.value += 1 shared_model.load_state_dict(self.model.state_dict()) # To keep the ratio of target updates to model updates, # here we calculate back the effective current timestep # from update_interval and number of updates so far. effective_timestep = self.optim_t * self.update_interval # We can safely assign self.t since in the learner # it isn't updated by any other method self.t = effective_timestep if effective_timestep % self.target_update_interval == 0: self.sync_target_network() except Exception: self.logger.exception("Learner loop failed. Exiting") exception_event.set() def _poller_loop( self, shared_model: torch.nn.Module, pipes: Sequence[mp.connection.Connection], replay_buffer_lock: mp.synchronize.Lock, stop_event: mp.synchronize.Event, exception_event: mp.synchronize.Event, ) -> None: # To stop this loop, call stop_event.set() while not stop_event.is_set() and not exception_event.is_set(): time.sleep(1e-6) # Poll actors for messages for i, pipe in enumerate(pipes): self._poll_pipe(i, pipe, replay_buffer_lock, exception_event) def setup_actor_learner_training( self, n_actors: int, update_counter: Optional[Any] = None, n_updates: Optional[int] = None, actor_update_interval: int = 8, ): if update_counter is None: update_counter = mp.Value(ctypes.c_ulong) (shared_model, learner_pipes, actor_pipes) = self._setup_actor_learner_training( n_actors, actor_update_interval, update_counter) exception_event = mp.Event() def make_actor(i): return pfrl.agents.StateQFunctionActor( pipe=actor_pipes[i], model=shared_model, explorer=self.explorer, phi=self.phi, batch_states=self.batch_states, logger=self.logger, recurrent=self.recurrent, ) replay_buffer_lock = mp.Lock() poller_stop_event = mp.Event() poller = pfrl.utils.StoppableThread( target=self._poller_loop, kwargs=dict( shared_model=shared_model, pipes=learner_pipes, replay_buffer_lock=replay_buffer_lock, stop_event=poller_stop_event, exception_event=exception_event, ), stop_event=poller_stop_event, ) learner_stop_event = mp.Event() learner = pfrl.utils.StoppableThread( target=self._learner_loop, kwargs=dict( shared_model=shared_model, pipes=learner_pipes, replay_buffer_lock=replay_buffer_lock, stop_event=learner_stop_event, n_updates=n_updates, exception_event=exception_event, ), stop_event=learner_stop_event, ) return make_actor, learner, poller, exception_event def stop_episode(self) -> None: if self.recurrent: self.test_recurrent_states = None def get_statistics(self): return [ ("average_q", _mean_or_nan(self.q_record)), ("average_loss", _mean_or_nan(self.loss_record)), ("cumulative_steps", self.cumulative_steps), ("n_updates", self.optim_t), ("rlen", len(self.replay_buffer)), ]
def __init__( self, policy, q_func, actor_optimizer, critic_optimizer, replay_buffer, gamma, explorer, gpu=None, replay_start_size=50000, minibatch_size=32, update_interval=1, target_update_interval=10000, phi=lambda x: x, target_update_method="hard", soft_update_tau=1e-2, n_times_update=1, recurrent=False, episodic_update_len=None, logger=getLogger(__name__), batch_states=batch_states, burnin_action_func=None, ): self.model = nn.ModuleList([policy, q_func]) if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.model.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.target_update_interval = target_update_interval self.phi = phi self.target_update_method = target_update_method self.soft_update_tau = soft_update_tau self.logger = logger self.actor_optimizer = actor_optimizer self.critic_optimizer = critic_optimizer self.recurrent = recurrent assert not self.recurrent, "recurrent=True is not yet implemented" if self.recurrent: update_func = self.update_from_episodes else: update_func = self.update self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=update_func, batchsize=minibatch_size, episodic_update=recurrent, episodic_update_len=episodic_update_len, n_times_update=n_times_update, replay_start_size=replay_start_size, update_interval=update_interval, ) self.batch_states = batch_states self.burnin_action_func = burnin_action_func self.t = 0 self.last_state = None self.last_action = None self.target_model = copy.deepcopy(self.model) self.target_model.eval() self.q_record = collections.deque(maxlen=1000) self.actor_loss_record = collections.deque(maxlen=100) self.critic_loss_record = collections.deque(maxlen=100) self.n_updates = 0 # Aliases for convenience self.policy, self.q_function = self.model self.target_policy, self.target_q_function = self.target_model self.sync_target_network()
class DDPG(AttributeSavingMixin, BatchAgent): """Deep Deterministic Policy Gradients. This can be used as SVG(0) by specifying a Gaussian policy instead of a deterministic policy. Args: policy (torch.nn.Module): Policy q_func (torch.nn.Module): Q-function actor_optimizer (Optimizer): Optimizer setup with the policy critic_optimizer (Optimizer): Optimizer setup with the Q-function replay_buffer (ReplayBuffer): Replay buffer gamma (float): Discount factor explorer (Explorer): Explorer that specifies an exploration strategy. gpu (int): GPU device id if not None nor negative. replay_start_size (int): if the replay buffer's size is less than replay_start_size, skip update minibatch_size (int): Minibatch size update_interval (int): Model update interval in step target_update_interval (int): Target model update interval in step phi (callable): Feature extractor applied to observations target_update_method (str): 'hard' or 'soft'. soft_update_tau (float): Tau of soft target update. n_times_update (int): Number of repetition of update batch_accumulator (str): 'mean' or 'sum' episodic_update (bool): Use full episodes for update if set True episodic_update_len (int or None): Subsequences of this length are used for update if set int and episodic_update=True logger (Logger): Logger used batch_states (callable): method which makes a batch of observations. default is `pfrl.utils.batch_states.batch_states` burnin_action_func (callable or None): If not None, this callable object is used to select actions before the model is updated one or more times during training. """ saved_attributes = ("model", "target_model", "actor_optimizer", "critic_optimizer") def __init__( self, policy, q_func, actor_optimizer, critic_optimizer, replay_buffer, gamma, explorer, gpu=None, replay_start_size=50000, minibatch_size=32, update_interval=1, target_update_interval=10000, phi=lambda x: x, target_update_method="hard", soft_update_tau=1e-2, n_times_update=1, recurrent=False, episodic_update_len=None, logger=getLogger(__name__), batch_states=batch_states, burnin_action_func=None, ): self.model = nn.ModuleList([policy, q_func]) if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.model.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.target_update_interval = target_update_interval self.phi = phi self.target_update_method = target_update_method self.soft_update_tau = soft_update_tau self.logger = logger self.actor_optimizer = actor_optimizer self.critic_optimizer = critic_optimizer self.recurrent = recurrent assert not self.recurrent, "recurrent=True is not yet implemented" if self.recurrent: update_func = self.update_from_episodes else: update_func = self.update self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=update_func, batchsize=minibatch_size, episodic_update=recurrent, episodic_update_len=episodic_update_len, n_times_update=n_times_update, replay_start_size=replay_start_size, update_interval=update_interval, ) self.batch_states = batch_states self.burnin_action_func = burnin_action_func self.t = 0 self.last_state = None self.last_action = None self.target_model = copy.deepcopy(self.model) self.target_model.eval() self.q_record = collections.deque(maxlen=1000) self.actor_loss_record = collections.deque(maxlen=100) self.critic_loss_record = collections.deque(maxlen=100) self.n_updates = 0 # Aliases for convenience self.policy, self.q_function = self.model self.target_policy, self.target_q_function = self.target_model self.sync_target_network() def sync_target_network(self): """Synchronize target network with current network.""" synchronize_parameters( src=self.model, dst=self.target_model, method=self.target_update_method, tau=self.soft_update_tau, ) # Update Q-function def compute_critic_loss(self, batch): """Compute loss for critic.""" batch_next_state = batch["next_state"] batch_rewards = batch["reward"] batch_terminal = batch["is_state_terminal"] batch_state = batch["state"] batch_actions = batch["action"] batchsize = len(batch_rewards) with torch.no_grad(): assert not self.recurrent next_actions = self.target_policy(batch_next_state).sample() next_q = self.target_q_function((batch_next_state, next_actions)) target_q = batch_rewards + self.gamma * ( 1.0 - batch_terminal) * next_q.reshape((batchsize, )) predict_q = self.q_function((batch_state, batch_actions)).reshape( (batchsize, )) loss = F.mse_loss(target_q, predict_q) # Update stats self.critic_loss_record.append(float(loss.detach().cpu().numpy())) return loss def compute_actor_loss(self, batch): """Compute loss for actor.""" batch_state = batch["state"] onpolicy_actions = self.policy(batch_state).rsample() q = self.q_function((batch_state, onpolicy_actions)) loss = -q.mean() # Update stats self.q_record.extend(q.detach().cpu().numpy()) self.actor_loss_record.append(float(loss.detach().cpu().numpy())) return loss def update(self, experiences, errors_out=None): """Update the model from experiences""" batch = batch_experiences(experiences, self.device, self.phi, self.gamma) self.critic_optimizer.zero_grad() self.compute_critic_loss(batch).backward() self.critic_optimizer.step() self.actor_optimizer.zero_grad() self.compute_actor_loss(batch).backward() self.actor_optimizer.step() self.n_updates += 1 def update_from_episodes(self, episodes, errors_out=None): raise NotImplementedError # Sort episodes desc by their lengths sorted_episodes = list(reversed(sorted(episodes, key=len))) max_epi_len = len(sorted_episodes[0]) # Precompute all the input batches batches = [] for i in range(max_epi_len): transitions = [] for ep in sorted_episodes: if len(ep) <= i: break transitions.append([ep[i]]) batch = batch_experiences(transitions, xp=self.device, phi=self.phi, gamma=self.gamma) batches.append(batch) with self.model.state_reset(), self.target_model.state_reset(): # Since the target model is evaluated one-step ahead, # its internal states need to be updated self.target_q_function.update_state(batches[0]["state"], batches[0]["action"]) self.target_policy(batches[0]["state"]) # Update critic through time critic_loss = 0 for batch in batches: critic_loss += self.compute_critic_loss(batch) self.critic_optimizer.update(lambda: critic_loss / max_epi_len) with self.model.state_reset(): # Update actor through time actor_loss = 0 for batch in batches: actor_loss += self.compute_actor_loss(batch) self.actor_optimizer.update(lambda: actor_loss / max_epi_len) def batch_act(self, batch_obs): if self.training: return self._batch_act_train(batch_obs) else: return self._batch_act_eval(batch_obs) def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): if self.training: self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) def _batch_select_greedy_actions(self, batch_obs): with torch.no_grad(), evaluating(self.policy): batch_xs = self.batch_states(batch_obs, self.device, self.phi) batch_action = self.policy(batch_xs).sample() return batch_action.cpu().numpy() def _batch_act_eval(self, batch_obs): assert not self.training return self._batch_select_greedy_actions(batch_obs) def _batch_act_train(self, batch_obs): assert self.training if self.burnin_action_func is not None and self.n_updates == 0: batch_action = [ self.burnin_action_func() for _ in range(len(batch_obs)) ] else: batch_greedy_action = self._batch_select_greedy_actions(batch_obs) batch_action = [ self.explorer.select_action(self.t, lambda: batch_greedy_action[i]) for i in range(len(batch_greedy_action)) ] self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) return batch_action def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): assert self.training for i in range(len(batch_obs)): self.t += 1 # Update the target network if self.t % self.target_update_interval == 0: self.sync_target_network() if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer self.replay_buffer.append( state=self.batch_last_obs[i], action=self.batch_last_action[i], reward=batch_reward[i], next_state=batch_obs[i], next_action=None, is_state_terminal=batch_done[i], env_id=i, ) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) def get_statistics(self): return [ ("average_q", _mean_or_nan(self.q_record)), ("average_actor_loss", _mean_or_nan(self.actor_loss_record)), ("average_critic_loss", _mean_or_nan(self.critic_loss_record)), ("n_updates", self.n_updates), ]
class SQIL(agent.AttributeSavingMixin, agent.BatchAgent): """Deep Q-Network algorithm. Args: q_function (StateQFunction): Q-function optimizer (Optimizer): Optimizer that is already setup replay_buffer (ReplayBuffer): Replay buffer gamma (float): Discount factor explorer (Explorer): Explorer that specifies an exploration strategy. gpu (int): GPU device id if not None nor negative. replay_start_size (int): if the replay buffer's size is less than replay_start_size, skip update minibatch_size (int): Minibatch size update_interval (int): Model update interval in step target_update_interval (int): Target model update interval in step clip_delta (bool): Clip delta if set True phi (callable): Feature extractor applied to observations target_update_method (str): 'hard' or 'soft'. soft_update_tau (float): Tau of soft target update. n_times_update (int): Number of repetition of update batch_accumulator (str): 'mean' or 'sum' episodic_update_len (int or None): Subsequences of this length are used for update if set int and episodic_update=True logger (Logger): Logger used batch_states (callable): method which makes a batch of observations. default is `pfrl.utils.batch_states.batch_states` recurrent (bool): If set to True, `model` is assumed to implement `pfrl.nn.Recurrent` and is updated in a recurrent manner. Changes from DQN: remove recurrent support add expert dataset """ saved_attributes = ("model", "target_model", "optimizer") def __init__( self, q_function, optimizer, replay_buffer, gamma, explorer, gpu=None, replay_start_size=50000, minibatch_size=32, update_interval=1, target_update_interval=10000, clip_delta=True, phi=lambda x: x, target_update_method="hard", soft_update_tau=1e-2, n_times_update=1, batch_accumulator="mean", episodic_update_len=None, logger=getLogger(__name__), batch_states=batch_states, expert_dataset=None, reward_scale=1.0, experience_lambda=1.0, recurrent=False, reward_boundaries=None, # specific to options ): self.expert_dataset = expert_dataset self.model = q_function if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.model.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.optimizer = optimizer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.target_update_interval = target_update_interval self.clip_delta = clip_delta self.phi = phi self.target_update_method = target_update_method self.soft_update_tau = soft_update_tau self.batch_accumulator = batch_accumulator assert batch_accumulator in ("mean", "sum") self.logger = logger self.batch_states = batch_states self.recurrent = recurrent if self.recurrent: update_func = self.update_from_episodes else: update_func = self.update self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=update_func, batchsize=minibatch_size, episodic_update=recurrent, episodic_update_len=episodic_update_len, n_times_update=n_times_update, replay_start_size=replay_start_size, update_interval=update_interval, ) self.minibatch_size = minibatch_size self.episodic_update_len = episodic_update_len self.replay_start_size = replay_start_size self.update_interval = update_interval assert ( target_update_interval % update_interval == 0 ), "target_update_interval should be a multiple of update_interval" # For imitation self.reward_scale = reward_scale self.experience_lambda = experience_lambda if reward_boundaries is not None and self.expert_dataset is not None: self.reward_based_sampler = RewardBasedSampler(self.expert_dataset, reward_boundaries, reward=reward_scale) else: self.reward_based_sampler = None self.t = 0 self.optim_t = 0 # Compensate pytorch optim not having `t` self._cumulative_steps = 0 self.last_state = None self.last_action = None self.target_model = None self.sync_target_network() # Statistics self.q_record = collections.deque(maxlen=1000) self.loss_record = collections.deque(maxlen=100) # Recurrent states of the model self.train_recurrent_states = None self.train_prev_recurrent_states = None self.test_recurrent_states = None # Error checking if (self.replay_buffer.capacity is not None and self.replay_buffer.capacity < self.replay_updater.replay_start_size): raise ValueError( "Replay start size cannot exceed replay buffer capacity.") @property def cumulative_steps(self): # cumulative_steps counts the overall steps during the training. return self._cumulative_steps def sync_target_network(self): """Synchronize target network with current network.""" if self.target_model is None: self.target_model = copy.deepcopy(self.model) def flatten_parameters(mod): if isinstance(mod, torch.nn.RNNBase): mod.flatten_parameters() # RNNBase.flatten_parameters must be called again after deep-copy. # See: https://discuss.pytorch.org/t/why-do-we-need-flatten-parameters-when-using-rnn-with-dataparallel/46506 # NOQA self.target_model.apply(flatten_parameters) # set target n/w to evaluate only. self.target_model.eval() else: synchronize_parameters( src=self.model, dst=self.target_model, method=self.target_update_method, tau=self.soft_update_tau, ) def update(self, experiences, errors_out=None): """Update the model from experiences Args: experiences (list): List of lists of dicts. For DQN, each dict must contains: - state (object): State - action (object): Action - reward (float): Reward - is_state_terminal (bool): True iff next state is terminal - next_state (object): Next state - weight (float, optional): Weight coefficient. It can be used for importance sampling. errors_out (list or None): If set to a list, then TD-errors computed from the given experiences are appended to the list. Returns: None Changes from DQN: Learned from demonstrations """ has_weight = "weight" in experiences[0][0] exp_batch = batch_experiences( experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) if has_weight: exp_batch["weights"] = torch.tensor( [elem[0]["weight"] for elem in experiences], device=self.device, dtype=torch.float32, ) if errors_out is None: errors_out = [] if self.reward_based_sampler is not None: demo_experiences = self.reward_based_sampler.sample(experiences) else: demo_experiences = load_experiences_from_demonstrations( self.expert_dataset, self.replay_updater.batchsize, self.reward_scale) demo_batch = batch_experiences( demo_experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) loss = self._compute_loss(exp_batch, demo_batch, errors_out=errors_out) if has_weight: self.replay_buffer.update_errors(errors_out) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.optim_t += 1 def update_from_episodes(self, episodes, errors_out=None): assert errors_out is None, "Recurrent DQN does not support PrioritizedBuffer" episodes = sorted(episodes, key=len, reverse=True) exp_batch = batch_recurrent_experiences( episodes, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) demo_experiences = load_experiences_from_demonstrations( self.expert_dataset, self.replay_updater.batchsize, self.reward_scale) demo_batch = batch_experiences( demo_experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) loss = self._compute_loss(exp_batch, demo_batch, errors_out=None) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.optim_t += 1 def _compute_target_values(self, exp_batch): """ Changes from DQN: Consider soft Bellman error """ batch_next_state = exp_batch["next_state"] target_next_qout = self.target_model(batch_next_state) next_q_max = torch.broadcast_tensors( target_next_qout.q_values.max(dim=-1, keepdim=True)[0], target_next_qout.q_values)[0] next_q_soft = ( next_q_max[:, 0] + (target_next_qout.q_values - next_q_max).exp().sum(dim=-1).log()) batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] discount = exp_batch["discount"] # return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max return batch_rewards + discount * (1.0 - batch_terminal) * next_q_soft def _compute_y_and_t(self, exp_batch): batch_size = exp_batch["reward"].shape[0] # Compute Q-values for current states batch_state = exp_batch["state"] if self.recurrent: qout, _ = pack_and_forward(self.model, batch_state, exp_batch["recurrent_state"]) else: qout = self.model(batch_state) batch_actions = exp_batch["action"] batch_q = torch.reshape(qout.evaluate_actions(batch_actions), (batch_size, 1)) with torch.no_grad(): batch_q_target = torch.reshape( self._compute_target_values(exp_batch), (batch_size, 1)) return batch_q, batch_q_target def __compute_loss(self, exp_batch, errors_out): y, t = self._compute_y_and_t(exp_batch) self.q_record.extend(y.detach().cpu().numpy().ravel()) if errors_out is not None: del errors_out[:] delta = torch.abs(y - t) if delta.ndim == 2: delta = torch.sum(delta, dim=1) delta = delta.detach().cpu().numpy() for e in delta: errors_out.append(e) if "weights" in exp_batch: return compute_weighted_value_loss( y, t, exp_batch["weights"], clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator, ) else: return compute_value_loss( y, t, clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator, ) def _compute_loss(self, exp_batch, demo_batch, errors_out=None): """Compute the Q-learning loss for a batch of experiences Args: exp_batch (dict): A dict of batched arrays of transitions Returns: Computed loss from the minibatch of experiences Changes from DQN: Learned from demonstrations """ exp_loss = self.__compute_loss(exp_batch, errors_out=errors_out) demo_loss = self.__compute_loss(demo_batch, errors_out=None) return (exp_loss * self.experience_lambda + demo_loss) / 2 def _evaluate_model_and_update_recurrent_states(self, batch_obs): batch_xs = self.batch_states(batch_obs, self.device, self.phi) if self.recurrent: if self.training: self.train_prev_recurrent_states = self.train_recurrent_states batch_av, self.train_recurrent_states = one_step_forward( self.model, batch_xs, self.train_recurrent_states) else: batch_av, self.test_recurrent_states = one_step_forward( self.model, batch_xs, self.test_recurrent_states) else: batch_av = self.model(batch_xs) return batch_av def batch_act(self, batch_obs): with torch.no_grad(), evaluating(self.model): batch_av = self._evaluate_model_and_update_recurrent_states( batch_obs) batch_argmax = batch_av.greedy_actions.cpu().numpy() if self.training: batch_action = [ self.explorer.select_action( self.t, lambda: batch_argmax[i], action_value=batch_av[i:i + 1], ) for i in range(len(batch_obs)) ] self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) else: # stochastic batch_action = [ self.explorer.select_action( self.t, lambda: batch_argmax[i], action_value=batch_av[i:i + 1], ) for i in range(len(batch_obs)) ] # deterministic # batch_action = batch_argmax return batch_action def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): for i in range(len(batch_obs)): self.t += 1 self._cumulative_steps += 1 # Update the target network if self.t % self.target_update_interval == 0: self.sync_target_network() if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer transition = { "state": self.batch_last_obs[i], "action": self.batch_last_action[i], "reward": batch_reward[i], "next_state": batch_obs[i], "next_action": None, "is_state_terminal": batch_done[i], } if self.recurrent: transition["recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at( self.train_prev_recurrent_states, i, detach=True)) transition[ "next_recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at(self.train_recurrent_states, i, detach=True)) self.replay_buffer.append(env_id=i, **transition) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) if self.recurrent: # Reset recurrent states when episodes end self.train_prev_recurrent_states = None self.train_recurrent_states = _batch_reset_recurrent_states_when_episodes_end( # NOQA batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.train_recurrent_states, ) def _batch_observe_eval(self, batch_obs, batch_reward, batch_done, batch_reset): if self.recurrent: # Reset recurrent states when episodes end self.test_recurrent_states = _batch_reset_recurrent_states_when_episodes_end( # NOQA batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.test_recurrent_states, ) def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): if self.training: return self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) else: return self._batch_observe_eval(batch_obs, batch_reward, batch_done, batch_reset) def _can_start_replay(self): if len(self.replay_buffer) < self.replay_start_size: return False if self.recurrent and self.replay_buffer.n_episodes < self.minibatch_size: return False return True def stop_episode(self): if self.recurrent: self.test_recurrent_states = None def get_statistics(self): return [ ("average_q", _mean_or_nan(self.q_record)), ("average_loss", _mean_or_nan(self.loss_record)), ("cumulative_steps", self.cumulative_steps), ("n_updates", self.optim_t), ("rlen", len(self.replay_buffer)), ]