def update_decoder( self, batch: ReplayBufferSample, task_info: TaskInfo, logger: Logger, step: int, kwargs_to_compute_gradient: Dict[str, Any], ) -> None: obs = batch.env_obs target_obs = batch.env_obs mtobs = MTObs(env_obs=obs, task_obs=None, task_info=task_info) h = self.critic.encode(mtobs=mtobs) if target_obs.dim() == 4: # preprocess images to be in [-0.5, 0.5] range target_obs = agent_utils.preprocess_obs(target_obs) rec_obs = self.decoder(h) if self.loss_reduction == "mean": rec_loss = F.mse_loss(target_obs, rec_obs).mean() elif self.loss_reduction == "none": rec_loss = F.mse_loss(target_obs, rec_obs, reduction="none") rec_loss = rec_loss.view(rec_loss.shape[0], -1).mean(dim=1, keepdim=True) # add L2 penalty on latent representation # see https://arxiv.org/pdf/1903.12436.pdf if self.loss_reduction == "mean": latent_loss = (0.5 * h.pow(2).sum(1)).mean() elif self.loss_reduction == "none": latent_loss = 0.5 * h.pow(2).sum(1, keepdim=True) loss = rec_loss + self.decoder_latent_lambda * latent_loss self.encoder_optimizer.zero_grad() self.decoder_optimizer.zero_grad() component_names = ["encoder", "decoder"] parameters: List[ParameterType] = [] for name in component_names: self._optimizers[name].zero_grad() parameters += self.get_parameters(name) if task_info.compute_grad: component_names.append("task_encoder") kwargs_to_compute_gradient["retain_graph"] = True parameters += self.get_parameters("task_encoder") self._compute_gradient(loss, parameters=parameters, step=step, component_names=component_names, **kwargs_to_compute_gradient) self.encoder_optimizer.step() self.decoder_optimizer.step() if self.loss_reduction == "mean": loss_to_log = loss elif self.loss_reduction == "none": loss_to_log = loss.mean() logger.log("train/ae_loss", loss_to_log, step)
def __init__(self, config: ConfigType, experiment_id: str = "0"): """Experiment Class to manage the lifecycle of a model. Args: config (ConfigType): experiment_id (str, optional): Defaults to "0". """ self.id = experiment_id self.config = config self.device = torch.device(self.config.setup.device) self.get_env_metadata = get_env_metadata self.envs, self.env_metadata = self.build_envs() key = "ordered_task_list" if key in self.env_metadata and self.env_metadata[key]: ordered_task_dict = { task: index for index, task in enumerate(self.env_metadata[key]) } else: ordered_task_dict = {} key = "envs_to_exclude_during_training" if key in self.config.experiment and self.config.experiment[key]: self.envs_to_exclude_during_training = { ordered_task_dict[task] for task in self.config.experiment[key] } print( f"Excluding the following environments: {self.envs_to_exclude_during_training}" ) else: self.envs_to_exclude_during_training = set() self.action_space = self.env_metadata["action_space"] assert self.action_space.low.min() >= -1 assert self.action_space.high.max() <= 1 self.env_obs_space = self.env_metadata["env_obs_space"] env_obs_shape = self.env_obs_space.shape action_shape = self.action_space.shape self.config = prepare_config(config=self.config, env_metadata=self.env_metadata) self.agent = hydra.utils.instantiate( self.config.agent.builder, env_obs_shape=env_obs_shape, action_shape=action_shape, action_range=[ float(self.action_space.low.min()), float(self.action_space.high.max()), ], device=self.device, ) self.video_dir = utils.make_dir( os.path.join(self.config.setup.save_dir, "video")) self.model_dir = utils.make_dir( os.path.join(self.config.setup.save_dir, "model")) self.buffer_dir = utils.make_dir( os.path.join(self.config.setup.save_dir, "buffer")) self.video = video.VideoRecorder( self.video_dir if self.config.experiment.save_video else None) self.replay_buffer = hydra.utils.instantiate( self.config.replay_buffer, device=self.device, env_obs_shape=env_obs_shape, task_obs_shape=(1, ), action_shape=action_shape, ) self.start_step = 0 should_resume_experiment = self.config.experiment.should_resume if should_resume_experiment: self.start_step = self.agent.load_latest_step( model_dir=self.model_dir) self.replay_buffer.load(save_dir=self.buffer_dir) self.logger = Logger( self.config.setup.save_dir, config=self.config, retain_logs=should_resume_experiment, ) self.max_episode_steps = self.env_metadata[ "max_episode_steps"] # maximum steps that the agent can take in one environment. self.startup_logs()
def update_decoder( # type: ignore[override] self, batch: ReplayBufferSample, task_info: TaskInfo, logger: Logger, step: int, kwargs_to_compute_gradient: Dict[str, Any], ): obs = batch.env_obs action = batch.action target_obs = batch.next_env_obs # uses transition model mtobs = MTObs(env_obs=obs, task_obs=None, task_info=task_info) h = self.critic.encode(mtobs=mtobs) next_h = self.transition_model.sample_prediction( torch.cat([h, action], dim=1)) if target_obs.dim() == 4: # preprocess images to be in [-0.5, 0.5] range target_obs = agent_utils.preprocess_obs(target_obs) rec_obs = self.decoder(next_h) # type: ignore[misc] if self.loss_reduction == "mean": rec_loss = F.mse_loss(target_obs, rec_obs).mean() elif self.loss_reduction == "none": rec_loss = F.mse_loss(target_obs, rec_obs, reduction="none") rec_loss = rec_loss.view(rec_loss.shape[0], -1).mean(dim=1, keepdim=True) # add L2 penalty on latent representation # see https://arxiv.org/pdf/1903.12436.pdf if self.loss_reduction == "mean": latent_loss = (0.5 * h.pow(2).sum(1)).mean() elif self.loss_reduction == "none": latent_loss = 0.5 * h.pow(2).sum(1, keepdim=True) loss = rec_loss + self.decoder_latent_lambda * latent_loss component_names = ["transition_model", "decoder"] if not self.is_encoder_identity: component_names.append("encoder") parameters: List[ParameterType] = [] for name in component_names: self._optimizers[name].zero_grad() parameters += self.get_parameters(name=name) component_names_to_pass = deepcopy(component_names) if task_info.compute_grad: component_names_to_pass.append("task_encoder") kwargs_to_compute_gradient["retain_graph"] = True parameters += self.get_parameters("task_encoder") self._compute_gradient( loss, parameters=parameters, step=step, component_names=component_names_to_pass, **kwargs_to_compute_gradient, ) for name in component_names: self._optimizers[name].step() if self.loss_reduction == "mean": loss_to_log = loss elif self.loss_reduction == "none": loss_to_log = loss.mean() logger.log("train/ae_loss", loss_to_log, step)
def update_transition_reward_model( self, batch: ReplayBufferSample, task_info: TaskInfo, logger: Logger, step: int, kwargs_to_compute_gradient: Dict[str, Any], ): obs = batch.env_obs action = batch.action next_obs = batch.next_env_obs reward = batch.reward mtobs = MTObs(env_obs=obs, task_obs=None, task_info=task_info) h = self.critic.encode(mtobs=mtobs) pred_next_latent_mu, pred_next_latent_sigma = self.transition_model( torch.cat([h, action], dim=1)) if pred_next_latent_sigma is None: pred_next_latent_sigma = torch.ones_like(pred_next_latent_mu) mtobs = MTObs(env_obs=next_obs, task_obs=None, task_info=task_info) next_h = self.critic.encode(mtobs=mtobs) diff = (pred_next_latent_mu - next_h.detach()) / pred_next_latent_sigma if self.loss_reduction == "mean": loss = torch.mean(0.5 * diff.pow(2) + torch.log(pred_next_latent_sigma)) loss_to_log = loss elif self.loss_reduction == "none": loss = (0.5 * diff.pow(2) + torch.log(pred_next_latent_sigma)).mean(dim=1, keepdim=True) loss_to_log = loss.mean() logger.log("train/ae_transition_loss", loss_to_log, step) pred_next_latent = self.transition_model.sample_prediction( torch.cat([h, action], dim=1)) pred_next_reward = self.reward_decoder(pred_next_latent) reward_loss = F.mse_loss(pred_next_reward, reward, reduction=self.loss_reduction) total_loss = loss + reward_loss parameters: List[ParameterType] = [] component_names = [ "transition_model", "reward_decoder", ] if not self.is_encoder_identity: component_names.append("encoder") for name in component_names: self._optimizers[name].zero_grad() parameters += self.get_parameters(name=name) component_names_to_pass = deepcopy(component_names) if task_info.compute_grad: component_names_to_pass.append("task_encoder") kwargs_to_compute_gradient["retain_graph"] = True parameters += self.get_parameters("task_encoder") self._compute_gradient( loss=total_loss, # oooh this order is very important parameters=parameters, step=step, component_names=component_names_to_pass, **kwargs_to_compute_gradient, ) for name in component_names: self._optimizers[name].step()
def update_actor_and_alpha( self, batch: ReplayBufferSample, task_info: TaskInfo, logger: Logger, step: int, kwargs_to_compute_gradient: Dict[str, Any], ) -> None: """Update the actor and alpha component. Args: batch (ReplayBufferSample): batch from the replay buffer. task_info (TaskInfo): task_info object. logger ([Logger]): logger object. step (int): step for tracking the training of the agent. kwargs_to_compute_gradient (Dict[str, Any]): """ # detach encoder, so we don't update it with the actor loss suffix = f"_agent_index_{self.index}" mtobs = MTObs(env_obs=batch.env_obs, task_obs=None, task_info=task_info) mu, pi, log_pi, log_std = self.agent.actor(mtobs=mtobs, detach_encoder=True) actor_Q1, actor_Q2 = self.agent.critic(mtobs=mtobs, action=pi, detach_encoder=True) actor_Q = torch.min(actor_Q1, actor_Q2) if self.agent.loss_reduction == "mean": actor_loss = ( self.agent.get_alpha(batch.task_obs).detach() * log_pi - actor_Q).mean() logger.log(f"train/actor_loss{suffix}", actor_loss, step) elif self.agent.loss_reduction == "none": actor_loss = ( self.agent.get_alpha(batch.task_obs).detach() * log_pi - actor_Q) logger.log(f"train/actor_loss{suffix}", actor_loss.mean(), step) logger.log(f"train/actor_target_entropy{suffix}", self.agent.target_entropy, step) entropy = 0.5 * log_std.shape[1] * ( 1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1) logger.log(f"train/actor_entropy{suffix}", entropy.mean(), step) mtobs = MTObs(env_obs=batch.env_obs, task_obs=None, task_info=NoneTaskInfo) distral_mu, _, distral_log_pi, distral_log_std = self.distilled_agent.actor( mtobs=mtobs, detach_encoder=False) distilled_agent_loss = gaussian_kld( mean1=distral_mu, logvar1=2 * distral_log_std, mean2=mu.detach(), logvar2=2 * log_std.detach(), ) batch_size = distilled_agent_loss.shape[0] distilled_agent_loss = torch.sum(distilled_agent_loss) / batch_size logger.log( f"train/actor_distilled_agent_loss{suffix}", distilled_agent_loss.mean(), step, ) distilled_agent_loss = distilled_agent_loss * self.distral_alpha # optimize the actor component_names = ["actor"] parameters: List[ParameterType] = [] for name in component_names: self.agent._optimizers[name].zero_grad() parameters += self.agent.get_parameters(name) if task_info.compute_grad: component_names.append("task_encoder") kwargs_to_compute_gradient["retain_graph"] = True parameters += self.agent.get_parameters("task_encoder") self.agent._compute_gradient( loss=actor_loss, parameters=parameters, step=step, component_names=component_names, **kwargs_to_compute_gradient, ) self.agent.actor_optimizer.step() self.agent.log_alpha_optimizer.zero_grad() if self.agent.loss_reduction == "mean": alpha_loss = ( self.agent.get_alpha(batch.task_obs) * (-log_pi - self.agent.target_entropy).detach()).mean() logger.log(f"train/alpha_loss{suffix}", alpha_loss, step) elif self.agent.loss_reduction == "none": alpha_loss = (self.agent.get_alpha(batch.task_obs) * (-log_pi - self.agent.target_entropy).detach()) logger.log(f"train/alpha_loss{suffix}", alpha_loss.mean(), step) # logger.log("train/alpha_value", self.get_alpha(batch.task_obs), step) self.agent._compute_gradient( loss=alpha_loss, parameters=self.agent.get_parameters(name="log_alpha"), step=step, component_names=["log_alpha"], **kwargs_to_compute_gradient, ) self.agent.log_alpha_optimizer.step() self.distilled_agent._optimizers["actor"].zero_grad() distilled_agent_loss.backward() self.distilled_agent._optimizers["actor"].step()
def update( self, replay_buffer: ReplayBuffer, logger: Logger, step: int, kwargs_to_compute_gradient: Optional[Dict[str, Any]] = None, buffer_index_to_sample: Optional[np.ndarray] = None, ) -> np.ndarray: """Update the agent. Args: replay_buffer (ReplayBuffer): replay buffer to sample the data. logger (Logger): logger for logging. step (int): step for tracking the training progress. kwargs_to_compute_gradient (Optional[Dict[str, Any]], optional): Defaults to None. buffer_index_to_sample (Optional[np.ndarray], optional): if this parameter is specified, use these indices instead of sampling from the replay buffer. If this is set to `None`, sample from the replay buffer. buffer_index_to_sample Defaults to None. Returns: np.ndarray: index sampled (from the replay buffer) to train the model. If buffer_index_to_sample is not set to None, return buffer_index_to_sample. """ if kwargs_to_compute_gradient is None: kwargs_to_compute_gradient = {} if buffer_index_to_sample is None: batch = replay_buffer.sample() else: batch = replay_buffer.sample(buffer_index_to_sample) logger.log("train/batch_reward", batch.reward.mean(), step) if self.should_use_task_encoder: self.task_encoder_optimizer.zero_grad() task_encoding = self.get_task_encoding( env_index=batch.task_obs.squeeze(1), disable_grad=False, modes=["train"], ) else: task_encoding = None # type: ignore[assignment] task_info = self.get_task_info( task_encoding=task_encoding, component_name="critic", env_index=batch.task_obs, ) self.update_critic( batch=batch, task_info=task_info, logger=logger, step=step, kwargs_to_compute_gradient=deepcopy(kwargs_to_compute_gradient), ) if step % self.actor_update_freq == 0: task_info = self.get_task_info( task_encoding=task_encoding, component_name="actor", env_index=batch.task_obs, ) self.update_actor_and_alpha( batch=batch, task_info=task_info, logger=logger, step=step, kwargs_to_compute_gradient=deepcopy( kwargs_to_compute_gradient), ) if step % self.critic_target_update_freq == 0: agent_utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, self.critic_tau) agent_utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, self.critic_tau) agent_utils.soft_update_params(self.critic.encoder, self.critic_target.encoder, self.encoder_tau) if ("transition_model" in self._components and "reward_decoder" in self._components): # some of the logic is a bit sketchy here. We will get to it soon. task_info = self.get_task_info( task_encoding=task_encoding, component_name="transition_reward", env_index=batch.task_obs, ) self.update_transition_reward_model( batch=batch, task_info=task_info, logger=logger, step=step, kwargs_to_compute_gradient=deepcopy( kwargs_to_compute_gradient), ) if ("decoder" in self._components # should_update_decoder and self.decoder is not None # type: ignore[attr-defined] and step % self.decoder_update_freq == 0 # type: ignore[attr-defined] ): task_info = self.get_task_info( task_encoding=task_encoding, component_name="decoder", env_index=batch.task_obs, ) self.update_decoder( batch=batch, task_info=task_info, logger=logger, step=step, kwargs_to_compute_gradient=deepcopy( kwargs_to_compute_gradient), ) if self.should_use_task_encoder: task_info = self.get_task_info( task_encoding=task_encoding, component_name="task_encoder", env_index=batch.task_obs, ) self.update_task_encoder( batch=batch, task_info=task_info, logger=logger, step=step, kwargs_to_compute_gradient=deepcopy( kwargs_to_compute_gradient), ) return batch.buffer_index
def update_actor_and_alpha( self, batch: ReplayBufferSample, task_info: TaskInfo, logger: Logger, step: int, kwargs_to_compute_gradient: Dict[str, Any], ) -> None: """Update the actor and alpha component. Args: batch (ReplayBufferSample): batch from the replay buffer. task_info (TaskInfo): task_info object. logger ([Logger]): logger object. step (int): step for tracking the training of the agent. kwargs_to_compute_gradient (Dict[str, Any]): """ # detach encoder, so we don't update it with the actor loss mtobs = MTObs( env_obs=batch.env_obs, task_obs=None, task_info=task_info, ) _, pi, log_pi, log_std = self.actor(mtobs=mtobs, detach_encoder=True) actor_Q1, actor_Q2 = self.critic(mtobs=mtobs, action=pi, detach_encoder=True) actor_Q = torch.min(actor_Q1, actor_Q2) if self.loss_reduction == "mean": actor_loss = (self.get_alpha(batch.task_obs).detach() * log_pi - actor_Q).mean() logger.log("train/actor_loss", actor_loss, step) elif self.loss_reduction == "none": actor_loss = self.get_alpha( batch.task_obs).detach() * log_pi - actor_Q logger.log("train/actor_loss", actor_loss.mean(), step) logger.log("train/actor_target_entropy", self.target_entropy, step) entropy = 0.5 * log_std.shape[1] * ( 1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1) logger.log("train/actor_entropy", entropy.mean(), step) # optimize the actor component_names = ["actor"] parameters: List[ParameterType] = [] for name in component_names: self._optimizers[name].zero_grad() parameters += self.get_parameters(name) if task_info.compute_grad: component_names.append("task_encoder") kwargs_to_compute_gradient["retain_graph"] = True parameters += self.get_parameters("task_encoder") self._compute_gradient( loss=actor_loss, parameters=parameters, step=step, component_names=component_names, **kwargs_to_compute_gradient, ) self.actor_optimizer.step() self.log_alpha_optimizer.zero_grad() if self.loss_reduction == "mean": alpha_loss = (self.get_alpha(batch.task_obs) * (-log_pi - self.target_entropy).detach()).mean() logger.log("train/alpha_loss", alpha_loss, step) elif self.loss_reduction == "none": alpha_loss = (self.get_alpha(batch.task_obs) * (-log_pi - self.target_entropy).detach()) logger.log("train/alpha_loss", alpha_loss.mean(), step) # breakpoint() # logger.log("train/alpha_value", self.get_alpha(batch.task_obs), step) self._compute_gradient( loss=alpha_loss, parameters=self.get_parameters(name="log_alpha"), step=step, component_names=["log_alpha"], **kwargs_to_compute_gradient, ) self.log_alpha_optimizer.step()
def update_critic( self, batch: ReplayBufferSample, task_info: TaskInfo, logger: Logger, step: int, kwargs_to_compute_gradient: Dict[str, Any], ) -> None: """Update the critic component. Args: batch (ReplayBufferSample): batch from the replay buffer. task_info (TaskInfo): task_info object. logger ([Logger]): logger object. step (int): step for tracking the training of the agent. kwargs_to_compute_gradient (Dict[str, Any]): """ with torch.no_grad(): target_V = self._get_target_V(batch=batch, task_info=task_info) target_Q = batch.reward + (batch.not_done * self.discount * target_V) # get current Q estimates mtobs = MTObs(env_obs=batch.env_obs, task_obs=None, task_info=task_info) current_Q1, current_Q2 = self.critic( mtobs=mtobs, action=batch.action, detach_encoder=False, ) critic_loss = F.mse_loss( current_Q1, target_Q, reduction=self.loss_reduction) + F.mse_loss( current_Q2, target_Q, reduction=self.loss_reduction) loss_to_log = critic_loss if self.loss_reduction == "none": loss_to_log = loss_to_log.mean() logger.log("train/critic_loss", loss_to_log, step) if loss_to_log > 1e8: raise RuntimeError( f"critic_loss = {loss_to_log} is too high. Stopping training.") component_names = ["critic"] parameters: List[ParameterType] = [] for name in component_names: self._optimizers[name].zero_grad() parameters += self.get_parameters(name) if task_info.compute_grad: component_names.append("task_encoder") kwargs_to_compute_gradient["retain_graph"] = True parameters += self.get_parameters("task_encoder") self._compute_gradient( loss=critic_loss, parameters=parameters, step=step, component_names=component_names, **kwargs_to_compute_gradient, ) # Optimize the critic self.critic_optimizer.step()