Esempio n. 1
0
    def update_decoder(
        self,
        batch: ReplayBufferSample,
        task_info: TaskInfo,
        logger: Logger,
        step: int,
        kwargs_to_compute_gradient: Dict[str, Any],
    ) -> None:
        obs = batch.env_obs
        target_obs = batch.env_obs
        mtobs = MTObs(env_obs=obs, task_obs=None, task_info=task_info)
        h = self.critic.encode(mtobs=mtobs)

        if target_obs.dim() == 4:
            # preprocess images to be in [-0.5, 0.5] range
            target_obs = agent_utils.preprocess_obs(target_obs)
        rec_obs = self.decoder(h)
        if self.loss_reduction == "mean":
            rec_loss = F.mse_loss(target_obs, rec_obs).mean()
        elif self.loss_reduction == "none":
            rec_loss = F.mse_loss(target_obs, rec_obs, reduction="none")
            rec_loss = rec_loss.view(rec_loss.shape[0], -1).mean(dim=1,
                                                                 keepdim=True)

        # add L2 penalty on latent representation
        # see https://arxiv.org/pdf/1903.12436.pdf

        if self.loss_reduction == "mean":
            latent_loss = (0.5 * h.pow(2).sum(1)).mean()
        elif self.loss_reduction == "none":
            latent_loss = 0.5 * h.pow(2).sum(1, keepdim=True)
        loss = rec_loss + self.decoder_latent_lambda * latent_loss
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()

        component_names = ["encoder", "decoder"]
        parameters: List[ParameterType] = []
        for name in component_names:
            self._optimizers[name].zero_grad()
            parameters += self.get_parameters(name)
        if task_info.compute_grad:
            component_names.append("task_encoder")
            kwargs_to_compute_gradient["retain_graph"] = True
            parameters += self.get_parameters("task_encoder")

        self._compute_gradient(loss,
                               parameters=parameters,
                               step=step,
                               component_names=component_names,
                               **kwargs_to_compute_gradient)

        self.encoder_optimizer.step()
        self.decoder_optimizer.step()

        if self.loss_reduction == "mean":
            loss_to_log = loss
        elif self.loss_reduction == "none":
            loss_to_log = loss.mean()
        logger.log("train/ae_loss", loss_to_log, step)
Esempio n. 2
0
    def __init__(self, config: ConfigType, experiment_id: str = "0"):
        """Experiment Class to manage the lifecycle of a model.

        Args:
            config (ConfigType):
            experiment_id (str, optional): Defaults to "0".
        """
        self.id = experiment_id
        self.config = config
        self.device = torch.device(self.config.setup.device)

        self.get_env_metadata = get_env_metadata
        self.envs, self.env_metadata = self.build_envs()

        key = "ordered_task_list"
        if key in self.env_metadata and self.env_metadata[key]:
            ordered_task_dict = {
                task: index
                for index, task in enumerate(self.env_metadata[key])
            }
        else:
            ordered_task_dict = {}

        key = "envs_to_exclude_during_training"
        if key in self.config.experiment and self.config.experiment[key]:
            self.envs_to_exclude_during_training = {
                ordered_task_dict[task]
                for task in self.config.experiment[key]
            }
            print(
                f"Excluding the following environments: {self.envs_to_exclude_during_training}"
            )
        else:
            self.envs_to_exclude_during_training = set()

        self.action_space = self.env_metadata["action_space"]
        assert self.action_space.low.min() >= -1
        assert self.action_space.high.max() <= 1

        self.env_obs_space = self.env_metadata["env_obs_space"]

        env_obs_shape = self.env_obs_space.shape
        action_shape = self.action_space.shape

        self.config = prepare_config(config=self.config,
                                     env_metadata=self.env_metadata)
        self.agent = hydra.utils.instantiate(
            self.config.agent.builder,
            env_obs_shape=env_obs_shape,
            action_shape=action_shape,
            action_range=[
                float(self.action_space.low.min()),
                float(self.action_space.high.max()),
            ],
            device=self.device,
        )

        self.video_dir = utils.make_dir(
            os.path.join(self.config.setup.save_dir, "video"))
        self.model_dir = utils.make_dir(
            os.path.join(self.config.setup.save_dir, "model"))
        self.buffer_dir = utils.make_dir(
            os.path.join(self.config.setup.save_dir, "buffer"))

        self.video = video.VideoRecorder(
            self.video_dir if self.config.experiment.save_video else None)

        self.replay_buffer = hydra.utils.instantiate(
            self.config.replay_buffer,
            device=self.device,
            env_obs_shape=env_obs_shape,
            task_obs_shape=(1, ),
            action_shape=action_shape,
        )

        self.start_step = 0

        should_resume_experiment = self.config.experiment.should_resume

        if should_resume_experiment:
            self.start_step = self.agent.load_latest_step(
                model_dir=self.model_dir)
            self.replay_buffer.load(save_dir=self.buffer_dir)

        self.logger = Logger(
            self.config.setup.save_dir,
            config=self.config,
            retain_logs=should_resume_experiment,
        )
        self.max_episode_steps = self.env_metadata[
            "max_episode_steps"]  # maximum steps that the agent can take in one environment.

        self.startup_logs()
