def __call__( self, model: Brain, target_model: Brain, experiences: Tuple[torch.Tensor, ...], gamma: float, head_cfg: ConfigDict, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return element-wise C51 loss and Q-values.""" states, actions, rewards, next_states, dones = experiences[:5] batch_size = states.shape[0] support = torch.linspace( head_cfg.configs.v_min, head_cfg.configs.v_max, head_cfg.configs.atom_size ).to(device) delta_z = float(head_cfg.configs.v_max - head_cfg.configs.v_min) / ( head_cfg.configs.atom_size - 1 ) with torch.no_grad(): # According to noisynet paper, # it resamples noisynet parameters on online network when using double q # but we don't because there is no remarkable difference in performance. next_actions = model.forward_(next_states)[1].argmax(1) next_dist = target_model.forward_(next_states)[0] next_dist = next_dist[range(batch_size), next_actions] t_z = rewards + (1 - dones) * gamma * support t_z = t_z.clamp(min=head_cfg.configs.v_min, max=head_cfg.configs.v_max) b = (t_z - head_cfg.configs.v_min) / delta_z l = b.floor().long() # noqa: E741 u = b.ceil().long() offset = ( torch.linspace( 0, (batch_size - 1) * head_cfg.configs.atom_size, batch_size ) .long() .unsqueeze(1) .expand(batch_size, head_cfg.configs.atom_size) .to(device) ) proj_dist = torch.zeros(next_dist.size(), device=device) proj_dist.view(-1).index_add_( 0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1) ) proj_dist.view(-1).index_add_( 0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1) ) dist, q_values = model.forward_(states) log_p = torch.log( torch.clamp(dist[range(batch_size), actions.long()], min=1e-7) ) dq_loss_element_wise = -(proj_dist * log_p).sum(1, keepdim=True) return dq_loss_element_wise, q_values
def _init_network(self): """Initialize networks and optimizers.""" # create actor self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.actor_target = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.actor_target.load_state_dict(self.actor.state_dict()) # create critic self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(self.device) self.critic_target = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(self.device) self.critic_target.load_state_dict(self.critic.state_dict()) # create optimizer self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr_actor, weight_decay=self.optim_cfg.weight_decay, ) self.critic_optim = optim.Adam( self.critic.parameters(), lr=self.optim_cfg.lr_critic, weight_decay=self.optim_cfg.weight_decay, ) # load the optimizer and model parameters if self.args.load_from is not None: self.load_params(self.args.load_from)
def test_brain(): """Test wheter brain make fc layer based on backbone's output size.""" head_cfg.configs.state_size = test_state_dim head_cfg.configs.output_size = 8 model = Brain(resnet_cfg, head_cfg) assert model.head.input_size == 16384
def _init_network(self): """Initialize networks and optimizers.""" # create actor self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) # create v_critic self.vf = Brain(self.backbone_cfg.critic_vf, self.head_cfg.critic_vf).to(self.device) self.vf_target = Brain(self.backbone_cfg.critic_vf, self.head_cfg.critic_vf).to(self.device) self.vf_target.load_state_dict(self.vf.state_dict()) # create q_critic self.qf_1 = Brain(self.backbone_cfg.critic_qf, self.head_cfg.critic_qf).to(self.device) self.qf_2 = Brain(self.backbone_cfg.critic_qf, self.head_cfg.critic_qf).to(self.device) # create optimizers self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr_actor, weight_decay=self.optim_cfg.weight_decay, ) self.vf_optim = optim.Adam( self.vf.parameters(), lr=self.optim_cfg.lr_vf, weight_decay=self.optim_cfg.weight_decay, ) self.qf_1_optim = optim.Adam( self.qf_1.parameters(), lr=self.optim_cfg.lr_qf1, weight_decay=self.optim_cfg.weight_decay, ) self.qf_2_optim = optim.Adam( self.qf_2.parameters(), lr=self.optim_cfg.lr_qf2, weight_decay=self.optim_cfg.weight_decay, ) # load the optimizer and model parameters if self.args.load_from is not None: self.load_params(self.args.load_from)
def _synchronize(self, network: Brain, new_state_dict: Dict[str, np.ndarray]): """Copy parameters from numpy arrays.""" param_name_list = list(new_state_dict.keys()) for worker_named_param in network.named_parameters(): worker_param_name = worker_named_param[0] if worker_param_name in param_name_list: new_param = numpy2floattensor( new_state_dict[worker_param_name], self.device) worker_named_param[1].data.copy_(new_param)
def __init__( self, args: argparse.Namespace, env_info: ConfigDict, log_cfg: ConfigDict, comm_cfg: ConfigDict, backbone: ConfigDict, head: ConfigDict, ): self.args = args self.env_info = env_info self.log_cfg = log_cfg self.comm_cfg = comm_cfg self.device = torch.device("cpu") # Logger only runs on cpu self.brain = Brain(backbone, head).to(self.device) self.update_step = 0 self.log_info_queue = deque(maxlen=100) self._init_env()
def _init_network(self): """Initialize networks and optimizers.""" self.dqn = Brain(self.backbone_cfg, self.head_cfg).to(self.device) self.dqn_target = Brain(self.backbone_cfg, self.head_cfg).to(self.device) self.loss_fn = build_loss(self.loss_type) self.dqn_target.load_state_dict(self.dqn.state_dict()) # create optimizer self.dqn_optim = optim.Adam( self.dqn.parameters(), lr=self.optim_cfg.lr_dqn, weight_decay=self.optim_cfg.weight_decay, eps=self.optim_cfg.adam_eps, ) # load the optimizer and model parameters if self.args.load_from is not None: self.load_params(self.args.load_from)
def _init_network(self): """Initialize networks and optimizers.""" # create actor if self.backbone_cfg.shared_actor_critic: shared_backbone = build_backbone( self.backbone_cfg.shared_actor_critic) self.actor = Brain( self.backbone_cfg.shared_actor_critic, self.head_cfg.actor, shared_backbone, ) self.critic = Brain( self.backbone_cfg.shared_actor_critic, self.head_cfg.critic, shared_backbone, ) self.actor = self.actor.to(self.device) self.critic = self.critic.to(self.device) else: self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(self.device) self.discriminator = Discriminator( self.backbone_cfg.discriminator, self.head_cfg.discriminator, self.head_cfg.aciton_embedder, ).to(self.device) # create optimizer self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr_actor, weight_decay=self.optim_cfg.weight_decay, ) self.critic_optim = optim.Adam( self.critic.parameters(), lr=self.optim_cfg.lr_critic, weight_decay=self.optim_cfg.weight_decay, ) self.discriminator_optim = optim.Adam( self.discriminator.parameters(), lr=self.optim_cfg.lr_discriminator, weight_decay=self.optim_cfg.weight_decay, ) # load model parameters if self.load_from is not None: self.load_params(self.load_from)
def _init_network(self): """Initialize networks and optimizers.""" self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to( self.device ) # create optimizer self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr_actor, weight_decay=self.optim_cfg.weight_decay, ) self.critic_optim = optim.Adam( self.critic.parameters(), lr=self.optim_cfg.lr_critic, weight_decay=self.optim_cfg.weight_decay, ) if self.load_from is not None: self.load_params(self.load_from)
def __init__( self, log_cfg: ConfigDict, comm_cfg: ConfigDict, backbone: ConfigDict, head: ConfigDict, env_name: str, is_atari: bool, state_size: int, output_size: int, max_update_step: int, episode_num: int, max_episode_steps: int, interim_test_num: int, is_log: bool, is_render: bool, ): self.log_cfg = log_cfg self.comm_cfg = comm_cfg self.device = torch.device("cpu") # Logger only runs on cpu head.configs.state_size = state_size head.configs.output_size = output_size self.brain = Brain(backbone, head).to(self.device) self.env_name = env_name self.is_atari = is_atari self.max_update_step = max_update_step self.episode_num = episode_num self.max_episode_steps = max_episode_steps self.interim_test_num = interim_test_num self.is_log = is_log self.is_render = is_render self.update_step = 0 self.log_info_queue = deque(maxlen=100) self._init_env()
def _init_network(self): """Initialize network and optimizer.""" self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to( self.device ) # create optimizer self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr, eps=self.optim_cfg.adam_eps ) self.critic_optim = optim.Adam( self.critic.parameters(), lr=self.optim_cfg.lr, eps=self.optim_cfg.adam_eps ) self.actor_target = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to( self.device ) self.actor_target.load_state_dict(self.actor.state_dict()) if self.load_from is not None: self.load_params(self.load_from)
class DQNWorker(DistributedWorker): """DQN worker for distributed training. Attributes: backbone (ConfigDict): backbone configs for building network head (ConfigDict): head configs for building network state_dict (ConfigDict): initial network state dict received form learner device (str): literal to indicate cpu/cuda use """ def __init__( self, rank: int, args: argparse.Namespace, env_info: ConfigDict, hyper_params: ConfigDict, backbone: ConfigDict, head: ConfigDict, state_dict: OrderedDict, device: str, loss_type: ConfigDict, ): DistributedWorker.__init__(self, rank, args, env_info, hyper_params, device) self.loss_fn = build_loss(loss_type) self.backbone_cfg = backbone self.head_cfg = head self.head_cfg.configs.state_size = self.env_info.observation_space.shape self.head_cfg.configs.output_size = self.env_info.action_space.n self.use_n_step = self.hyper_params.n_step > 1 self.max_epsilon = self.hyper_params.max_epsilon self.min_epsilon = self.hyper_params.min_epsilon self.epsilon = self.hyper_params.max_epsilon self._init_networks(state_dict) # pylint: disable=attribute-defined-outside-init def _init_networks(self, state_dict: OrderedDict): """Initialize DQN policy with learner state dict.""" self.dqn = Brain(self.backbone_cfg, self.head_cfg).to(self.device) self.dqn.load_state_dict(state_dict) self.dqn.eval() def load_params(self, path: str): """Load model and optimizer parameters.""" DistributedWorker.load_params(self, path) params = torch.load(path) self.dqn.load_state_dict(params["dqn_state_dict"]) print("[INFO] loaded the model and optimizer from", path) def select_action(self, state: np.ndarray) -> np.ndarray: """Select an action from the input space.""" # epsilon greedy policy # pylint: disable=comparison-with-callable if self.epsilon > np.random.random(): selected_action = np.array(self.env.action_space.sample()) else: with torch.no_grad(): state = self._preprocess_state(state, self.device) selected_action = self.dqn(state).argmax() selected_action = selected_action.cpu().numpy() # Decay epsilon self.epsilon = max( self.epsilon - (self.max_epsilon - self.min_epsilon) * self.hyper_params.epsilon_decay, self.min_epsilon, ) return selected_action def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool, dict]: """Take an action and return the response of the env.""" next_state, reward, done, info = self.env.step(action) return next_state, reward, done, info def compute_priorities(self, memory: Dict[str, np.ndarray]) -> np.ndarray: """Compute initial priority values of experiences in local memory.""" states = numpy2floattensor(memory["states"], self.device) actions = numpy2floattensor(memory["actions"], self.device).long() rewards = numpy2floattensor(memory["rewards"].reshape(-1, 1), self.device) next_states = numpy2floattensor(memory["next_states"], self.device) dones = numpy2floattensor(memory["dones"].reshape(-1, 1), self.device) memory_tensors = (states, actions, rewards, next_states, dones) with torch.no_grad(): dq_loss_element_wise, _ = self.loss_fn( self.dqn, self.dqn, memory_tensors, self.hyper_params.gamma, self.head_cfg, ) loss_for_prior = dq_loss_element_wise.detach().cpu().numpy() new_priorities = loss_for_prior + self.hyper_params.per_eps return new_priorities def synchronize(self, new_state_dict: Dict[str, np.ndarray]): """Synchronize worker dqn with learner dqn.""" self._synchronize(self.dqn, new_state_dict)
class DDPGLearner(Learner): """Learner for DDPG Agent. Attributes: args (argparse.Namespace): arguments including hyperparameters and training settings hyper_params (ConfigDict): hyper-parameters optim_cfg (ConfigDict): config of optimizer log_cfg (ConfigDict): configuration for saving log and checkpoint actor (nn.Module): actor model to select actions actor_target (nn.Module): target actor model to select actions critic (nn.Module): critic model to predict state values critic_target (nn.Module): target critic model to predict state values actor_optim (Optimizer): optimizer for training actor critic_optim (Optimizer): optimizer for training critic """ def __init__( self, args: argparse.Namespace, env_info: ConfigDict, hyper_params: ConfigDict, log_cfg: ConfigDict, backbone: ConfigDict, head: ConfigDict, optim_cfg: ConfigDict, noise_cfg: ConfigDict, device: torch.device, ): Learner.__init__(self, args, env_info, hyper_params, log_cfg, device) self.backbone_cfg = backbone self.head_cfg = head self.head_cfg.critic.configs.state_size = ( self.env_info.observation_space.shape[0] + self.env_info.action_space.shape[0], ) self.head_cfg.actor.configs.state_size = self.env_info.observation_space.shape self.head_cfg.actor.configs.output_size = self.env_info.action_space.shape[ 0] self.optim_cfg = optim_cfg self.noise_cfg = noise_cfg self._init_network() def _init_network(self): """Initialize networks and optimizers.""" # create actor self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.actor_target = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.actor_target.load_state_dict(self.actor.state_dict()) # create critic self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(self.device) self.critic_target = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(self.device) self.critic_target.load_state_dict(self.critic.state_dict()) # create optimizer self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr_actor, weight_decay=self.optim_cfg.weight_decay, ) self.critic_optim = optim.Adam( self.critic.parameters(), lr=self.optim_cfg.lr_critic, weight_decay=self.optim_cfg.weight_decay, ) # load the optimizer and model parameters if self.args.load_from is not None: self.load_params(self.args.load_from) def update_model( self, experience: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: """Update actor and critic networks.""" states, actions, rewards, next_states, dones = experience # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise masks = 1 - dones next_actions = self.actor_target(next_states) next_values = self.critic_target( torch.cat((next_states, next_actions), dim=-1)) curr_returns = rewards + self.hyper_params.gamma * next_values * masks curr_returns = curr_returns.to(self.device) # train critic gradient_clip_ac = self.hyper_params.gradient_clip_ac gradient_clip_cr = self.hyper_params.gradient_clip_cr values = self.critic(torch.cat((states, actions), dim=-1)) critic_loss = F.mse_loss(values, curr_returns) self.critic_optim.zero_grad() critic_loss.backward() clip_grad_norm_(self.critic.parameters(), gradient_clip_cr) self.critic_optim.step() # train actor actions = self.actor(states) actor_loss = -self.critic(torch.cat((states, actions), dim=-1)).mean() self.actor_optim.zero_grad() actor_loss.backward() clip_grad_norm_(self.actor.parameters(), gradient_clip_ac) self.actor_optim.step() # update target networks common_utils.soft_update(self.actor, self.actor_target, self.hyper_params.tau) common_utils.soft_update(self.critic, self.critic_target, self.hyper_params.tau) return actor_loss.item(), critic_loss.item() def save_params(self, n_episode: int): """Save model and optimizer parameters.""" params = { "actor_state_dict": self.actor.state_dict(), "actor_target_state_dict": self.actor_target.state_dict(), "critic_state_dict": self.critic.state_dict(), "critic_target_state_dict": self.critic_target.state_dict(), "actor_optim_state_dict": self.actor_optim.state_dict(), "critic_optim_state_dict": self.critic_optim.state_dict(), } Learner._save_params(self, params, n_episode) def load_params(self, path: str): """Load model and optimizer parameters.""" Learner.load_params(self, path) params = torch.load(path) self.actor.load_state_dict(params["actor_state_dict"]) self.actor_target.load_state_dict(params["actor_target_state_dict"]) self.critic.load_state_dict(params["critic_state_dict"]) self.critic_target.load_state_dict(params["critic_target_state_dict"]) self.actor_optim.load_state_dict(params["actor_optim_state_dict"]) self.critic_optim.load_state_dict(params["critic_optim_state_dict"]) print("[INFO] loaded the model and optimizer from", path) def get_state_dict(self) -> Tuple[OrderedDict]: """Return state dicts, mainly for distributed worker.""" return (self.critic_target.state_dict(), self.actor.state_dict()) def get_policy(self) -> nn.Module: """Return model (policy) used for action selection.""" return self.actor
class DistributedLogger(ABC): """Base class for loggers use in distributed training. Attributes: log_cfg (ConfigDict): configuration for saving log and checkpoint comm_config (ConfigDict): configs for communication backbone (ConfigDict): backbone configs for building network head (ConfigDict): head configs for building network brain (Brain): logger brain for evaluation update_step (int): tracker for learner update step device (torch.device): device, cpu by default log_info_queue (deque): queue for storing log info received from learner env (gym.Env): gym environment for running test """ def __init__( self, log_cfg: ConfigDict, comm_cfg: ConfigDict, backbone: ConfigDict, head: ConfigDict, env_name: str, is_atari: bool, state_size: int, output_size: int, max_update_step: int, episode_num: int, max_episode_steps: int, interim_test_num: int, is_log: bool, is_render: bool, ): self.log_cfg = log_cfg self.comm_cfg = comm_cfg self.device = torch.device("cpu") # Logger only runs on cpu head.configs.state_size = state_size head.configs.output_size = output_size self.brain = Brain(backbone, head).to(self.device) self.env_name = env_name self.is_atari = is_atari self.max_update_step = max_update_step self.episode_num = episode_num self.max_episode_steps = max_episode_steps self.interim_test_num = interim_test_num self.is_log = is_log self.is_render = is_render self.update_step = 0 self.log_info_queue = deque(maxlen=100) self._init_env() # pylint: disable=attribute-defined-outside-init def _init_env(self): """Initialize gym environment.""" if self.is_atari: self.env = atari_env_generator(self.env_name, self.max_episode_steps) else: self.env = gym.make(self.env_name) self.env, self.max_episode_steps = env_utils.set_env( self.env, self.max_episode_steps) @abstractmethod def load_params(self, path: str): if not os.path.exists(path): raise Exception( f"[ERROR] the input path does not exist. Wrong path: {path}") # pylint: disable=attribute-defined-outside-init def init_communication(self): """Initialize inter-process communication sockets.""" ctx = zmq.Context() self.pull_socket = ctx.socket(zmq.PULL) self.pull_socket.bind( f"tcp://127.0.0.1:{self.comm_cfg.learner_logger_port}") @abstractmethod def select_action(self, state: np.ndarray): pass @abstractmethod def write_log(self, log_value: dict): pass # pylint: disable=no-self-use @staticmethod def _preprocess_state(state: np.ndarray, device: torch.device) -> torch.Tensor: state = numpy2floattensor(state, device) return state def set_wandb(self): """Set configuration for wandb logging.""" wandb.init( project=self.env_name, name=f"{self.log_cfg.agent}/{self.log_cfg.curr_time}", ) additional_log = dict( episode_num=self.episode_num, max_episode_steps=self.max_episode_steps, ) wandb.config.update(additional_log) shutil.copy(self.log_cfg.cfg_path, os.path.join(wandb.run.dir, "config.yaml")) def recv_log_info(self): """Receive info from learner.""" received = False try: log_info_id = self.pull_socket.recv(zmq.DONTWAIT) received = True except zmq.Again: pass if received: self.log_info_queue.append(log_info_id) def run(self): """Run main logging loop; continuously receive data and log.""" if self.is_log: self.set_wandb() while self.update_step < self.max_update_step: self.recv_log_info() if self.log_info_queue: # if non-empty log_info_id = self.log_info_queue.pop() log_info = pa.deserialize(log_info_id) state_dict = log_info["state_dict"] log_value = log_info["log_value"] self.update_step = log_value["update_step"] self.synchronize(state_dict) avg_score = self.test(self.update_step) log_value["avg_score"] = avg_score self.write_log(log_value) def write_worker_log(self, worker_logs: List[dict], worker_update_interval: int): """Log the mean scores of each episode per update step to wandb.""" # NOTE: Worker plots are passed onto wandb.log as matplotlib.pyplot # since wandb doesn't support logging multiple lines to single plot self.set_wandb() # Plot individual workers fig = go.Figure() worker_id = 0 for worker_log in worker_logs: fig.add_trace( go.Scatter( x=list(worker_log.keys()), y=smoothen_graph(list(worker_log.values())), mode="lines", name=f"Worker {worker_id}", line=dict(width=2), )) worker_id = worker_id + 1 # Plot mean scores logged_update_steps = list( range(0, self.max_update_step + 1, worker_update_interval)) mean_scores = [] try: for step in logged_update_steps: scores_for_step = [] for worker_log in worker_logs: if step in list(worker_log): scores_for_step.append(worker_log[step]) mean_scores.append(np.mean(scores_for_step)) except Exception as e: print(f"[Error] {e}") fig.add_trace( go.Scatter( x=logged_update_steps, y=mean_scores, mode="lines+markers", name="Mean scores", line=dict(width=5), )) # Write to wandb wandb.log({"Worker scores": fig}) def test(self, update_step: int, interim_test: bool = True): """Test the agent.""" avg_score = self._test(update_step, interim_test) # termination self.env.close() return avg_score def _test(self, update_step: int, interim_test: bool) -> float: """Common test routine.""" if interim_test: test_num = self.interim_test_num else: test_num = self.episode_num self.brain.eval() scores = [] for i_episode in range(test_num): state = self.env.reset() done = False score = 0 step = 0 while not done: if self.is_render: self.env.render() action = self.select_action(state) next_state, reward, done, _ = self.env.step(action) state = next_state score += reward step += 1 scores.append(score) if interim_test: print( "[INFO] update step: %d\ttest %d\tstep: %d\ttotal score: %d" % (update_step, i_episode, step, score)) else: print("[INFO] test %d\tstep: %d\ttotal score: %d" % (i_episode, step, score)) return np.mean(scores) def synchronize(self, state_dict: Dict[str, np.ndarray]): """Copy parameters from numpy arrays.""" param_name_list = list(state_dict.keys()) for logger_named_param in self.brain.named_parameters(): logger_param_name = logger_named_param[0] if logger_param_name in param_name_list: new_param = numpy2floattensor(state_dict[logger_param_name], self.device) logger_named_param[1].data.copy_(new_param)
class ACERLearner(Learner): """Learner for ACER Agent. Attributes: args (argparse.Namespace): arguments including hyperparameters and training settings hyper_params (ConfigDict): hyper-parameters log_cfg (ConfigDict): configuration for saving log and checkpoint model (nn.Module): model to select actions and predict values model_optim (Optimizer): optimizer for training model """ def __init__( self, backbone: ConfigDict, head: ConfigDict, optim_cfg: ConfigDict, trust_region: ConfigDict, hyper_params: ConfigDict, log_cfg: ConfigDict, env_info: ConfigDict, is_test: bool, load_from: str, ): Learner.__init__(self, hyper_params, log_cfg, env_info.name, is_test) self.backbone_cfg = backbone self.head_cfg = head self.load_from = load_from self.head_cfg.actor.configs.state_size = env_info.observation_space.shape self.head_cfg.critic.configs.state_size = env_info.observation_space.shape self.head_cfg.actor.configs.output_size = env_info.action_space.n self.head_cfg.critic.configs.output_size = env_info.action_space.n self.optim_cfg = optim_cfg self.gradient_clip = hyper_params.gradient_clip self.trust_region = trust_region self._init_network() def _init_network(self): """Initialize network and optimizer.""" self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to( self.device ) # create optimizer self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr, eps=self.optim_cfg.adam_eps ) self.critic_optim = optim.Adam( self.critic.parameters(), lr=self.optim_cfg.lr, eps=self.optim_cfg.adam_eps ) self.actor_target = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to( self.device ) self.actor_target.load_state_dict(self.actor.state_dict()) if self.load_from is not None: self.load_params(self.load_from) def update_model(self, experience: Tuple) -> torch.Tensor: state, action, reward, prob, done = experience state = state.to(self.device) reward = reward.to(self.device) action = action.to(self.device) prob = prob.to(self.device).squeeze() done = done.to(self.device) pi = F.softmax(self.actor(state), 1) q = self.critic(state) q_i = q.gather(1, action) pi_i = pi.gather(1, action) with torch.no_grad(): v = (q * pi).sum(1).unsqueeze(1) rho = pi / (prob + 1e-8) rho_i = rho.gather(1, action) rho_bar = rho_i.clamp(max=self.hyper_params.c) q_ret = self.q_retrace( reward, done, q_i, v, rho_bar, self.hyper_params.gamma ).to(self.device) loss_f = -rho_bar * torch.log(pi_i + 1e-8) * (q_ret - v) loss_bc = ( -(1 - (self.hyper_params.c / rho)).clamp(min=0) * pi.detach() * torch.log(pi + 1e-8) * (q.detach() - v) ) value_loss = torch.sqrt((q_i - q_ret).pow(2)).mean() * 0.5 if self.trust_region.use_trust_region: g = loss_f + loss_bc pi_target = F.softmax(self.actor_target(state), 1) # gradient of partial Q KL(P || Q) = - P / Q k = -pi_target / (pi + 1e-8) k_dot_g = k * g tr = ( g - ((k_dot_g - self.trust_region.delta) / torch.norm(k)).clamp(max=0) * k ) loss = tr.mean() + value_loss else: loss = loss_f.mean() + loss_bc.sum(1).mean() + value_loss self.actor_optim.zero_grad() self.critic_optim.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.gradient_clip) nn.utils.clip_grad_norm_(self.critic.parameters(), self.gradient_clip) for name, param in self.actor.named_parameters(): if not torch.isfinite(param.grad).all(): print(name, torch.isfinite(param.grad).all()) print("Warning : Gradient is infinite. Do not update gradient.") return loss for name, param in self.critic.named_parameters(): if not torch.isfinite(param.grad).all(): print(name, torch.isfinite(param.grad).all()) print("Warning : Gradient is infinite. Do not update gradient.") return loss self.actor_optim.step() self.critic_optim.step() common_utils.soft_update(self.actor, self.actor_target, self.hyper_params.tau) return loss @staticmethod def q_retrace( reward: torch.Tensor, done: torch.Tensor, q_a: torch.Tensor, v: torch.Tensor, rho_bar: torch.Tensor, gamma: float, ): """Calculate Q retrace.""" q_ret = v[-1] q_ret_lst = [] for i in reversed(range(len(reward))): q_ret = reward[i] + gamma * q_ret * done[i] q_ret_lst.append(q_ret.item()) q_ret = rho_bar[i] * (q_ret - q_a[i]) + v[i] q_ret_lst.reverse() q_ret = torch.FloatTensor(q_ret_lst).unsqueeze(1) return q_ret def save_params(self, n_episode: int): params = { "actor_state_dict": self.actor.state_dict(), "actor_optim_state_dict": self.actor_optim.state_dict(), "critic_state_dict": self.critic.state_dict(), "critic_optim_state_dict": self.critic_optim.state_dict(), } Learner._save_params(self, params, n_episode) def load_params(self, path: str): Learner.load_params(self, path) params = torch.load(path) self.actor.load_state_dict(params["actor_state_dict"]) self.critic.load_state_dict(params["critic_state_dict"]) self.actor_optim.load_state_dict(params["actor_optim_state_dict"]) self.critic_optim.load_state_dict(params["critic_optim_state_dict"]) print("[INFO] Loaded the model and optimizer from", path) def get_state_dict(self) -> Tuple[OrderedDict]: """Return state dicts, mainly for distributed worker.""" return (self.model.state_dict(), self.optim.state_dict()) def get_policy(self) -> nn.Module: """Return model (policy) used for action selection.""" return self.actor
class SACLearner(Learner): """Learner for SAC Agent. Attributes: args (argparse.Namespace): arguments including hyperparameters and training settings hyper_params (ConfigDict): hyper-parameters log_cfg (ConfigDict): configuration for saving log and checkpoint update_step (int): step number of updates target_entropy (int): desired entropy used for the inequality constraint log_alpha (torch.Tensor): weight for entropy alpha_optim (Optimizer): optimizer for alpha actor (nn.Module): actor model to select actions actor_optim (Optimizer): optimizer for training actor critic_1 (nn.Module): critic model to predict state values critic_2 (nn.Module): critic model to predict state values critic_target1 (nn.Module): target critic model to predict state values critic_target2 (nn.Module): target critic model to predict state values critic_optim1 (Optimizer): optimizer for training critic_1 critic_optim2 (Optimizer): optimizer for training critic_2 """ def __init__( self, args: argparse.Namespace, env_info: ConfigDict, hyper_params: ConfigDict, log_cfg: ConfigDict, backbone: ConfigDict, head: ConfigDict, optim_cfg: ConfigDict, ): Learner.__init__(self, args, env_info, hyper_params, log_cfg) self.backbone_cfg = backbone self.head_cfg = head self.head_cfg.actor.configs.state_size = ( self.head_cfg.critic_vf.configs.state_size ) = self.env_info.observation_space.shape self.head_cfg.critic_qf.configs.state_size = ( self.env_info.observation_space.shape[0] + self.env_info.action_space.shape[0], ) self.head_cfg.actor.configs.output_size = self.env_info.action_space.shape[ 0] self.optim_cfg = optim_cfg self.update_step = 0 if self.hyper_params.auto_entropy_tuning: self.target_entropy = -np.prod( (self.env_info.action_space.shape[0], )).item() self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) self.alpha_optim = optim.Adam([self.log_alpha], lr=optim_cfg.lr_entropy) self._init_network() # pylint: disable=attribute-defined-outside-init def _init_network(self): """Initialize networks and optimizers.""" # create actor self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) # create v_critic self.vf = Brain(self.backbone_cfg.critic_vf, self.head_cfg.critic_vf).to(self.device) self.vf_target = Brain(self.backbone_cfg.critic_vf, self.head_cfg.critic_vf).to(self.device) self.vf_target.load_state_dict(self.vf.state_dict()) # create q_critic self.qf_1 = Brain(self.backbone_cfg.critic_qf, self.head_cfg.critic_qf).to(self.device) self.qf_2 = Brain(self.backbone_cfg.critic_qf, self.head_cfg.critic_qf).to(self.device) # create optimizers self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr_actor, weight_decay=self.optim_cfg.weight_decay, ) self.vf_optim = optim.Adam( self.vf.parameters(), lr=self.optim_cfg.lr_vf, weight_decay=self.optim_cfg.weight_decay, ) self.qf_1_optim = optim.Adam( self.qf_1.parameters(), lr=self.optim_cfg.lr_qf1, weight_decay=self.optim_cfg.weight_decay, ) self.qf_2_optim = optim.Adam( self.qf_2.parameters(), lr=self.optim_cfg.lr_qf2, weight_decay=self.optim_cfg.weight_decay, ) # load the optimizer and model parameters if self.args.load_from is not None: self.load_params(self.args.load_from) def update_model( self, experience: Union[TensorTuple, Tuple[TensorTuple]] ) -> Tuple[torch.Tensor, torch.Tensor, list, np.ndarray]: # type: ignore """Update actor and critic networks.""" self.update_step += 1 states, actions, rewards, next_states, dones = experience new_actions, log_prob, pre_tanh_value, mu, std = self.actor(states) # train alpha if self.hyper_params.auto_entropy_tuning: alpha_loss = (-self.log_alpha * (log_prob + self.target_entropy).detach()).mean() self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() alpha = self.log_alpha.exp() else: alpha_loss = torch.zeros(1) alpha = self.hyper_params.w_entropy # Q function loss masks = 1 - dones states_actions = torch.cat((states, actions), dim=-1) q_1_pred = self.qf_1(states_actions) q_2_pred = self.qf_2(states_actions) v_target = self.vf_target(next_states) q_target = rewards + self.hyper_params.gamma * v_target * masks qf_1_loss = F.mse_loss(q_1_pred, q_target.detach()) qf_2_loss = F.mse_loss(q_2_pred, q_target.detach()) # V function loss states_actions = torch.cat((states, new_actions), dim=-1) v_pred = self.vf(states) q_pred = torch.min(self.qf_1(states_actions), self.qf_2(states_actions)) v_target = q_pred - alpha * log_prob vf_loss = F.mse_loss(v_pred, v_target.detach()) if self.update_step % self.hyper_params.policy_update_freq == 0: # actor loss advantage = q_pred - v_pred.detach() actor_loss = (alpha * log_prob - advantage).mean() # regularization mean_reg = self.hyper_params.w_mean_reg * mu.pow(2).mean() std_reg = self.hyper_params.w_std_reg * std.pow(2).mean() pre_activation_reg = self.hyper_params.w_pre_activation_reg * ( pre_tanh_value.pow(2).sum(dim=-1).mean()) actor_reg = mean_reg + std_reg + pre_activation_reg # actor loss + regularization actor_loss += actor_reg # train actor self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() # update target networks common_utils.soft_update(self.vf, self.vf_target, self.hyper_params.tau) else: actor_loss = torch.zeros(1) # train Q functions self.qf_1_optim.zero_grad() qf_1_loss.backward() self.qf_1_optim.step() self.qf_2_optim.zero_grad() qf_2_loss.backward() self.qf_2_optim.step() # train V function self.vf_optim.zero_grad() vf_loss.backward() self.vf_optim.step() return ( actor_loss.item(), qf_1_loss.item(), qf_2_loss.item(), vf_loss.item(), alpha_loss.item(), ) def save_params(self, n_episode: int): """Save model and optimizer parameters.""" params = { "actor": self.actor.state_dict(), "qf_1": self.qf_1.state_dict(), "qf_2": self.qf_2.state_dict(), "vf": self.vf.state_dict(), "vf_target": self.vf_target.state_dict(), "actor_optim": self.actor_optim.state_dict(), "qf_1_optim": self.qf_1_optim.state_dict(), "qf_2_optim": self.qf_2_optim.state_dict(), "vf_optim": self.vf_optim.state_dict(), } if self.hyper_params.auto_entropy_tuning: params["alpha_optim"] = self.alpha_optim.state_dict() Learner._save_params(self, params, n_episode) def load_params(self, path: str): """Load model and optimizer parameters.""" Learner.load_params(self, path) params = torch.load(path) self.actor.load_state_dict(params["actor"]) self.qf_1.load_state_dict(params["qf_1"]) self.qf_2.load_state_dict(params["qf_2"]) self.vf.load_state_dict(params["vf"]) self.vf_target.load_state_dict(params["vf_target"]) self.actor_optim.load_state_dict(params["actor_optim"]) self.qf_1_optim.load_state_dict(params["qf_1_optim"]) self.qf_2_optim.load_state_dict(params["qf_2_optim"]) self.vf_optim.load_state_dict(params["vf_optim"]) if self.hyper_params.auto_entropy_tuning: self.alpha_optim.load_state_dict(params["alpha_optim"]) print("[INFO] loaded the model and optimizer from", path) def get_state_dict(self) -> Tuple[OrderedDict]: """Return state dicts, mainly for distributed worker.""" return (self.qf_1.state_dict(), self.qf_2.state_dict(), self.actor.state_dict()) def get_policy(self) -> nn.Module: """Return model (policy) used for action selection.""" return self.actor
class TD3Learner(Learner): """Learner for DDPG Agent. Attributes: hyper_params (ConfigDict): hyper-parameters network_cfg (ConfigDict): config of network for training agent optim_cfg (ConfigDict): config of optimizer noise_cfg (ConfigDict): config of noise target_policy_noise (GaussianNoise): random noise for target values actor (nn.Module): actor model to select actions critic1 (nn.Module): critic model to predict state values critic2 (nn.Module): critic model to predict state values critic_target1 (nn.Module): target critic model to predict state values critic_target2 (nn.Module): target critic model to predict state values actor_target (nn.Module): target actor model to select actions critic_optim (Optimizer): optimizer for training critic actor_optim (Optimizer): optimizer for training actor """ def __init__( self, hyper_params: ConfigDict, log_cfg: ConfigDict, backbone: ConfigDict, head: ConfigDict, optim_cfg: ConfigDict, noise_cfg: ConfigDict, env_name: str, state_size: tuple, output_size: int, is_test: bool, load_from: str, ): Learner.__init__(self, hyper_params, log_cfg, env_name, is_test) self.backbone_cfg = backbone self.head_cfg = head self.head_cfg.critic.configs.state_size = (state_size[0] + output_size, ) self.head_cfg.actor.configs.state_size = state_size self.head_cfg.actor.configs.output_size = output_size self.optim_cfg = optim_cfg self.noise_cfg = noise_cfg self.load_from = load_from self.target_policy_noise = GaussianNoise( self.head_cfg.actor.configs.output_size, self.noise_cfg.target_policy_noise, self.noise_cfg.target_policy_noise, ) self.update_step = 0 self._init_network() def _init_network(self): """Initialize networks and optimizers.""" # create actor self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.actor_target = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.actor_target.load_state_dict(self.actor.state_dict()) # create critic self.critic1 = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(self.device) self.critic2 = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(self.device) self.critic_target1 = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(self.device) self.critic_target2 = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(self.device) self.critic_target1.load_state_dict(self.critic1.state_dict()) self.critic_target2.load_state_dict(self.critic2.state_dict()) # concat critic parameters to use one optim critic_parameters = list(self.critic1.parameters()) + list( self.critic2.parameters()) # create optimizers self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr_actor, weight_decay=self.optim_cfg.weight_decay, ) self.critic_optim = optim.Adam( critic_parameters, lr=self.optim_cfg.lr_critic, weight_decay=self.optim_cfg.weight_decay, ) # load the optimizer and model parameters if self.load_from is not None: self.load_params(self.load_from) def update_model( self, experience: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: """Update TD3 actor and critic networks.""" self.update_step += 1 states, actions, rewards, next_states, dones = experience masks = 1 - dones # get actions with noise noise = common_utils.numpy2floattensor( self.target_policy_noise.sample(), self.device) clipped_noise = torch.clamp( noise, -self.noise_cfg.target_policy_noise_clip, self.noise_cfg.target_policy_noise_clip, ) next_actions = (self.actor_target(next_states) + clipped_noise).clamp( -1.0, 1.0) # min (Q_1', Q_2') next_states_actions = torch.cat((next_states, next_actions), dim=-1) next_values1 = self.critic_target1(next_states_actions) next_values2 = self.critic_target2(next_states_actions) next_values = torch.min(next_values1, next_values2) # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise curr_returns = rewards + self.hyper_params.gamma * next_values * masks curr_returns = curr_returns.detach() # critic loss state_actions = torch.cat((states, actions), dim=-1) values1 = self.critic1(state_actions) values2 = self.critic2(state_actions) critic1_loss = F.mse_loss(values1, curr_returns) critic2_loss = F.mse_loss(values2, curr_returns) # train critic critic_loss = critic1_loss + critic2_loss self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() if self.update_step % self.hyper_params.policy_update_freq == 0: # policy loss actions = self.actor(states) state_actions = torch.cat((states, actions), dim=-1) actor_loss = -self.critic1(state_actions).mean() # train actor self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() # update target networks tau = self.hyper_params.tau common_utils.soft_update(self.critic1, self.critic_target1, tau) common_utils.soft_update(self.critic2, self.critic_target2, tau) common_utils.soft_update(self.actor, self.actor_target, tau) else: actor_loss = torch.zeros(1) return actor_loss.item(), critic1_loss.item(), critic2_loss.item() def save_params(self, n_episode: int): """Save model and optimizer parameters.""" params = { "actor": self.actor.state_dict(), "actor_target": self.actor_target.state_dict(), "actor_optim": self.actor_optim.state_dict(), "critic1": self.critic1.state_dict(), "critic2": self.critic2.state_dict(), "critic_target1": self.critic_target1.state_dict(), "critic_target2": self.critic_target2.state_dict(), "critic_optim": self.critic_optim.state_dict(), } Learner._save_params(self, params, n_episode) def load_params(self, path: str): """Load model and optimizer parameters.""" Learner.load_params(self, path) params = torch.load(path) self.critic1.load_state_dict(params["critic1"]) self.critic2.load_state_dict(params["critic2"]) self.critic_target1.load_state_dict(params["critic_target1"]) self.critic_target2.load_state_dict(params["critic_target2"]) self.critic_optim.load_state_dict(params["critic_optim"]) self.actor.load_state_dict(params["actor"]) self.actor_target.load_state_dict(params["actor_target"]) self.actor_optim.load_state_dict(params["actor_optim"]) print("[INFO] loaded the model and optimizer from", path) def get_state_dict(self) -> Tuple[OrderedDict]: """Return state dicts, mainly for distributed worker.""" return ( self.critic_target1.state_dict(), self.critic_target2.state_dict(), self.actor.state_dict(), ) def get_policy(self) -> nn.Module: """Return model (policy) used for action selection.""" return self.actor
def __call__( self, model: Brain, target_model: Brain, experiences: Tuple[torch.Tensor, ...], gamma: float, head_cfg: ConfigDict, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return element-wise IQN loss and Q-values. Reference: https://github.com/google/dopamine """ states, actions, rewards, next_states, dones = experiences[:5] batch_size = states.shape[0] # size of rewards: (n_tau_prime_samples x batch_size) x 1. rewards = rewards.repeat(head_cfg.configs.n_tau_prime_samples, 1) # size of gamma_with_terminal: (n_tau_prime_samples x batch_size) x 1. masks = 1 - dones gamma_with_terminal = masks * gamma gamma_with_terminal = gamma_with_terminal.repeat( head_cfg.configs.n_tau_prime_samples, 1) # Get the indices of the maximium Q-value across the action dimension. # Shape of replay_next_qt_argmax: (n_tau_prime_samples x batch_size) x 1. next_actions = model(next_states).argmax(dim=1) # double Q next_actions = next_actions[:, None] next_actions = next_actions.repeat( head_cfg.configs.n_tau_prime_samples, 1) # Shape of next_target_values: (n_tau_prime_samples x batch_size) x 1. target_quantile_values, _ = target_model.forward_( next_states, head_cfg.configs.n_tau_prime_samples) target_quantile_values = target_quantile_values.gather(1, next_actions) target_quantile_values = rewards + gamma_with_terminal * target_quantile_values target_quantile_values = target_quantile_values.detach() # Reshape to n_tau_prime_samples x batch_size x 1 since this is # the manner in which the target_quantile_values are tiled. target_quantile_values = target_quantile_values.view( head_cfg.configs.n_tau_prime_samples, batch_size, 1) # Transpose dimensions so that the dimensionality is batch_size x # n_tau_prime_samples x 1 to prepare for computation of Bellman errors. target_quantile_values = torch.transpose(target_quantile_values, 0, 1) # Get quantile values: (n_tau_samples x batch_size) x action_dim. quantile_values, quantiles = model.forward_( states, head_cfg.configs.n_tau_samples) reshaped_actions = actions[:, None].repeat(head_cfg.configs.n_tau_samples, 1) chosen_action_quantile_values = quantile_values.gather( 1, reshaped_actions.long()) chosen_action_quantile_values = chosen_action_quantile_values.view( head_cfg.configs.n_tau_samples, batch_size, 1) # Transpose dimensions so that the dimensionality is batch_size x # n_tau_prime_samples x 1 to prepare for computation of Bellman errors. chosen_action_quantile_values = torch.transpose( chosen_action_quantile_values, 0, 1) # Shape of bellman_erors and huber_loss: # batch_size x num_tau_prime_samples x num_tau_samples x 1. bellman_errors = (target_quantile_values[:, :, None, :] - chosen_action_quantile_values[:, None, :, :]) # The huber loss (introduced in QR-DQN) is defined via two cases: # case_one: |bellman_errors| <= kappa # case_two: |bellman_errors| > kappa huber_loss_case_one = ( (torch.abs(bellman_errors) <= head_cfg.configs.kappa).float() * 0.5 * bellman_errors**2) huber_loss_case_two = ( (torch.abs(bellman_errors) > head_cfg.configs.kappa).float() * head_cfg.configs.kappa * (torch.abs(bellman_errors) - 0.5 * head_cfg.configs.kappa)) huber_loss = huber_loss_case_one + huber_loss_case_two # Reshape quantiles to batch_size x num_tau_samples x 1 quantiles = quantiles.view(head_cfg.configs.n_tau_samples, batch_size, 1) quantiles = torch.transpose(quantiles, 0, 1) # Tile by num_tau_prime_samples along a new dimension. Shape is now # batch_size x num_tau_prime_samples x num_tau_samples x 1. # These quantiles will be used for computation of the quantile huber loss # below (see section 2.3 of the paper). quantiles = quantiles[:, None, :, :].repeat( 1, head_cfg.configs.n_tau_prime_samples, 1, 1) quantiles = quantiles.to(device) # Shape: batch_size x n_tau_prime_samples x n_tau_samples x 1. quantile_huber_loss = ( torch.abs(quantiles - (bellman_errors < 0).float().detach()) * huber_loss / head_cfg.configs.kappa) # Sum over current quantile value (n_tau_samples) dimension, # average over target quantile value (n_tau_prime_samples) dimension. # Shape: batch_size x n_tau_prime_samples x 1. loss = torch.sum(quantile_huber_loss, dim=2) # Shape: batch_size x 1. iqn_loss_element_wise = torch.mean(loss, dim=1) # q values for regularization. q_values = model(states) return iqn_loss_element_wise, q_values
class GAILPPOLearner(PPOLearner): """PPO-based GAILLearner for GAIL Agent. Attributes: hyper_params (ConfigDict): hyper-parameters log_cfg (ConfigDict): configuration for saving log and checkpoint actor (nn.Module): actor model to select actions critic (nn.Module): critic model to predict state values discriminator (nn.Module): discriminator model to classify data actor_optim (Optimizer): optimizer for training actor critic_optim (Optimizer): optimizer for training critic discriminator_optim (Optimizer): optimizer for training discriminator """ def __init__( self, hyper_params: ConfigDict, log_cfg: ConfigDict, backbone: ConfigDict, head: ConfigDict, optim_cfg: ConfigDict, env_name: str, state_size: tuple, output_size: int, is_test: bool, load_from: str, ): head.discriminator.configs.state_size = state_size head.discriminator.configs.action_size = output_size super().__init__( hyper_params, log_cfg, backbone, head, optim_cfg, env_name, state_size, output_size, is_test, load_from, ) self.demo_memory = None def _init_network(self): """Initialize networks and optimizers.""" # create actor if self.backbone_cfg.shared_actor_critic: shared_backbone = build_backbone( self.backbone_cfg.shared_actor_critic) self.actor = Brain( self.backbone_cfg.shared_actor_critic, self.head_cfg.actor, shared_backbone, ) self.critic = Brain( self.backbone_cfg.shared_actor_critic, self.head_cfg.critic, shared_backbone, ) self.actor = self.actor.to(self.device) self.critic = self.critic.to(self.device) else: self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(self.device) self.discriminator = Discriminator( self.backbone_cfg.discriminator, self.head_cfg.discriminator, self.head_cfg.aciton_embedder, ).to(self.device) # create optimizer self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr_actor, weight_decay=self.optim_cfg.weight_decay, ) self.critic_optim = optim.Adam( self.critic.parameters(), lr=self.optim_cfg.lr_critic, weight_decay=self.optim_cfg.weight_decay, ) self.discriminator_optim = optim.Adam( self.discriminator.parameters(), lr=self.optim_cfg.lr_discriminator, weight_decay=self.optim_cfg.weight_decay, ) # load model parameters if self.load_from is not None: self.load_params(self.load_from) def update_model(self, experience: TensorTuple, epsilon: float) -> TensorTuple: """Update generator(actor), critic and discriminator networks.""" states, actions, rewards, values, log_probs, next_state, masks = experience next_state = numpy2floattensor(next_state, self.device) with torch.no_grad(): next_value = self.critic(next_state) returns = ppo_utils.compute_gae( next_value, rewards, masks, values, self.hyper_params.gamma, self.hyper_params.tau, ) states = torch.cat(states) actions = torch.cat(actions) returns = torch.cat(returns).detach() values = torch.cat(values).detach() log_probs = torch.cat(log_probs).detach() advantages = (returns - values).detach() if self.hyper_params.standardize_advantage: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-7) actor_losses, critic_losses, total_losses, discriminator_losses = [], [], [], [] for ( state, action, old_value, old_log_prob, return_, adv, epoch, ) in ppo_utils.ppo_iter( self.hyper_params.epoch, self.hyper_params.batch_size, states, actions, values, log_probs, returns, advantages, ): # critic_loss value = self.critic(state) if self.hyper_params.use_clipped_value_loss: value_pred_clipped = old_value + torch.clamp( (value - old_value), -epsilon, epsilon) value_loss_clipped = (return_ - value_pred_clipped).pow(2) value_loss = (return_ - value).pow(2) critic_loss = 0.5 * torch.max(value_loss, value_loss_clipped).mean() else: critic_loss = 0.5 * (return_ - value).pow(2).mean() critic_loss_ = self.hyper_params.w_value * critic_loss # train critic self.critic_optim.zero_grad() critic_loss_.backward() clip_grad_norm_(self.critic.parameters(), self.hyper_params.gradient_clip_cr) self.critic_optim.step() # calculate ratios _, dist = self.actor(state) log_prob = dist.log_prob(action) ratio = (log_prob - old_log_prob).exp() # actor_loss surr_loss = ratio * adv clipped_surr_loss = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon) * adv actor_loss = -torch.min(surr_loss, clipped_surr_loss).mean() # entropy entropy = dist.entropy().mean() actor_loss_ = actor_loss - self.hyper_params.w_entropy * entropy # train actor self.actor_optim.zero_grad() actor_loss_.backward() clip_grad_norm_(self.actor.parameters(), self.hyper_params.gradient_clip_ac) self.actor_optim.step() # total_loss total_loss = critic_loss_ + actor_loss_ # discriminator loss demo_state, demo_action = self.demo_memory.sample(len(state)) exp_score = torch.sigmoid( self.discriminator.forward((state, action))) demo_score = torch.sigmoid( self.discriminator.forward((demo_state, demo_action))) discriminator_exp_acc = (exp_score > 0.5).float().mean().item() discriminator_demo_acc = (demo_score <= 0.5).float().mean().item() discriminator_loss = F.binary_cross_entropy( exp_score, torch.ones_like(exp_score)) + F.binary_cross_entropy( demo_score, torch.zeros_like(demo_score)) # train discriminator if (discriminator_exp_acc < self.optim_cfg.discriminator_acc_threshold or discriminator_demo_acc < self.optim_cfg.discriminator_acc_threshold and epoch == 0): self.discriminator_optim.zero_grad() discriminator_loss.backward() self.discriminator_optim.step() actor_losses.append(actor_loss.item()) critic_losses.append(critic_loss.item()) total_losses.append(total_loss.item()) discriminator_losses.append(discriminator_loss.item()) actor_loss = sum(actor_losses) / len(actor_losses) critic_loss = sum(critic_losses) / len(critic_losses) total_loss = sum(total_losses) / len(total_losses) discriminator_loss = sum(discriminator_losses) / len( discriminator_losses) return ( (actor_loss, critic_loss, total_loss, discriminator_loss), (discriminator_exp_acc, discriminator_demo_acc), ) def save_params(self, n_episode: int): """Save model and optimizer parameters.""" params = { "actor_state_dict": self.actor.state_dict(), "critic_state_dict": self.critic.state_dict(), "discriminator_state_dict": self.discriminator.state_dict(), "actor_optim_state_dict": self.actor_optim.state_dict(), "critic_optim_state_dict": self.critic_optim.state_dict(), "discriminator_optim_state_dict": self.discriminator_optim.state_dict(), } PPOLearner._save_params(self, params, n_episode) def load_params(self, path: str): """Load model and optimizer parameters.""" PPOLearner.load_params(self, path) params = torch.load(path) self.actor.load_state_dict(params["actor_state_dict"]) self.critic.load_state_dict(params["critic_state_dict"]) self.actor_optim.load_state_dict(params["actor_optim_state_dict"]) self.critic_optim.load_state_dict(params["critic_optim_state_dict"]) print("[INFO] loaded the model and optimizer from", path) def get_state_dict(self) -> Tuple[OrderedDict]: """Return state dicts, mainly for distributed worker.""" return ( self.actor.state_dict(), self.critic.state_dict(), self.discriminator.state_dict(), ) def set_demo_memory(self, demo_memory): self.demo_memory = demo_memory
def __call__( self, model: Brain, target_model: Brain, experiences: Tuple[torch.Tensor, ...], gamma: float, head_cfg: ConfigDict, burn_in_step: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return R2D1 loss and Q-values.""" # TODO: Combine with IQNLoss output_size = head_cfg.configs.output_size ( burnin_states_tuple, states_tuple, burnin_prev_actions_tuple, agent_actions, prev_actions_tuple, burnin_prev_rewards_tuple, agent_rewards, prev_rewards_tuple, burnin_dones_tuple, agent_dones, init_rnn_state, ) = slice_r2d1_arguments(experiences, burn_in_step, output_size) batch_size = states_tuple[0].shape[0] sequence_size = states_tuple[0].shape[1] with torch.no_grad(): _, target_rnn_state = target_model( burnin_states_tuple[1], init_rnn_state, burnin_prev_actions_tuple[1], burnin_prev_rewards_tuple[1], ) _, init_rnn_state = model( burnin_states_tuple[0], init_rnn_state, burnin_prev_actions_tuple[0], burnin_prev_rewards_tuple[0], ) init_rnn_state = torch.transpose(init_rnn_state, 0, 1) target_rnn_state = torch.transpose(target_rnn_state, 0, 1) burnin_invalid_mask = valid_from_done(burnin_dones_tuple[0].transpose( 0, 1)) burnin_target_invalid_mask = valid_from_done( burnin_dones_tuple[1].transpose(0, 1)) init_rnn_state[burnin_invalid_mask] = 0 target_rnn_state[burnin_target_invalid_mask] = 0 # size of rewards: (n_tau_prime_samples x batch_size) x 1. agent_rewards = agent_rewards.repeat( head_cfg.configs.n_tau_prime_samples, 1, 1) # size of gamma_with_terminal: (n_tau_prime_samples x batch_size) x 1. masks = 1 - agent_dones gamma_with_terminal = masks * gamma gamma_with_terminal = gamma_with_terminal.repeat( head_cfg.configs.n_tau_prime_samples, 1, 1) # Get the indices of the maximium Q-value across the action dimension. # Shape of replay_next_qt_argmax: (n_tau_prime_samples x batch_size) x 1. next_actions, _ = model( states_tuple[1], target_rnn_state, prev_actions_tuple[1], prev_rewards_tuple[1], ).argmax(dim=-1) next_actions = next_actions[:, :, None] next_actions = next_actions.repeat( head_cfg.configs.n_tau_prime_samples, 1, 1) with torch.no_grad(): # Shape of next_target_values: (n_tau_prime_samples x batch_size) x 1. target_quantile_values, _, _ = target_model.forward_( states_tuple[1], target_rnn_state, prev_actions_tuple[1], prev_rewards_tuple[1], head_cfg.configs.n_tau_prime_samples, ) target_quantile_values = target_quantile_values.gather( -1, next_actions) target_quantile_values = ( agent_rewards + gamma_with_terminal * target_quantile_values) target_quantile_values = target_quantile_values.detach() # Reshape to n_tau_prime_samples x batch_size x 1 since this is # the manner in which the target_quantile_values are tiled. target_quantile_values = target_quantile_values.view( head_cfg.configs.n_tau_prime_samples, batch_size, sequence_size, 1) # Transpose dimensions so that the dimensionality is batch_size x # n_tau_prime_samples x 1 to prepare for computation of Bellman errors. target_quantile_values = torch.transpose(target_quantile_values, 0, 1) # Get quantile values: (n_tau_samples x batch_size) x action_dim. quantile_values, quantiles, _ = model.forward_( states_tuple[0], init_rnn_state, prev_actions_tuple[0], prev_rewards_tuple[0], head_cfg.configs.n_tau_samples, ) reshaped_actions = agent_actions.repeat(head_cfg.configs.n_tau_samples, 1, 1) chosen_action_quantile_values = quantile_values.gather( -1, reshaped_actions.long()) chosen_action_quantile_values = chosen_action_quantile_values.view( head_cfg.configs.n_tau_samples, batch_size, sequence_size, 1) # Transpose dimensions so that the dimensionality is batch_size x # n_tau_prime_samples x 1 to prepare for computation of Bellman errors. chosen_action_quantile_values = torch.transpose( chosen_action_quantile_values, 0, 1) # Shape of bellman_erors and huber_loss: # batch_size x num_tau_prime_samples x num_tau_samples x 1. bellman_errors = (target_quantile_values[:, :, None, :] - chosen_action_quantile_values[:, None, :, :]) # The huber loss (introduced in QR-DQN) is defined via two cases: # case_one: |bellman_errors| <= kappa # case_two: |bellman_errors| > kappa huber_loss_case_one = ( (torch.abs(bellman_errors) <= head_cfg.configs.kappa).float() * 0.5 * bellman_errors**2) huber_loss_case_two = ( (torch.abs(bellman_errors) > head_cfg.configs.kappa).float() * head_cfg.configs.kappa * (torch.abs(bellman_errors) - 0.5 * head_cfg.configs.kappa)) huber_loss = huber_loss_case_one + huber_loss_case_two # Reshape quantiles to batch_size x num_tau_samples x 1 quantiles = quantiles.view(head_cfg.configs.n_tau_samples, batch_size, sequence_size, 1) quantiles = torch.transpose(quantiles, 0, 1) # Tile by num_tau_prime_samples along a new dimension. Shape is now # batch_size x num_tau_prime_samples x num_tau_samples x sequence_length x 1. # These quantiles will be used for computation of the quantile huber loss # below (see section 2.3 of the paper). quantiles = quantiles[:, None, :, :, :].repeat( 1, head_cfg.configs.n_tau_prime_samples, 1, 1, 1) # Shape: batch_size x n_tau_prime_samples x n_tau_samples x sequence_length x 1. quantile_huber_loss = ( torch.abs(quantiles - (bellman_errors < 0).float().detach()) * huber_loss / head_cfg.configs.kappa) # Sum over current quantile value (n_tau_samples) dimension, # average over target quantile value (n_tau_prime_samples) dimension. # Shape: batch_size x n_tau_prime_samples x 1. loss = torch.sum(quantile_huber_loss, dim=2) # Shape: batch_size x sequence_length x 1. iqn_loss_element_wise = torch.mean(loss, dim=1) # Shape: batch_size x 1. iqn_loss_element_wise = abs(torch.mean(iqn_loss_element_wise, dim=1)) # q values for regularization. q_values, _ = model( states_tuple[0], init_rnn_state, prev_actions_tuple[0], prev_rewards_tuple[0], ) return iqn_loss_element_wise, q_values
class A2CLearner(Learner): """Learner for A2C Agent. Attributes: hyper_params (ConfigDict): hyper-parameters log_cfg (ConfigDict): configuration for saving log and checkpoint actor (nn.Module): actor model to select actions critic (nn.Module): critic model to predict state values actor_optim (Optimizer): optimizer for training actor critic_optim (Optimizer): optimizer for training critic """ def __init__( self, hyper_params: ConfigDict, log_cfg: ConfigDict, backbone: ConfigDict, head: ConfigDict, optim_cfg: ConfigDict, env_name: str, state_size: tuple, output_size: int, is_test: bool, load_from: str, ): Learner.__init__(self, hyper_params, log_cfg, env_name, is_test) self.backbone_cfg = backbone self.head_cfg = head self.head_cfg.actor.configs.state_size = ( self.head_cfg.critic.configs.state_size ) = state_size self.head_cfg.actor.configs.output_size = output_size self.optim_cfg = optim_cfg self.load_from = load_from self._init_network() def _init_network(self): """Initialize networks and optimizers.""" self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to( self.device ) # create optimizer self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr_actor, weight_decay=self.optim_cfg.weight_decay, ) self.critic_optim = optim.Adam( self.critic.parameters(), lr=self.optim_cfg.lr_critic, weight_decay=self.optim_cfg.weight_decay, ) if self.load_from is not None: self.load_params(self.load_from) def update_model(self, experience: TensorTuple) -> TensorTuple: """Update A2C actor and critic networks""" log_prob, pred_value, next_state, reward, done = experience next_state = numpy2floattensor(next_state, self.device) # Q_t = r + gamma * V(s_{t+1}) if state != Terminal # = r otherwise mask = 1 - done next_value = self.critic(next_state).detach() q_value = reward + self.hyper_params.gamma * next_value * mask q_value = q_value.to(self.device) # advantage = Q_t - V(s_t) advantage = q_value - pred_value # calculate loss at the current step policy_loss = -advantage.detach() * log_prob # adv. is not backpropagated policy_loss += self.hyper_params.w_entropy * -log_prob # entropy value_loss = F.smooth_l1_loss(pred_value, q_value.detach()) # train gradient_clip_ac = self.hyper_params.gradient_clip_ac gradient_clip_cr = self.hyper_params.gradient_clip_cr self.actor_optim.zero_grad() policy_loss.backward() clip_grad_norm_(self.actor.parameters(), gradient_clip_ac) self.actor_optim.step() self.critic_optim.zero_grad() value_loss.backward() clip_grad_norm_(self.critic.parameters(), gradient_clip_cr) self.critic_optim.step() return policy_loss.item(), value_loss.item() def save_params(self, n_episode: int): """Save model and optimizer parameters.""" params = { "actor_state_dict": self.actor.state_dict(), "critic_state_dict": self.critic.state_dict(), "actor_optim_state_dict": self.actor_optim.state_dict(), "critic_optim_state_dict": self.critic_optim.state_dict(), } Learner._save_params(self, params, n_episode) def load_params(self, path: str): """Load model and optimizer parameters.""" Learner.load_params(self, path) params = torch.load(path) self.actor.load_state_dict(params["actor_state_dict"]) self.critic.load_state_dict(params["critic_state_dict"]) self.actor_optim.load_state_dict(params["actor_optim_state_dict"]) self.critic_optim.load_state_dict(params["critic_optim_state_dict"]) print("[INFO] Loaded the model and optimizer from", path) def get_state_dict(self) -> Tuple[OrderedDict]: """Return state dicts, mainly for distributed worker.""" return (self.critic.state_dict(), self.actor.state_dict()) def get_policy(self) -> nn.Module: """Return model (policy) used for action selection.""" return self.actor
class PPOLearner(Learner): """Learner for PPO Agent. Attributes: args (argparse.Namespace): arguments including hyperparameters and training settings hyper_params (ConfigDict): hyper-parameters log_cfg (ConfigDict): configuration for saving log and checkpoint actor (nn.Module): actor model to select actions critic (nn.Module): critic model to predict state values actor_optim (Optimizer): optimizer for training actor critic_optim (Optimizer): optimizer for training critic """ def __init__( self, args: argparse.Namespace, env_info: ConfigDict, hyper_params: ConfigDict, log_cfg: ConfigDict, backbone: ConfigDict, head: ConfigDict, optim_cfg: ConfigDict, ): Learner.__init__( self, args, env_info, hyper_params, log_cfg, ) self.backbone_cfg = backbone self.head_cfg = head self.head_cfg.actor.configs.state_size = ( self.head_cfg.critic.configs.state_size ) = self.env_info.observation_space.shape self.head_cfg.actor.configs.output_size = self.env_info.action_space.shape[ 0] self.optim_cfg = optim_cfg self.is_discrete = self.hyper_params.is_discrete self._init_network() def _init_network(self): """Initialize networks and optimizers.""" # create actor self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(self.device) # create optimizer self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr_actor, weight_decay=self.optim_cfg.weight_decay, ) self.critic_optim = optim.Adam( self.critic.parameters(), lr=self.optim_cfg.lr_critic, weight_decay=self.optim_cfg.weight_decay, ) # load model parameters if self.args.load_from is not None: self.load_params(self.args.load_from) def update_model(self, experience: TensorTuple, epsilon: float) -> TensorTuple: """Update PPO actor and critic networks""" states, actions, rewards, values, log_probs, next_state, masks = experience next_state = numpy2floattensor(next_state, self.device) next_value = self.critic(next_state) returns = ppo_utils.compute_gae( next_value, rewards, masks, values, self.hyper_params.gamma, self.hyper_params.tau, ) states = torch.cat(states) actions = torch.cat(actions) returns = torch.cat(returns).detach() values = torch.cat(values).detach() log_probs = torch.cat(log_probs).detach() advantages = returns - values if self.is_discrete: actions = actions.unsqueeze(1) log_probs = log_probs.unsqueeze(1) if self.hyper_params.standardize_advantage: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-7) actor_losses, critic_losses, total_losses = [], [], [] for state, action, old_value, old_log_prob, return_, adv in ppo_utils.ppo_iter( self.hyper_params.epoch, self.hyper_params.batch_size, states, actions, values, log_probs, returns, advantages, ): # calculate ratios _, dist = self.actor(state) log_prob = dist.log_prob(action) ratio = (log_prob - old_log_prob).exp() # actor_loss surr_loss = ratio * adv clipped_surr_loss = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon) * adv actor_loss = -torch.min(surr_loss, clipped_surr_loss).mean() # critic_loss value = self.critic(state) if self.hyper_params.use_clipped_value_loss: value_pred_clipped = old_value + torch.clamp( (value - old_value), -epsilon, epsilon) value_loss_clipped = (return_ - value_pred_clipped).pow(2) value_loss = (return_ - value).pow(2) critic_loss = 0.5 * torch.max(value_loss, value_loss_clipped).mean() else: critic_loss = 0.5 * (return_ - value).pow(2).mean() # entropy entropy = dist.entropy().mean() # total_loss w_value = self.hyper_params.w_value w_entropy = self.hyper_params.w_entropy critic_loss_ = w_value * critic_loss actor_loss_ = actor_loss - w_entropy * entropy total_loss = critic_loss_ + actor_loss_ # train critic gradient_clip_ac = self.hyper_params.gradient_clip_ac gradient_clip_cr = self.hyper_params.gradient_clip_cr self.critic_optim.zero_grad() critic_loss_.backward(retain_graph=True) clip_grad_norm_(self.critic.parameters(), gradient_clip_ac) self.critic_optim.step() # train actor self.actor_optim.zero_grad() actor_loss_.backward() clip_grad_norm_(self.actor.parameters(), gradient_clip_cr) self.actor_optim.step() actor_losses.append(actor_loss.item()) critic_losses.append(critic_loss.item()) total_losses.append(total_loss.item()) actor_loss = sum(actor_losses) / len(actor_losses) critic_loss = sum(critic_losses) / len(critic_losses) total_loss = sum(total_losses) / len(total_losses) return actor_loss, critic_loss, total_loss def save_params(self, n_episode: int): """Save model and optimizer parameters.""" params = { "actor_state_dict": self.actor.state_dict(), "critic_state_dict": self.critic.state_dict(), "actor_optim_state_dict": self.actor_optim.state_dict(), "critic_optim_state_dict": self.critic_optim.state_dict(), } Learner._save_params(self, params, n_episode) def load_params(self, path: str): """Load model and optimizer parameters.""" Learner.load_params(self, path) params = torch.load(path) self.actor.load_state_dict(params["actor_state_dict"]) self.critic.load_state_dict(params["critic_state_dict"]) self.actor_optim.load_state_dict(params["actor_optim_state_dict"]) self.critic_optim.load_state_dict(params["critic_optim_state_dict"]) print("[INFO] loaded the model and optimizer from", path) def get_state_dict(self) -> Tuple[OrderedDict]: """Return state dicts, mainly for distributed worker.""" return (self.actor.state_dict(), self.critic.state_dict()) def get_policy(self) -> nn.Module: """Return model (policy) used for action selection.""" return self.actor
class DQNLearner(Learner): """Learner for DQN Agent. Attributes: args (argparse.Namespace): arguments including hyperparameters and training settings hyper_params (ConfigDict): hyper-parameters log_cfg (ConfigDict): configuration for saving log and checkpoint dqn (nn.Module): dqn model to predict state Q values dqn_target (nn.Module): target dqn model to predict state Q values dqn_optim (Optimizer): optimizer for training dqn """ def __init__( self, args: argparse.Namespace, env_info: ConfigDict, hyper_params: ConfigDict, log_cfg: ConfigDict, backbone: ConfigDict, head: ConfigDict, optim_cfg: ConfigDict, loss_type: ConfigDict, ): Learner.__init__(self, args, env_info, hyper_params, log_cfg) self.backbone_cfg = backbone self.head_cfg = head self.head_cfg.configs.state_size = self.env_info.observation_space.shape self.head_cfg.configs.output_size = self.env_info.action_space.n self.optim_cfg = optim_cfg self.use_n_step = self.hyper_params.n_step > 1 self.loss_type = loss_type self._init_network() # pylint: disable=attribute-defined-outside-init def _init_network(self): """Initialize networks and optimizers.""" self.dqn = Brain(self.backbone_cfg, self.head_cfg).to(self.device) self.dqn_target = Brain(self.backbone_cfg, self.head_cfg).to(self.device) self.loss_fn = build_loss(self.loss_type) self.dqn_target.load_state_dict(self.dqn.state_dict()) # create optimizer self.dqn_optim = optim.Adam( self.dqn.parameters(), lr=self.optim_cfg.lr_dqn, weight_decay=self.optim_cfg.weight_decay, eps=self.optim_cfg.adam_eps, ) # load the optimizer and model parameters if self.args.load_from is not None: self.load_params(self.args.load_from) def update_model( self, experience: Union[TensorTuple, Tuple[TensorTuple]] ) -> Tuple[torch.Tensor, torch.Tensor, list, np.ndarray]: # type: ignore """Update dqn and dqn target.""" if self.use_n_step: experience_1, experience_n = experience else: experience_1 = experience weights, indices = experience_1[-3:-1] gamma = self.hyper_params.gamma dq_loss_element_wise, q_values = self.loss_fn(self.dqn, self.dqn_target, experience_1, gamma, self.head_cfg) dq_loss = torch.mean(dq_loss_element_wise * weights) # n step loss if self.use_n_step: gamma = self.hyper_params.gamma**self.hyper_params.n_step dq_loss_n_element_wise, q_values_n = self.loss_fn( self.dqn, self.dqn_target, experience_n, gamma, self.head_cfg) # to update loss and priorities q_values = 0.5 * (q_values + q_values_n) dq_loss_element_wise += dq_loss_n_element_wise * self.hyper_params.w_n_step dq_loss = torch.mean(dq_loss_element_wise * weights) # q_value regularization q_regular = torch.norm(q_values, 2).mean() * self.hyper_params.w_q_reg # total loss loss = dq_loss + q_regular self.dqn_optim.zero_grad() loss.backward() clip_grad_norm_(self.dqn.parameters(), self.hyper_params.gradient_clip) self.dqn_optim.step() # update target networks common_utils.soft_update(self.dqn, self.dqn_target, self.hyper_params.tau) # update priorities in PER loss_for_prior = dq_loss_element_wise.detach().cpu().numpy() new_priorities = loss_for_prior + self.hyper_params.per_eps if self.head_cfg.configs.use_noisy_net: self.dqn.head.reset_noise() self.dqn_target.head.reset_noise() return ( loss.item(), q_values.mean().item(), indices, new_priorities, ) def save_params(self, n_episode: int): """Save model and optimizer parameters.""" params = { "dqn_state_dict": self.dqn.state_dict(), "dqn_target_state_dict": self.dqn_target.state_dict(), "dqn_optim_state_dict": self.dqn_optim.state_dict(), } Learner._save_params(self, params, n_episode) # pylint: disable=attribute-defined-outside-init def load_params(self, path: str): """Load model and optimizer parameters.""" Learner.load_params(self, path) params = torch.load(path) self.dqn.load_state_dict(params["dqn_state_dict"]) self.dqn_target.load_state_dict(params["dqn_target_state_dict"]) self.dqn_optim.load_state_dict(params["dqn_optim_state_dict"]) print("[INFO] loaded the model and optimizer from", path) def get_state_dict(self) -> OrderedDict: """Return state dicts, mainly for distributed worker.""" dqn = deepcopy(self.dqn) return dqn.cpu().state_dict() def get_policy(self) -> nn.Module: """Return model (policy) used for action selection, used only in grad cam.""" return self.dqn
def _init_networks(self, state_dict: OrderedDict): """Initialize DQN policy with learner state dict.""" self.dqn = Brain(self.backbone_cfg, self.head_cfg).to(self.device) self.dqn.load_state_dict(state_dict) self.dqn.eval()
def __call__( self, model: Brain, target_model: Brain, experiences: Tuple[torch.Tensor, ...], gamma: float, head_cfg: ConfigDict, burn_in_step: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return element-wise C51 loss and Q-values.""" # TODO: Combine with IQNLoss output_size = head_cfg.configs.output_size ( burnin_states_tuple, states_tuple, burnin_prev_actions_tuple, agent_actions, prev_actions_tuple, burnin_prev_rewards_tuple, agent_rewards, prev_rewards_tuple, burnin_dones_tuple, agent_dones, init_rnn_state, ) = slice_r2d1_arguments(experiences, burn_in_step, output_size) batch_size = states_tuple[0].shape[0] sequence_size = states_tuple[0].shape[1] with torch.no_grad(): _, target_rnn_state = target_model( burnin_states_tuple[1], init_rnn_state, burnin_prev_actions_tuple[1], burnin_prev_rewards_tuple[1], ) _, init_rnn_state = model( burnin_states_tuple[0], init_rnn_state, burnin_prev_actions_tuple[0], burnin_prev_rewards_tuple[0], ) init_rnn_state = torch.transpose(init_rnn_state, 0, 1) target_rnn_state = torch.transpose(target_rnn_state, 0, 1) burnin_invalid_mask = valid_from_done(burnin_dones_tuple[0].transpose( 0, 1)) burnin_target_invalid_mask = valid_from_done( burnin_dones_tuple[1].transpose(0, 1)) init_rnn_state[burnin_invalid_mask] = 0 target_rnn_state[burnin_target_invalid_mask] = 0 support = torch.linspace(head_cfg.configs.v_min, head_cfg.configs.v_max, head_cfg.configs.atom_size).to(device) delta_z = float(head_cfg.configs.v_max - head_cfg.configs.v_min) / ( head_cfg.configs.atom_size - 1) with torch.no_grad(): # According to noisynet paper, # it resamples noisynet parameters on online network when using double q # but we don't because there is no remarkable difference in performance. next_actions, _ = model.forward_( states_tuple[1], target_rnn_state, prev_actions_tuple[1], prev_rewards_tuple[1], ) next_actions = next_actions[1].argmax(-1) next_dist, _ = target_model.forward_( states_tuple[1], target_rnn_state, prev_actions_tuple[1], prev_rewards_tuple[1], ) next_dist = next_dist[0][range(batch_size * sequence_size), next_actions] t_z = agent_rewards + (1 - agent_dones) * gamma * support t_z = t_z.clamp(min=head_cfg.configs.v_min, max=head_cfg.configs.v_max) b = (t_z - head_cfg.configs.v_min) / delta_z b = b.view(batch_size * sequence_size, -1) l = b.floor().long() # noqa: E741 u = b.ceil().long() offset = (torch.linspace( 0, (batch_size * sequence_size - 1) * head_cfg.configs.atom_size, batch_size * sequence_size, ).long().unsqueeze(1)) offset = offset.expand(batch_size * sequence_size, head_cfg.configs.atom_size).to(device) proj_dist = torch.zeros(next_dist.size(), device=device) proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)) proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)) (dist, q_values), _ = model.forward_( states_tuple[0], init_rnn_state, prev_actions_tuple[0], prev_rewards_tuple[0], ) log_p = dist[range(batch_size * sequence_size), agent_actions.contiguous().view(batch_size * sequence_size).long(), ] log_p = torch.log(log_p.clamp(min=1e-5)) log_p = log_p.view(batch_size, sequence_size, -1) proj_dist = proj_dist.view(batch_size, sequence_size, -1) dq_loss_element_wise = -(proj_dist * log_p).sum(-1).mean(1) return dq_loss_element_wise, q_values