class Agent(SACAgent): """SAC agent interacting with environment. Attrtibutes: memory (ReplayBuffer): replay memory demo_memory (ReplayBuffer): replay memory for demo her (HER): hinsight experience replay transitions_epi (list): transitions per episode (for HER) goal_state (np.ndarray): goal state to generate concatenated states total_step (int): total step numbers episode_step (int): step number of the current episode """ # pylint: disable=attribute-defined-outside-init def _initialize(self): """Initialize non-common things.""" # load demo replay memory with open(self.args.demo_path, "rb") as f: demo = list(pickle.load(f)) # HER if self.hyper_params["USE_HER"]: self.her = HER(self.args.demo_path) self.transitions_epi: list = list() self.desired_state = np.zeros((1,)) self.hook_transition = True demo = self.her.generate_demo_transitions(demo) if not self.args.test: # Replay buffers self.demo_memory = ReplayBuffer( len(demo), self.hyper_params["DEMO_BATCH_SIZE"] ) self.demo_memory.extend(demo) self.memory = ReplayBuffer( self.hyper_params["BUFFER_SIZE"], self.hyper_params["BATCH_SIZE"] ) # set hyper parameters self.lambda1 = self.hyper_params["LAMBDA1"] self.lambda2 = ( self.hyper_params["LAMBDA2"] / self.hyper_params["DEMO_BATCH_SIZE"] ) def select_action(self, state: np.ndarray) -> np.ndarray: """Select an action from the input space.""" state_ = state # HER if self.hyper_params["USE_HER"]: self.desired_state = self.her.sample_desired_state() state = np.concatenate((state, self.desired_state), axis=-1) selected_action = SACAgent.select_action(self, state) self.curr_state = state_ return selected_action def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]: """Take an action and return the response of the env.""" next_state, reward, done = SACAgent.step(self, action) if not self.args.test and self.hyper_params["USE_HER"]: self.transitions_epi.append(self.hooked_transition) if done: # insert generated transitions if the episode is done transitions = self.her.generate_transitions( self.transitions_epi, self.desired_state ) self.memory.extend(transitions) return next_state, reward, done def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" experiences = self.memory.sample() demos = self.demo_memory.sample() states, actions, rewards, next_states, dones = experiences demo_states, demo_actions, _, _, _ = demos new_actions, log_prob, pre_tanh_value, mu, std = self.actor(states) pred_actions, _, _, _, _ = self.actor(demo_states) # train alpha if self.hyper_params["AUTO_ENTROPY_TUNING"]: alpha_loss = ( -self.log_alpha * (log_prob + self.target_entropy).detach() ).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = torch.zeros(1) alpha = self.hyper_params["W_ENTROPY"] # Q function loss masks = 1 - dones q_1_pred = self.qf_1(states, actions) q_2_pred = self.qf_2(states, actions) v_target = self.vf_target(next_states) q_target = rewards + self.hyper_params["GAMMA"] * v_target * masks qf_1_loss = F.mse_loss(q_1_pred, q_target.detach()) qf_2_loss = F.mse_loss(q_2_pred, q_target.detach()) # V function loss v_pred = self.vf(states) q_pred = torch.min( self.qf_1(states, new_actions), self.qf_2(states, new_actions) ) v_target = q_pred - alpha * log_prob vf_loss = F.mse_loss(v_pred, v_target.detach()) # train Q functions self.qf_1_optimizer.zero_grad() qf_1_loss.backward() self.qf_1_optimizer.step() self.qf_2_optimizer.zero_grad() qf_2_loss.backward() self.qf_2_optimizer.step() # train V function self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() if self.total_step % self.hyper_params["DELAYED_UPDATE"] == 0: # bc loss qf_mask = torch.gt( self.qf_1(demo_states, demo_actions), self.qf_1(demo_states, pred_actions), ).to(device) qf_mask = qf_mask.float() n_qf_mask = int(qf_mask.sum().item()) if n_qf_mask == 0: bc_loss = torch.zeros(1, device=device) else: bc_loss = ( torch.mul(pred_actions, qf_mask) - torch.mul(demo_actions, qf_mask) ).pow(2).sum() / n_qf_mask # actor loss advantage = q_pred - v_pred.detach() actor_loss = (alpha * log_prob - advantage).mean() actor_loss = self.lambda1 * actor_loss + self.lambda2 * bc_loss # regularization if not self.is_discrete: # iff the action is continuous mean_reg = self.hyper_params["W_MEAN_REG"] * mu.pow(2).mean() std_reg = self.hyper_params["W_STD_REG"] * std.pow(2).mean() pre_activation_reg = self.hyper_params["W_PRE_ACTIVATION_REG"] * ( pre_tanh_value.pow(2).sum(dim=-1).mean() ) actor_reg = mean_reg + std_reg + pre_activation_reg # actor loss + regularization actor_loss += actor_reg # train actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # update target networks common_utils.soft_update(self.vf, self.vf_target, self.hyper_params["TAU"]) else: actor_loss = torch.zeros(1) return ( actor_loss.data, qf_1_loss.data, qf_2_loss.data, vf_loss.data, alpha_loss.data, )
class Agent(DDPGAgent): """ActorCritic interacting with environment. Attributes: memory (ReplayBuffer): replay memory demo_memory (ReplayBuffer): replay memory for demo her (HER): hinsight experience replay transitions_epi (list): transitions per episode (for HER) goal_state (np.ndarray): goal state to generate concatenated states total_step (int): total step numbers episode_step (int): step number of the current episode """ # pylint: disable=attribute-defined-outside-init def _initialize(self): """Initialize non-common things.""" # load demo replay memory with open(self.args.demo_path, "rb") as f: demo = list(pickle.load(f)) # HER if self.hyper_params["USE_HER"]: self.her = HER(self.args.demo_path) self.transitions_epi: list = list() self.desired_state = np.zeros((1, )) self.hook_transition = True demo = self.her.generate_demo_transitions(demo) if not self.args.test: # Replay buffers demo_batch_size = self.hyper_params["DEMO_BATCH_SIZE"] self.demo_memory = ReplayBuffer(len(demo), demo_batch_size) self.demo_memory.extend(demo) self.memory = ReplayBuffer(self.hyper_params["BUFFER_SIZE"], self.hyper_params["BATCH_SIZE"]) # set hyper parameters self.lambda1 = self.hyper_params["LAMBDA1"] self.lambda2 = self.hyper_params["LAMBDA2"] / demo_batch_size def select_action(self, state: np.ndarray) -> np.ndarray: """Select an action from the input space.""" state_ = state if self.hyper_params["USE_HER"]: self.desired_state = self.her.sample_desired_state() state = np.concatenate((state, self.desired_state), axis=-1) selected_action = DDPGAgent.select_action(self, state) self.curr_state = state_ return selected_action def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]: """Take an action and return the response of the env.""" next_state, reward, done = DDPGAgent.step(self, action) if not self.args.test and self.hyper_params["USE_HER"]: self.transitions_epi.append(self.hooked_transition) if done: # insert generated transitions if the episode is done transitions = self.her.generate_transitions( self.transitions_epi, self.desired_state) self.memory.extend(transitions) return next_state, reward, done def update_model(self) -> Tuple[torch.Tensor, torch.Tensor]: """Train the model after each episode.""" experiences = self.memory.sample() demos = self.demo_memory.sample() exp_states, exp_actions, exp_rewards, exp_next_states, exp_dones = experiences demo_states, demo_actions, demo_rewards, demo_next_states, demo_dones = demos states = torch.cat((exp_states, demo_states), dim=0) actions = torch.cat((exp_actions, demo_actions), dim=0) rewards = torch.cat((exp_rewards, demo_rewards), dim=0) next_states = torch.cat((exp_next_states, demo_next_states), dim=0) dones = torch.cat((exp_dones, demo_dones), dim=0) # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise masks = 1 - dones next_actions = self.actor_target(next_states) next_values = self.critic_target( torch.cat((next_states, next_actions), dim=-1)) curr_returns = rewards + (self.hyper_params["GAMMA"] * next_values * masks) curr_returns = curr_returns.to(device) # critic loss values = self.critic(torch.cat((states, actions), dim=-1)) critic_loss = F.mse_loss(values, curr_returns) # train critic self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # policy loss actions = self.actor(states) policy_loss = -self.critic(torch.cat((states, actions), dim=-1)).mean() # bc loss pred_actions = self.actor(demo_states) qf_mask = torch.gt( self.critic(torch.cat((demo_states, demo_actions), dim=-1)), self.critic(torch.cat((demo_states, pred_actions), dim=-1)), ).to(device) qf_mask = qf_mask.float() n_qf_mask = int(qf_mask.sum().item()) if n_qf_mask == 0: bc_loss = torch.zeros(1, device=device) else: bc_loss = (torch.mul(pred_actions, qf_mask) - torch.mul( demo_actions, qf_mask)).pow(2).sum() / n_qf_mask # train actor: pg loss + BC loss actor_loss = self.lambda1 * policy_loss + self.lambda2 * bc_loss self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.actor, self.actor_target, tau) common_utils.soft_update(self.critic, self.critic_target, tau) return actor_loss.data, critic_loss.data
class Agent(AbstractAgent): """SAC agent interacting with environment. Attrtibutes: memory (ReplayBuffer): replay memory actor (nn.Module): actor model to select actions actor_target (nn.Module): target actor model to select actions actor_optimizer (Optimizer): optimizer for training actor critic_1 (nn.Module): critic model to predict state values critic_2 (nn.Module): critic model to predict state values critic_target1 (nn.Module): target critic model to predict state values critic_target2 (nn.Module): target critic model to predict state values critic_optimizer1 (Optimizer): optimizer for training critic_1 critic_optimizer2 (Optimizer): optimizer for training critic_2 curr_state (np.ndarray): temporary storage of the current state target_entropy (int): desired entropy used for the inequality constraint alpha (torch.Tensor): weight for entropy alpha_optimizer (Optimizer): optimizer for alpha hyper_params (dict): hyper-parameters total_step (int): total step numbers episode_step (int): step number of the current episode i_episode (int): current episode number her (HER): hinsight experience replay """ def __init__(self, env, args, hyper_params, models, optims, target_entropy, her): """Initialization. Args: env (gym.Env): openAI Gym environment args (argparse.Namespace): arguments including hyperparameters and training settings hyper_params (dict): hyper-parameters models (tuple): models including actor and critic optims (tuple): optimizers for actor and critic target_entropy (float): target entropy for the inequality constraint her (HER): hinsight experience replay """ AbstractAgent.__init__(self, env, args) self.actor, self.vf, self.vf_target, self.qf_1, self.qf_2 = models self.actor_optimizer, self.vf_optimizer = optims[0:2] self.qf_1_optimizer, self.qf_2_optimizer = optims[2:4] self.hyper_params = hyper_params self.curr_state = np.zeros((1, )) self.total_step = 0 self.episode_step = 0 self.i_episode = 0 self.her = her # automatic entropy tuning if self.hyper_params["AUTO_ENTROPY_TUNING"]: self.target_entropy = target_entropy self.log_alpha = torch.zeros(1, requires_grad=True, device=device) self.alpha_optimizer = optim.Adam( [self.log_alpha], lr=self.hyper_params["LR_ENTROPY"]) # load the optimizer and model parameters if args.load_from is not None and os.path.exists(args.load_from): self.load_params(args.load_from) self._initialize() def _initialize(self): """Initialize non-common things.""" if not self.args.test: # replay memory self.memory = ReplayBuffer(self.hyper_params["BUFFER_SIZE"], self.hyper_params["BATCH_SIZE"]) # HER if self.hyper_params["USE_HER"]: # load demo replay memory with open(self.args.demo_path, "rb") as f: demo = pickle.load(f) if self.hyper_params["DESIRED_STATES_FROM_DEMO"]: self.her.fetch_desired_states_from_demo(demo) self.transitions_epi = list() self.desired_state = np.zeros((1, )) demo = self.her.generate_demo_transitions(demo) if not self.args.test: # Replay buffers self.memory = ReplayBuffer(self.hyper_params["BUFFER_SIZE"], self.hyper_params["BATCH_SIZE"]) def _preprocess_state(self, state): """Preprocess state so that actor selects an action.""" if self.hyper_params["USE_HER"]: self.desired_state = self.her.get_desired_state() state = np.concatenate((state, self.desired_state), axis=-1) state = torch.FloatTensor(state).to(device) return state def _add_transition_to_memory(self, transition): """Add 1 step and n step transitions to memory.""" if self.hyper_params["USE_HER"]: self.transitions_epi.append(transition) done = transition[ -1] or self.episode_step == self.args.max_episode_steps if done: # insert generated transitions if the episode is done transitions = self.her.generate_transitions( self.transitions_epi, self.desired_state, self.hyper_params["SUCCESS_SCORE"], ) self.memory.extend(transitions) self.transitions_epi = list() else: self.memory.add(*transition) def select_action(self, state): """Select an action from the input space.""" self.curr_state = state state = self._preprocess_state(state) # if initial random action should be conducted if (self.total_step < self.hyper_params["INITIAL_RANDOM_ACTION"] and not self.args.test): return self.env.action_space.sample() if self.args.test: _, _, _, selected_action, _ = self.actor(state) else: selected_action, _, _, _, _ = self.actor(state) return selected_action.detach().cpu().numpy() def step(self, action): """Take an action and return the response of the env.""" self.total_step += 1 self.episode_step += 1 next_state, reward, done, _ = self.env.step(action) if not self.args.test: # if the last state is not a terminal state, store done as false done_bool = (False if self.episode_step == self.args.max_episode_steps else done) transition = (self.curr_state, action, reward, next_state, done_bool) self._add_transition_to_memory(transition) return next_state, reward, done def update_model(self, experiences): """Train the model after each episode.""" states, actions, rewards, next_states, dones = experiences new_actions, log_prob, pre_tanh_value, mu, std = self.actor(states) # train alpha if self.hyper_params["AUTO_ENTROPY_TUNING"]: alpha_loss = (-self.log_alpha * (log_prob + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = torch.zeros(1) alpha = self.hyper_params["W_ENTROPY"] # Q function loss masks = 1 - dones q_1_pred = self.qf_1(states, actions) q_2_pred = self.qf_2(states, actions) v_target = self.vf_target(next_states) q_target = rewards + self.hyper_params["GAMMA"] * v_target * masks qf_1_loss = F.mse_loss(q_1_pred, q_target.detach()) qf_2_loss = F.mse_loss(q_2_pred, q_target.detach()) # V function loss v_pred = self.vf(states) q_pred = torch.min(self.qf_1(states, new_actions), self.qf_2(states, new_actions)) v_target = q_pred - alpha * log_prob vf_loss = F.mse_loss(v_pred, v_target.detach()) # train Q functions self.qf_1_optimizer.zero_grad() qf_1_loss.backward() self.qf_1_optimizer.step() self.qf_2_optimizer.zero_grad() qf_2_loss.backward() self.qf_2_optimizer.step() # train V function self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() if self.total_step % self.hyper_params["DELAYED_UPDATE"] == 0: # actor loss advantage = q_pred - v_pred.detach() actor_loss = (alpha * log_prob - advantage).mean() # regularization mean_reg = self.hyper_params["W_MEAN_REG"] * mu.pow(2).mean() std_reg = self.hyper_params["W_STD_REG"] * std.pow(2).mean() pre_activation_reg = self.hyper_params["W_PRE_ACTIVATION_REG"] * ( pre_tanh_value.pow(2).sum(dim=-1).mean()) actor_reg = mean_reg + std_reg + pre_activation_reg # actor loss + regularization actor_loss += actor_reg # train actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # update target networks common_utils.soft_update(self.vf, self.vf_target, self.hyper_params["TAU"]) else: actor_loss = torch.zeros(1) return ( actor_loss.data, qf_1_loss.data, qf_2_loss.data, vf_loss.data, alpha_loss.data, ) def load_params(self, path): """Load model and optimizer parameters.""" if not os.path.exists(path): print("[ERROR] the input path does not exist. ->", path) return params = torch.load(path) self.actor.load_state_dict(params["actor"]) self.qf_1.load_state_dict(params["qf_1"]) self.qf_2.load_state_dict(params["qf_2"]) self.vf.load_state_dict(params["vf"]) self.vf_target.load_state_dict(params["vf_target"]) self.actor_optimizer.load_state_dict(params["actor_optim"]) self.qf_1_optimizer.load_state_dict(params["qf_1_optim"]) self.qf_2_optimizer.load_state_dict(params["qf_2_optim"]) self.vf_optimizer.load_state_dict(params["vf_optim"]) if self.hyper_params["AUTO_ENTROPY_TUNING"]: self.alpha_optimizer.load_state_dict(params["alpha_optim"]) print("[INFO] loaded the model and optimizer from", path) def save_params(self, n_episode): """Save model and optimizer parameters.""" params = { "actor": self.actor.state_dict(), "qf_1": self.qf_1.state_dict(), "qf_2": self.qf_2.state_dict(), "vf": self.vf.state_dict(), "vf_target": self.vf_target.state_dict(), "actor_optim": self.actor_optimizer.state_dict(), "qf_1_optim": self.qf_1_optimizer.state_dict(), "qf_2_optim": self.qf_2_optimizer.state_dict(), "vf_optim": self.vf_optimizer.state_dict(), } if self.hyper_params["AUTO_ENTROPY_TUNING"]: params["alpha_optim"] = self.alpha_optimizer.state_dict() AbstractAgent.save_params(self, params, n_episode) def write_log(self, i, loss, score=0.0, delayed_update=1): """Write log about loss and score""" total_loss = loss.sum() print( "[INFO] episode %d, episode_step %d, total step %d, total score: %d\n" "total loss: %.3f actor_loss: %.3f qf_1_loss: %.3f qf_2_loss: %.3f " "vf_loss: %.3f alpha_loss: %.3f\n" % ( i, self.episode_step, self.total_step, score, total_loss, loss[0] * delayed_update, # actor loss loss[1], # qf_1 loss loss[2], # qf_2 loss loss[3], # vf loss loss[4], # alpha loss )) if self.args.log: wandb.log({ "score": score, "total loss": total_loss, "actor loss": loss[0] * delayed_update, "qf_1 loss": loss[1], "qf_2 loss": loss[2], "vf loss": loss[3], "alpha loss": loss[4], }) def train(self): """Train the agent.""" # logger if self.args.log: wandb.init() wandb.config.update(self.hyper_params) wandb.config.update(vars(self.args)) wandb.watch([self.actor, self.vf, self.qf_1, self.qf_2], log="parameters") for self.i_episode in range(1, self.args.episode_num + 1): state = self.env.reset() done = False score = 0 self.episode_step = 0 loss_episode = list() while not done: if self.args.render and self.i_episode >= self.args.render_after: self.env.render() action = self.select_action(state) next_state, reward, done = self.step(action) state = next_state score += reward # training if len(self.memory) >= self.hyper_params["BATCH_SIZE"]: experiences = self.memory.sample() loss = self.update_model(experiences) loss_episode.append(loss) # for logging # logging if loss_episode: avg_loss = np.vstack(loss_episode).mean(axis=0) self.write_log(self.i_episode, avg_loss, score, self.hyper_params["DELAYED_UPDATE"]) if self.i_episode % self.args.save_period == 0: self.save_params(self.i_episode) # termination self.env.close()
class BCDDPGAgent(DDPGAgent): """BC with DDPG agent interacting with environment. Attributes: her (HER): hinsight experience replay transitions_epi (list): transitions per episode (for HER) desired_state (np.ndarray): desired state of current episode memory (ReplayBuffer): replay memory demo_memory (ReplayBuffer): replay memory for demo lambda1 (float): proportion of policy loss lambda2 (float): proportion of BC loss """ def __init__( self, env: gym.Env, args: argparse.Namespace, hyper_params: dict, models: tuple, optims: tuple, noise: OUNoise, her: HER, ): """Initialization. Args: her (HER): hinsight experience replay """ self.her = her DDPGAgent.__init__(self, env, args, hyper_params, models, optims, noise) # pylint: disable=attribute-defined-outside-init def _initialize(self): """Initialize non-common things.""" # load demo replay memory with open(self.args.demo_path, "rb") as f: demo = list(pickle.load(f)) # HER if self.hyper_params["USE_HER"]: if self.hyper_params["DESIRED_STATES_FROM_DEMO"]: self.her.fetch_desired_states_from_demo(demo) self.transitions_epi: list = list() self.desired_state = np.zeros((1,)) demo = self.her.generate_demo_transitions(demo) if not self.args.test: # Replay buffers demo_batch_size = self.hyper_params["DEMO_BATCH_SIZE"] self.demo_memory = ReplayBuffer(len(demo), demo_batch_size) self.demo_memory.extend(demo) self.memory = ReplayBuffer( self.hyper_params["BUFFER_SIZE"], self.hyper_params["BATCH_SIZE"] ) # set hyper parameters self.lambda1 = self.hyper_params["LAMBDA1"] self.lambda2 = self.hyper_params["LAMBDA2"] / demo_batch_size def _preprocess_state(self, state: np.ndarray) -> torch.Tensor: """Preprocess state so that actor selects an action.""" if self.hyper_params["USE_HER"]: self.desired_state = self.her.get_desired_state() state = np.concatenate((state, self.desired_state), axis=-1) state = torch.FloatTensor(state).to(device) return state def _add_transition_to_memory(self, transition: Tuple[np.ndarray, ...]): """Add 1 step and n step transitions to memory.""" if self.hyper_params["USE_HER"]: self.transitions_epi.append(transition) done = transition[-1] or self.episode_step == self.args.max_episode_steps if done: # insert generated transitions if the episode is done transitions = self.her.generate_transitions( self.transitions_epi, self.desired_state, self.hyper_params["SUCCESS_SCORE"], ) self.memory.extend(transitions) self.transitions_epi.clear() else: self.memory.add(transition) def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" experiences = self.memory.sample() demos = self.demo_memory.sample() exp_states, exp_actions, exp_rewards, exp_next_states, exp_dones = experiences demo_states, demo_actions, demo_rewards, demo_next_states, demo_dones = demos states = torch.cat((exp_states, demo_states), dim=0) actions = torch.cat((exp_actions, demo_actions), dim=0) rewards = torch.cat((exp_rewards, demo_rewards), dim=0) next_states = torch.cat((exp_next_states, demo_next_states), dim=0) dones = torch.cat((exp_dones, demo_dones), dim=0) # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise masks = 1 - dones next_actions = self.actor_target(next_states) next_values = self.critic_target(torch.cat((next_states, next_actions), dim=-1)) curr_returns = rewards + (self.hyper_params["GAMMA"] * next_values * masks) curr_returns = curr_returns.to(device) # critic loss values = self.critic(torch.cat((states, actions), dim=-1)) critic_loss = F.mse_loss(values, curr_returns) # train critic gradient_clip_cr = self.hyper_params["GRADIENT_CLIP_CR"] self.critic_optimizer.zero_grad() critic_loss.backward() nn.utils.clip_grad_norm_(self.critic.parameters(), gradient_clip_cr) self.critic_optimizer.step() # policy loss actions = self.actor(states) policy_loss = -self.critic(torch.cat((states, actions), dim=-1)).mean() # bc loss pred_actions = self.actor(demo_states) qf_mask = torch.gt( self.critic(torch.cat((demo_states, demo_actions), dim=-1)), self.critic(torch.cat((demo_states, pred_actions), dim=-1)), ).to(device) qf_mask = qf_mask.float() n_qf_mask = int(qf_mask.sum().item()) if n_qf_mask == 0: bc_loss = torch.zeros(1, device=device) else: bc_loss = ( torch.mul(pred_actions, qf_mask) - torch.mul(demo_actions, qf_mask) ).pow(2).sum() / n_qf_mask # train actor: pg loss + BC loss actor_loss = self.lambda1 * policy_loss + self.lambda2 * bc_loss gradient_clip_ac = self.hyper_params["GRADIENT_CLIP_AC"] self.actor_optimizer.zero_grad() actor_loss.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), gradient_clip_ac) self.actor_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.actor, self.actor_target, tau) common_utils.soft_update(self.critic, self.critic_target, tau) return actor_loss.item(), critic_loss.item(), n_qf_mask def write_log(self, i: int, loss: np.ndarray, score: int, avg_time_cost): """Write log about loss and score""" total_loss = loss.sum() print( "[INFO] episode %d, episode step: %d, total step: %d, total score: %d\n" "total loss: %f actor_loss: %.3f critic_loss: %.3f, n_qf_mask: %d " "(spent %.6f sec/step)\n" % ( i, self.episode_step, self.total_step, score, total_loss, loss[0], loss[1], loss[2], avg_time_cost, ) # actor loss # critic loss ) if self.args.log: wandb.log( { "score": score, "total loss": total_loss, "actor loss": loss[0], "critic loss": loss[1], "time per each step": avg_time_cost, } )
class Agent(SACAgent): """BC with SAC agent interacting with environment. Attrtibutes: HER (AbstractHER): hinsight experience replay transitions_epi (list): transitions per episode (for HER) desired_state (np.ndarray): desired state of current episode memory (ReplayBuffer): replay memory demo_memory (ReplayBuffer): replay memory for demo lambda1 (float): proportion of policy loss lambda2 (float): proportion of BC loss """ def __init__( self, env: gym.Env, args: argparse.Namespace, hyper_params: dict, models: tuple, optims: tuple, target_entropy: float, HER: AbstractHER, ): """Initialization. Args: HER (AbstractHER): hinsight experience replay """ self.HER = HER SACAgent.__init__(self, env, args, hyper_params, models, optims, target_entropy) # pylint: disable=attribute-defined-outside-init def _initialize(self): """Initialize non-common things.""" # load demo replay memory with open(self.args.demo_path, "rb") as f: demo = list(pickle.load(f)) # HER if self.hyper_params["USE_HER"]: self.her = self.HER() if self.hyper_params["DESIRED_STATES_FROM_DEMO"]: self.her.fetch_desired_states_from_demo(demo) self.transitions_epi: list = list() self.desired_state = np.zeros((1, )) demo = self.her.generate_demo_transitions(demo) if not self.args.test: # Replay buffers demo_batch_size = self.hyper_params["DEMO_BATCH_SIZE"] self.demo_memory = ReplayBuffer(len(demo), demo_batch_size) self.demo_memory.extend(demo) self.memory = ReplayBuffer(self.hyper_params["BUFFER_SIZE"], self.hyper_params["BATCH_SIZE"]) # set hyper parameters self.lambda1 = self.hyper_params["LAMBDA1"] self.lambda2 = self.hyper_params["LAMBDA2"] / demo_batch_size def _preprocess_state(self, state: np.ndarray) -> torch.Tensor: """Preprocess state so that actor selects an action.""" if self.hyper_params["USE_HER"]: self.desired_state = self.her.get_desired_state() state = np.concatenate((state, self.desired_state), axis=-1) state = torch.FloatTensor(state).to(device) return state def _add_transition_to_memory(self, transition: Tuple[np.ndarray, ...]): """Add 1 step and n step transitions to memory.""" if self.hyper_params["USE_HER"]: self.transitions_epi.append(transition) done = transition[ -1] or self.episode_step == self.args.max_episode_steps if done: # insert generated transitions if the episode is done transitions = self.her.generate_transitions( self.transitions_epi, self.desired_state, self.hyper_params["SUCCESS_SCORE"], ) self.memory.extend(transitions) self.transitions_epi.clear() else: self.memory.add(*transition) def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" experiences = self.memory.sample() demos = self.demo_memory.sample() states, actions, rewards, next_states, dones = experiences demo_states, demo_actions, _, _, _ = demos new_actions, log_prob, pre_tanh_value, mu, std = self.actor(states) pred_actions, _, _, _, _ = self.actor(demo_states) # train alpha if self.hyper_params["AUTO_ENTROPY_TUNING"]: alpha_loss = (-self.log_alpha * (log_prob + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = torch.zeros(1) alpha = self.hyper_params["W_ENTROPY"] # Q function loss masks = 1 - dones q_1_pred = self.qf_1(states, actions) q_2_pred = self.qf_2(states, actions) v_target = self.vf_target(next_states) q_target = rewards + self.hyper_params["GAMMA"] * v_target * masks qf_1_loss = F.mse_loss(q_1_pred, q_target.detach()) qf_2_loss = F.mse_loss(q_2_pred, q_target.detach()) # V function loss v_pred = self.vf(states) q_pred = torch.min(self.qf_1(states, new_actions), self.qf_2(states, new_actions)) v_target = q_pred - alpha * log_prob vf_loss = F.mse_loss(v_pred, v_target.detach()) # train Q functions self.qf_1_optimizer.zero_grad() qf_1_loss.backward() self.qf_1_optimizer.step() self.qf_2_optimizer.zero_grad() qf_2_loss.backward() self.qf_2_optimizer.step() # train V function self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() if self.total_step % self.hyper_params["DELAYED_UPDATE"] == 0: # bc loss qf_mask = torch.gt( self.qf_1(demo_states, demo_actions), self.qf_1(demo_states, pred_actions), ).to(device) qf_mask = qf_mask.float() n_qf_mask = int(qf_mask.sum().item()) if n_qf_mask == 0: bc_loss = torch.zeros(1, device=device) else: bc_loss = (torch.mul(pred_actions, qf_mask) - torch.mul( demo_actions, qf_mask)).pow(2).sum() / n_qf_mask # actor loss advantage = q_pred - v_pred.detach() actor_loss = (alpha * log_prob - advantage).mean() actor_loss = self.lambda1 * actor_loss + self.lambda2 * bc_loss # regularization if not self.is_discrete: # iff the action is continuous mean_reg = self.hyper_params["W_MEAN_REG"] * mu.pow(2).mean() std_reg = self.hyper_params["W_STD_REG"] * std.pow(2).mean() pre_activation_reg = self.hyper_params[ "W_PRE_ACTIVATION_REG"] * (pre_tanh_value.pow(2).sum( dim=-1).mean()) actor_reg = mean_reg + std_reg + pre_activation_reg # actor loss + regularization actor_loss += actor_reg # train actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # update target networks common_utils.soft_update(self.vf, self.vf_target, self.hyper_params["TAU"]) else: actor_loss = torch.zeros(1) n_qf_mask = 0 return ( actor_loss.data, qf_1_loss.data, qf_2_loss.data, vf_loss.data, alpha_loss.data, n_qf_mask, ) def write_log(self, i: int, loss: np.ndarray, score: float = 0.0, delayed_update: int = 1): """Write log about loss and score""" total_loss = loss.sum() print( "[INFO] episode %d, episode_step %d, total step %d, total score: %d\n" "total loss: %.3f actor_loss: %.3f qf_1_loss: %.3f qf_2_loss: %.3f " "vf_loss: %.3f alpha_loss: %.3f n_qf_mask: %d\n" % ( i, self.episode_step, self.total_step, score, total_loss, loss[0] * delayed_update, # actor loss loss[1], # qf_1 loss loss[2], # qf_2 loss loss[3], # vf loss loss[4], # alpha loss loss[5], # n_qf_mask )) if self.args.log: wandb.log({ "score": score, "total loss": total_loss, "actor loss": loss[0] * delayed_update, "qf_1 loss": loss[1], "qf_2 loss": loss[2], "vf loss": loss[3], "alpha loss": loss[4], })