Esempio n. 3
0
    def update_decoder(  # type: ignore[override]
        self,
        batch: ReplayBufferSample,
        task_info: TaskInfo,
        logger: Logger,
        step: int,
        kwargs_to_compute_gradient: Dict[str, Any],
    ):
        obs = batch.env_obs
        action = batch.action
        target_obs = batch.next_env_obs
        #  uses transition model
        mtobs = MTObs(env_obs=obs, task_obs=None, task_info=task_info)
        h = self.critic.encode(mtobs=mtobs)
        next_h = self.transition_model.sample_prediction(
            torch.cat([h, action], dim=1))
        if target_obs.dim() == 4:
            # preprocess images to be in [-0.5, 0.5] range
            target_obs = agent_utils.preprocess_obs(target_obs)
        rec_obs = self.decoder(next_h)  # type: ignore[misc]
        if self.loss_reduction == "mean":
            rec_loss = F.mse_loss(target_obs, rec_obs).mean()
        elif self.loss_reduction == "none":
            rec_loss = F.mse_loss(target_obs, rec_obs, reduction="none")
            rec_loss = rec_loss.view(rec_loss.shape[0], -1).mean(dim=1,
                                                                 keepdim=True)

        # add L2 penalty on latent representation
        # see https://arxiv.org/pdf/1903.12436.pdf

        if self.loss_reduction == "mean":
            latent_loss = (0.5 * h.pow(2).sum(1)).mean()
        elif self.loss_reduction == "none":
            latent_loss = 0.5 * h.pow(2).sum(1, keepdim=True)

        loss = rec_loss + self.decoder_latent_lambda * latent_loss

        component_names = ["transition_model", "decoder"]
        if not self.is_encoder_identity:
            component_names.append("encoder")
        parameters: List[ParameterType] = []
        for name in component_names:
            self._optimizers[name].zero_grad()
            parameters += self.get_parameters(name=name)
        component_names_to_pass = deepcopy(component_names)
        if task_info.compute_grad:
            component_names_to_pass.append("task_encoder")
            kwargs_to_compute_gradient["retain_graph"] = True
            parameters += self.get_parameters("task_encoder")

        self._compute_gradient(
            loss,
            parameters=parameters,
            step=step,
            component_names=component_names_to_pass,
            **kwargs_to_compute_gradient,
        )

        for name in component_names:
            self._optimizers[name].step()

        if self.loss_reduction == "mean":
            loss_to_log = loss
        elif self.loss_reduction == "none":
            loss_to_log = loss.mean()
        logger.log("train/ae_loss", loss_to_log, step)
Esempio n. 4
0
    def update_transition_reward_model(
        self,
        batch: ReplayBufferSample,
        task_info: TaskInfo,
        logger: Logger,
        step: int,
        kwargs_to_compute_gradient: Dict[str, Any],
    ):
        obs = batch.env_obs
        action = batch.action
        next_obs = batch.next_env_obs
        reward = batch.reward
        mtobs = MTObs(env_obs=obs, task_obs=None, task_info=task_info)
        h = self.critic.encode(mtobs=mtobs)
        pred_next_latent_mu, pred_next_latent_sigma = self.transition_model(
            torch.cat([h, action], dim=1))
        if pred_next_latent_sigma is None:
            pred_next_latent_sigma = torch.ones_like(pred_next_latent_mu)
        mtobs = MTObs(env_obs=next_obs, task_obs=None, task_info=task_info)

        next_h = self.critic.encode(mtobs=mtobs)
        diff = (pred_next_latent_mu - next_h.detach()) / pred_next_latent_sigma
        if self.loss_reduction == "mean":
            loss = torch.mean(0.5 * diff.pow(2) +
                              torch.log(pred_next_latent_sigma))
            loss_to_log = loss
        elif self.loss_reduction == "none":
            loss = (0.5 * diff.pow(2) +
                    torch.log(pred_next_latent_sigma)).mean(dim=1,
                                                            keepdim=True)
            loss_to_log = loss.mean()
        logger.log("train/ae_transition_loss", loss_to_log, step)

        pred_next_latent = self.transition_model.sample_prediction(
            torch.cat([h, action], dim=1))
        pred_next_reward = self.reward_decoder(pred_next_latent)
        reward_loss = F.mse_loss(pred_next_reward,
                                 reward,
                                 reduction=self.loss_reduction)
        total_loss = loss + reward_loss

        parameters: List[ParameterType] = []
        component_names = [
            "transition_model",
            "reward_decoder",
        ]
        if not self.is_encoder_identity:
            component_names.append("encoder")
        for name in component_names:
            self._optimizers[name].zero_grad()
            parameters += self.get_parameters(name=name)
        component_names_to_pass = deepcopy(component_names)
        if task_info.compute_grad:
            component_names_to_pass.append("task_encoder")
            kwargs_to_compute_gradient["retain_graph"] = True
            parameters += self.get_parameters("task_encoder")

        self._compute_gradient(
            loss=total_loss,
            # oooh this order is very important
            parameters=parameters,
            step=step,
            component_names=component_names_to_pass,
            **kwargs_to_compute_gradient,
        )
        for name in component_names:
            self._optimizers[name].step()
Esempio n. 5
0
    def update_actor_and_alpha(
        self,
        batch: ReplayBufferSample,
        task_info: TaskInfo,
        logger: Logger,
        step: int,
        kwargs_to_compute_gradient: Dict[str, Any],
    ) -> None:
        """Update the actor and alpha component.

        Args:
            batch (ReplayBufferSample): batch from the replay buffer.
            task_info (TaskInfo): task_info object.
            logger ([Logger]): logger object.
            step (int): step for tracking the training of the agent.
            kwargs_to_compute_gradient (Dict[str, Any]):

        """
        # detach encoder, so we don't update it with the actor loss
        suffix = f"_agent_index_{self.index}"
        mtobs = MTObs(env_obs=batch.env_obs,
                      task_obs=None,
                      task_info=task_info)
        mu, pi, log_pi, log_std = self.agent.actor(mtobs=mtobs,
                                                   detach_encoder=True)
        actor_Q1, actor_Q2 = self.agent.critic(mtobs=mtobs,
                                               action=pi,
                                               detach_encoder=True)

        actor_Q = torch.min(actor_Q1, actor_Q2)
        if self.agent.loss_reduction == "mean":
            actor_loss = (
                self.agent.get_alpha(batch.task_obs).detach() * log_pi -
                actor_Q).mean()
            logger.log(f"train/actor_loss{suffix}", actor_loss, step)

        elif self.agent.loss_reduction == "none":
            actor_loss = (
                self.agent.get_alpha(batch.task_obs).detach() * log_pi -
                actor_Q)
            logger.log(f"train/actor_loss{suffix}", actor_loss.mean(), step)

        logger.log(f"train/actor_target_entropy{suffix}",
                   self.agent.target_entropy, step)

        entropy = 0.5 * log_std.shape[1] * (
            1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1)

        logger.log(f"train/actor_entropy{suffix}", entropy.mean(), step)

        mtobs = MTObs(env_obs=batch.env_obs,
                      task_obs=None,
                      task_info=NoneTaskInfo)
        distral_mu, _, distral_log_pi, distral_log_std = self.distilled_agent.actor(
            mtobs=mtobs, detach_encoder=False)
        distilled_agent_loss = gaussian_kld(
            mean1=distral_mu,
            logvar1=2 * distral_log_std,
            mean2=mu.detach(),
            logvar2=2 * log_std.detach(),
        )
        batch_size = distilled_agent_loss.shape[0]
        distilled_agent_loss = torch.sum(distilled_agent_loss) / batch_size
        logger.log(
            f"train/actor_distilled_agent_loss{suffix}",
            distilled_agent_loss.mean(),
            step,
        )
        distilled_agent_loss = distilled_agent_loss * self.distral_alpha

        # optimize the actor
        component_names = ["actor"]
        parameters: List[ParameterType] = []
        for name in component_names:
            self.agent._optimizers[name].zero_grad()
            parameters += self.agent.get_parameters(name)
        if task_info.compute_grad:
            component_names.append("task_encoder")
            kwargs_to_compute_gradient["retain_graph"] = True
            parameters += self.agent.get_parameters("task_encoder")

        self.agent._compute_gradient(
            loss=actor_loss,
            parameters=parameters,
            step=step,
            component_names=component_names,
            **kwargs_to_compute_gradient,
        )
        self.agent.actor_optimizer.step()
        self.agent.log_alpha_optimizer.zero_grad()
        if self.agent.loss_reduction == "mean":
            alpha_loss = (
                self.agent.get_alpha(batch.task_obs) *
                (-log_pi - self.agent.target_entropy).detach()).mean()
            logger.log(f"train/alpha_loss{suffix}", alpha_loss, step)
        elif self.agent.loss_reduction == "none":
            alpha_loss = (self.agent.get_alpha(batch.task_obs) *
                          (-log_pi - self.agent.target_entropy).detach())
            logger.log(f"train/alpha_loss{suffix}", alpha_loss.mean(), step)
        # logger.log("train/alpha_value", self.get_alpha(batch.task_obs), step)
        self.agent._compute_gradient(
            loss=alpha_loss,
            parameters=self.agent.get_parameters(name="log_alpha"),
            step=step,
            component_names=["log_alpha"],
            **kwargs_to_compute_gradient,
        )
        self.agent.log_alpha_optimizer.step()
        self.distilled_agent._optimizers["actor"].zero_grad()
        distilled_agent_loss.backward()
        self.distilled_agent._optimizers["actor"].step()
Esempio n. 6
0
    def update(
        self,
        replay_buffer: ReplayBuffer,
        logger: Logger,
        step: int,
        kwargs_to_compute_gradient: Optional[Dict[str, Any]] = None,
        buffer_index_to_sample: Optional[np.ndarray] = None,
    ) -> np.ndarray:
        """Update the agent.

        Args:
            replay_buffer (ReplayBuffer): replay buffer to sample the data.
            logger (Logger): logger for logging.
            step (int): step for tracking the training progress.
            kwargs_to_compute_gradient (Optional[Dict[str, Any]], optional): Defaults
                to None.
            buffer_index_to_sample (Optional[np.ndarray], optional): if this parameter
                is specified, use these indices instead of sampling from the replay
                buffer. If this is set to `None`, sample from the replay buffer.
                buffer_index_to_sample Defaults to None.

        Returns:
            np.ndarray: index sampled (from the replay buffer) to train the model. If
                buffer_index_to_sample is not set to None, return buffer_index_to_sample.

        """

        if kwargs_to_compute_gradient is None:
            kwargs_to_compute_gradient = {}

        if buffer_index_to_sample is None:
            batch = replay_buffer.sample()
        else:
            batch = replay_buffer.sample(buffer_index_to_sample)

        logger.log("train/batch_reward", batch.reward.mean(), step)
        if self.should_use_task_encoder:
            self.task_encoder_optimizer.zero_grad()
            task_encoding = self.get_task_encoding(
                env_index=batch.task_obs.squeeze(1),
                disable_grad=False,
                modes=["train"],
            )
        else:
            task_encoding = None  # type: ignore[assignment]

        task_info = self.get_task_info(
            task_encoding=task_encoding,
            component_name="critic",
            env_index=batch.task_obs,
        )
        self.update_critic(
            batch=batch,
            task_info=task_info,
            logger=logger,
            step=step,
            kwargs_to_compute_gradient=deepcopy(kwargs_to_compute_gradient),
        )
        if step % self.actor_update_freq == 0:
            task_info = self.get_task_info(
                task_encoding=task_encoding,
                component_name="actor",
                env_index=batch.task_obs,
            )
            self.update_actor_and_alpha(
                batch=batch,
                task_info=task_info,
                logger=logger,
                step=step,
                kwargs_to_compute_gradient=deepcopy(
                    kwargs_to_compute_gradient),
            )
        if step % self.critic_target_update_freq == 0:
            agent_utils.soft_update_params(self.critic.Q1,
                                           self.critic_target.Q1,
                                           self.critic_tau)
            agent_utils.soft_update_params(self.critic.Q2,
                                           self.critic_target.Q2,
                                           self.critic_tau)
            agent_utils.soft_update_params(self.critic.encoder,
                                           self.critic_target.encoder,
                                           self.encoder_tau)

        if ("transition_model" in self._components
                and "reward_decoder" in self._components):
            # some of the logic is a bit sketchy here. We will get to it soon.
            task_info = self.get_task_info(
                task_encoding=task_encoding,
                component_name="transition_reward",
                env_index=batch.task_obs,
            )
            self.update_transition_reward_model(
                batch=batch,
                task_info=task_info,
                logger=logger,
                step=step,
                kwargs_to_compute_gradient=deepcopy(
                    kwargs_to_compute_gradient),
            )
        if ("decoder" in self._components  # should_update_decoder
                and self.decoder is not None  # type: ignore[attr-defined]
                and step % self.decoder_update_freq ==
                0  # type: ignore[attr-defined]
            ):
            task_info = self.get_task_info(
                task_encoding=task_encoding,
                component_name="decoder",
                env_index=batch.task_obs,
            )
            self.update_decoder(
                batch=batch,
                task_info=task_info,
                logger=logger,
                step=step,
                kwargs_to_compute_gradient=deepcopy(
                    kwargs_to_compute_gradient),
            )

        if self.should_use_task_encoder:
            task_info = self.get_task_info(
                task_encoding=task_encoding,
                component_name="task_encoder",
                env_index=batch.task_obs,
            )
            self.update_task_encoder(
                batch=batch,
                task_info=task_info,
                logger=logger,
                step=step,
                kwargs_to_compute_gradient=deepcopy(
                    kwargs_to_compute_gradient),
            )

        return batch.buffer_index
Esempio n. 7
0
    def update_actor_and_alpha(
        self,
        batch: ReplayBufferSample,
        task_info: TaskInfo,
        logger: Logger,
        step: int,
        kwargs_to_compute_gradient: Dict[str, Any],
    ) -> None:
        """Update the actor and alpha component.

        Args:
            batch (ReplayBufferSample): batch from the replay buffer.
            task_info (TaskInfo): task_info object.
            logger ([Logger]): logger object.
            step (int): step for tracking the training of the agent.
            kwargs_to_compute_gradient (Dict[str, Any]):

        """

        # detach encoder, so we don't update it with the actor loss
        mtobs = MTObs(
            env_obs=batch.env_obs,
            task_obs=None,
            task_info=task_info,
        )
        _, pi, log_pi, log_std = self.actor(mtobs=mtobs, detach_encoder=True)
        actor_Q1, actor_Q2 = self.critic(mtobs=mtobs,
                                         action=pi,
                                         detach_encoder=True)

        actor_Q = torch.min(actor_Q1, actor_Q2)
        if self.loss_reduction == "mean":
            actor_loss = (self.get_alpha(batch.task_obs).detach() * log_pi -
                          actor_Q).mean()
            logger.log("train/actor_loss", actor_loss, step)

        elif self.loss_reduction == "none":
            actor_loss = self.get_alpha(
                batch.task_obs).detach() * log_pi - actor_Q
            logger.log("train/actor_loss", actor_loss.mean(), step)

        logger.log("train/actor_target_entropy", self.target_entropy, step)

        entropy = 0.5 * log_std.shape[1] * (
            1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1)

        logger.log("train/actor_entropy", entropy.mean(), step)

        # optimize the actor
        component_names = ["actor"]
        parameters: List[ParameterType] = []
        for name in component_names:
            self._optimizers[name].zero_grad()
            parameters += self.get_parameters(name)
        if task_info.compute_grad:
            component_names.append("task_encoder")
            kwargs_to_compute_gradient["retain_graph"] = True
            parameters += self.get_parameters("task_encoder")

        self._compute_gradient(
            loss=actor_loss,
            parameters=parameters,
            step=step,
            component_names=component_names,
            **kwargs_to_compute_gradient,
        )
        self.actor_optimizer.step()

        self.log_alpha_optimizer.zero_grad()
        if self.loss_reduction == "mean":
            alpha_loss = (self.get_alpha(batch.task_obs) *
                          (-log_pi - self.target_entropy).detach()).mean()
            logger.log("train/alpha_loss", alpha_loss, step)
        elif self.loss_reduction == "none":
            alpha_loss = (self.get_alpha(batch.task_obs) *
                          (-log_pi - self.target_entropy).detach())
            logger.log("train/alpha_loss", alpha_loss.mean(), step)
        # breakpoint()
        # logger.log("train/alpha_value", self.get_alpha(batch.task_obs), step)
        self._compute_gradient(
            loss=alpha_loss,
            parameters=self.get_parameters(name="log_alpha"),
            step=step,
            component_names=["log_alpha"],
            **kwargs_to_compute_gradient,
        )
        self.log_alpha_optimizer.step()
Esempio n. 8
0
    def update_critic(
        self,
        batch: ReplayBufferSample,
        task_info: TaskInfo,
        logger: Logger,
        step: int,
        kwargs_to_compute_gradient: Dict[str, Any],
    ) -> None:
        """Update the critic component.

        Args:
            batch (ReplayBufferSample): batch from the replay buffer.
            task_info (TaskInfo): task_info object.
            logger ([Logger]): logger object.
            step (int): step for tracking the training of the agent.
            kwargs_to_compute_gradient (Dict[str, Any]):

        """
        with torch.no_grad():
            target_V = self._get_target_V(batch=batch, task_info=task_info)
            target_Q = batch.reward + (batch.not_done * self.discount *
                                       target_V)

        # get current Q estimates
        mtobs = MTObs(env_obs=batch.env_obs,
                      task_obs=None,
                      task_info=task_info)
        current_Q1, current_Q2 = self.critic(
            mtobs=mtobs,
            action=batch.action,
            detach_encoder=False,
        )
        critic_loss = F.mse_loss(
            current_Q1, target_Q, reduction=self.loss_reduction) + F.mse_loss(
                current_Q2, target_Q, reduction=self.loss_reduction)

        loss_to_log = critic_loss
        if self.loss_reduction == "none":
            loss_to_log = loss_to_log.mean()
        logger.log("train/critic_loss", loss_to_log, step)

        if loss_to_log > 1e8:
            raise RuntimeError(
                f"critic_loss = {loss_to_log} is too high. Stopping training.")

        component_names = ["critic"]
        parameters: List[ParameterType] = []
        for name in component_names:
            self._optimizers[name].zero_grad()
            parameters += self.get_parameters(name)
        if task_info.compute_grad:
            component_names.append("task_encoder")
            kwargs_to_compute_gradient["retain_graph"] = True
            parameters += self.get_parameters("task_encoder")

        self._compute_gradient(
            loss=critic_loss,
            parameters=parameters,
            step=step,
            component_names=component_names,
            **kwargs_to_compute_gradient,
        )

        # Optimize the critic
        self.critic_optimizer.step()