Exemple #1
0
    def __init__(self, args, logger, get_iter_idx):

        self.args = args
        self.logger = logger
        self.get_iter_idx = get_iter_idx
        self.task_dim = get_task_dim(
            self.args) if self.args.decode_task else None
        self.num_tasks = get_num_tasks(
            self.args) if self.args.decode_task else None

        # initialise the encoder
        self.encoder = self.initialise_encoder()

        # initialise the decoders (returns None for unused decoders)
        self.state_decoder, self.reward_decoder, self.task_decoder = self.initialise_decoder(
        )

        # initialise rollout storage for the VAE update
        # (this differs from the data that the on-policy RL algorithm uses)
        self.rollout_storage = RolloutStorageVAE(
            num_processes=self.args.num_processes,
            max_trajectory_len=self.args.max_trajectory_len,
            zero_pad=True,
            max_num_rollouts=self.args.size_vae_buffer,
            state_dim=self.args.state_dim,
            action_dim=self.args.action_dim,
            vae_buffer_add_thresh=self.args.vae_buffer_add_thresh,
            task_dim=self.task_dim)

        # initalise optimiser for the encoder and decoders
        decoder_params = []
        if not self.args.disable_decoder:
            if self.args.decode_reward:
                decoder_params.extend(self.reward_decoder.parameters())
            if self.args.decode_state:
                decoder_params.extend(self.state_decoder.parameters())
            if self.args.decode_task:
                decoder_params.extend(self.task_decoder.parameters())
        self.optimiser_vae = torch.optim.Adam(
            [*self.encoder.parameters(), *decoder_params], lr=self.args.lr_vae)
Exemple #2
0
class VaribadVAE:
    """
    VAE of variBAD:
    - has an encoder and decoder,
    - can compute the ELBO loss
    - can update the VAE part of the model
    """

    def __init__(self, args, logger, get_iter_idx):

        self.args = args
        self.logger = logger
        self.get_iter_idx = get_iter_idx
        self.task_dim = self.get_task_dim()

        # initialise the encoder
        self.encoder = self.initialise_encoder()

        # initialise the decoders (returns None for unused decoders)
        self.state_decoder, self.reward_decoder, self.task_decoder = self.initialise_decoder()

        # initialise rollout storage for the VAE update
        self.rollout_storage = RolloutStorageVAE(num_processes=self.args.num_processes,
                                                 max_trajectory_len=self.args.max_trajectory_len,
                                                 zero_pad=True,
                                                 max_num_rollouts=self.args.size_vae_buffer,
                                                 obs_dim=self.args.obs_dim,
                                                 action_dim=self.args.action_dim,
                                                 vae_buffer_add_thresh=self.args.vae_buffer_add_thresh,
                                                 task_dim=self.task_dim,
                                                 )

        # initalise optimiser for the encoder and decoders

        decoder_params = []
        if not self.args.disable_decoder:
            if self.args.decode_reward:
                decoder_params.extend(self.reward_decoder.parameters())
            if self.args.decode_state:
                decoder_params.extend(self.state_decoder.parameters())
            if self.args.decode_task:
                decoder_params.extend(self.task_decoder.parameters())

        self.optimiser_vae = torch.optim.Adam([*self.encoder.parameters(), *decoder_params], lr=self.args.lr_vae)

    def compute_task_reconstruction_loss(self, dec_embedding, dec_task, return_predictions=False):
        # make some predictions and compute individual losses
        task_pred = self.task_decoder(dec_embedding)

        if self.args.task_pred_type == 'task_id':
            env = gym.make(self.args.env_name)
            dec_task = env.task_to_id(dec_task)
            dec_task = dec_task.expand(task_pred.shape[:-1]).view(-1)
            # loss for the data we fed into encoder
            task_pred_shape = task_pred.shape
            loss_task = F.cross_entropy(task_pred.view(-1, task_pred.shape[-1]), dec_task, reduction='none').reshape(
                task_pred_shape[:-1])
        elif self.args.task_pred_type == 'task_description':
            loss_task = (task_pred - dec_task).pow(2).mean(dim=1)

        if return_predictions:
            return loss_task, task_pred
        else:
            return loss_task

    def initialise_decoder(self):

        latent_dim = self.args.latent_dim
        if self.args.disable_stochasticity_in_latent:
            latent_dim *= 2

        if self.args.decode_reward:
            # initialise reward decoder for VAE
            reward_decoder = RewardDecoder(
                layers=self.args.reward_decoder_layers,
                latent_dim=latent_dim,
                #
                state_dim=self.args.obs_dim,
                state_embed_dim=self.args.state_embedding_size,
                action_dim=self.args.action_dim,
                action_embed_dim=self.args.action_embedding_size,
                num_states=self.args.num_states,
                multi_head=self.args.multihead_for_reward,
                pred_type=self.args.rew_pred_type,
                input_prev_state=self.args.input_prev_state,
                input_action=self.args.input_action,
            ).to(device)
        else:
            reward_decoder = None

        if self.args.decode_state:
            # initialise state decoder for VAE
            state_decoder = StateTransitionDecoder(
                latent_dim=latent_dim,
                layers=self.args.state_decoder_layers,
                action_dim=self.args.action_dim,
                action_embed_dim=self.args.action_embedding_size,
                state_dim=self.args.obs_dim,
                state_embed_dim=self.args.state_embedding_size,
                pred_type=self.args.state_pred_type,
            ).to(device)
        else:
            state_decoder = None

        if self.args.decode_task:
            env = gym.make(self.args.env_name)
            if self.args.task_pred_type == 'task_description':
                task_dim = env.task_dim
            elif self.args.task_pred_type == 'task_id':
                task_dim = env.num_tasks
            else:
                raise NotImplementedError
            task_decoder = TaskDecoder(
                latent_dim=latent_dim,
                layers=self.args.task_decoder_layers,
                task_dim=task_dim,
                pred_type=self.args.task_pred_type,
            ).to(device)
        else:
            task_decoder = None

        return state_decoder, reward_decoder, task_decoder

    def initialise_encoder(self):
        """
        Initialises an RNN encoder.
        :return:
        """

        encoder = RNNEncoder(
            layers_before_gru=self.args.layers_before_aggregator,
            hidden_size=self.args.aggregator_hidden_size,
            layers_after_gru=self.args.layers_after_aggregator,
            latent_dim=self.args.latent_dim,
            action_dim=self.args.action_dim,
            action_embed_dim=self.args.action_embedding_size,
            state_dim=self.args.obs_dim,
            state_embed_dim=self.args.state_embedding_size,
            reward_size=1,
            reward_embed_size=self.args.reward_embedding_size,
        ).to(device)

        return encoder

    def compute_state_reconstruction_loss(self, dec_embedding, dec_prev_obs, dec_next_obs, dec_actions,
                                          return_predictions=False):
        # make some predictions and compute individual losses
        if self.args.state_pred_type == 'deterministic':
            obs_reconstruction = self.state_decoder(dec_embedding, dec_prev_obs, dec_actions)
            loss_state = (obs_reconstruction - dec_next_obs).pow(2).mean(dim=1)
        elif self.args.state_pred_type == 'gaussian':
            state_pred = self.state_decoder(dec_embedding, dec_prev_obs, dec_actions)
            state_pred_mean = state_pred[:, :state_pred.shape[1] // 2]
            state_pred_std = torch.exp(0.5 * state_pred[:, state_pred.shape[1] // 2:])
            m = torch.distributions.normal.Normal(state_pred_mean, state_pred_std)
            loss_state = -m.log_prob(dec_next_obs).mean(dim=1)

        if return_predictions:
            return loss_state, obs_reconstruction
        else:
            return loss_state

    def compute_rew_reconstruction_loss(self, dec_embedding,
                                        dec_prev_obs, dec_next_obs,
                                        dec_actions, dec_rewards,
                                        return_predictions=False):
        """
        Computed the reward reconstruction loss
        (no reduction of loss is done here; sum/avg has to be done outside)
        """

        # make some predictions and compute individual losses
        if self.args.multihead_for_reward:
            if self.args.rew_pred_type == 'bernoulli' or self.args.rew_pred_type == 'categorical':
                # loss for the data we fed into encoder
                p_rew = self.reward_decoder(dec_embedding, None)
                env = gym.make(self.args.env_name)
                indices = env.task_to_id(dec_next_obs).to(device)
                if indices.dim() < p_rew.dim():
                    indices = indices.unsqueeze(-1)
                rew_pred = p_rew.gather(dim=-1, index=indices)
                rew_target = (dec_rewards == 1).float()
                loss_rew = F.binary_cross_entropy(rew_pred, rew_target, reduction='none').mean(dim=-1)
            elif self.args.rew_pred_type == 'deterministic':
                raise NotImplementedError
                p_rew = self.reward_decoder(dec_embedding, None)
                env = gym.make(self.args.env_name)
                indices = env.task_to_id(dec_next_obs)
                loss_rew = F.mse_loss(p_rew.gather(1, indices.reshape(-1, 1)), dec_rewards, reduction='none').mean(
                    dim=1)
            else:
                raise NotImplementedError
        else:
            if self.args.rew_pred_type == 'bernoulli':
                rew_pred = self.reward_decoder(dec_embedding, dec_next_obs)
                loss_rew = F.binary_cross_entropy(rew_pred, (dec_rewards == 1).float(), reduction='none').mean(dim=1)
            elif self.args.rew_pred_type == 'deterministic':
                rew_pred = self.reward_decoder(dec_embedding, dec_next_obs, dec_prev_obs, dec_actions)
                loss_rew = (rew_pred - dec_rewards).pow(2).mean(dim=1)
            elif self.args.rew_pred_type == 'gaussian':
                rew_pred = self.reward_decoder(dec_embedding, dec_next_obs, dec_prev_obs, dec_actions).mean(dim=1)
                rew_pred_mean = rew_pred[:, :rew_pred.shape[1] // 2]
                rew_pred_std = torch.exp(0.5 * rew_pred[:, rew_pred.shape[1] // 2:])
                m = torch.distributions.normal.Normal(rew_pred_mean, rew_pred_std)
                loss_rew = -m.log_prob(dec_rewards)
            else:
                raise NotImplementedError

        if return_predictions:
            return loss_rew, rew_pred
        else:
            return loss_rew

    def compute_kl_loss(self, latent_mean, latent_logvar, len_encoder):
        # -- KL divergence
        if self.args.kl_to_gauss_prior:
            kl_divergences = (- 0.5 * (1 + latent_logvar - latent_mean.pow(2) - latent_logvar.exp()).sum(dim=1))
        else:
            gauss_dim = latent_mean.shape[-1]
            # add the gaussian prior
            all_means = torch.cat((torch.zeros(1, latent_mean.shape[1]).to(device), latent_mean))
            all_logvars = torch.cat((torch.zeros(1, latent_logvar.shape[1]).to(device), latent_logvar))
            # https://arxiv.org/pdf/1811.09975.pdf
            # KL(N(mu,E)||N(m,S)) = 0.5 * (log(|S|/|E|) - K + tr(S^-1 E) + (m-mu)^T S^-1 (m-mu)))
            mu = all_means[1:]
            m = all_means[:-1]
            logE = all_logvars[1:]
            logS = all_logvars[:-1]
            kl_divergences = 0.5 * (torch.sum(logS, dim=1) - torch.sum(logE, dim=1) - gauss_dim + torch.sum(
                1 / torch.exp(logS) * torch.exp(logE), dim=1) + ((m - mu) / torch.exp(logS) * (m - mu)).sum(dim=1))

        if self.args.learn_prior:
            mask = torch.ones(len(kl_divergences))
            mask[0] = 0
            kl_divergences = kl_divergences * mask

        # returns, for each ELBO_t term, one KL (so H+1 kl's)
        if len_encoder is not None:
            return kl_divergences[len_encoder]
        else:
            return kl_divergences

    def compute_vae_loss(self, update=False):
        """
        Returns the VAE loss
        """

        if not self.rollout_storage.ready_for_update():
            return 0

        if self.args.disable_decoder and self.args.disable_stochasticity_in_latent:
            return 0

        # get a mini-batch
        vae_prev_obs, vae_next_obs, vae_actions, vae_rewards, vae_tasks, \
        len_encoder, trajectory_lens = self.rollout_storage.get_batch(num_rollouts=self.args.vae_batch_num_trajs,
                                                                      num_enc_len=self.args.vae_batch_num_enc_lens)
        # vae_prev_obs will be of size: max trajectory len x num trajectories x dimension of observations
        # len_encoder will be of size:  number of trajectories x data_per_rollout

        # pass through encoder (outputs will be: (max_traj_len+1) x number of rollouts x latent_dim -- includes the prior!)
        _, latent_mean, latent_logvar, _ = self.encoder(actions=vae_actions,
                                                        states=vae_next_obs,
                                                        rewards=vae_rewards,
                                                        hidden_state=None,
                                                        return_prior=True)

        rew_reconstruction_loss = []
        state_reconstruction_loss = []
        task_reconstruction_loss = []
        kl_loss = []

        num_tasks = len(trajectory_lens)
        # for each task we have in our batch...
        for idx_traj in range(num_tasks):

            # get the embedding values (size: traj_length+1 * latent_dim; the +1 is for the prior)
            curr_means = latent_mean[:trajectory_lens[idx_traj] + 1, idx_traj, :]
            curr_logvars = latent_logvar[:trajectory_lens[idx_traj] + 1, idx_traj, :]
            # take one sample for each ELBO term
            curr_samples = self.encoder._sample_gaussian(curr_means, curr_logvars)

            # select data from current rollout (result is traj_length * obs_dim)
            curr_prev_obs = vae_prev_obs[:, idx_traj, :]
            curr_next_obs = vae_next_obs[:, idx_traj, :]
            curr_actions = vae_actions[:, idx_traj, :]
            curr_rewards = vae_rewards[:, idx_traj, :]

            dec_embedding = []
            dec_embedding_task = []
            dec_prev_obs, dec_next_obs, dec_actions, dec_rewards = [], [], [], []

            # if the size of what we decode is always the same, we can speed up creating the batches
            if len(np.unique(trajectory_lens)) == 1 and not self.args.decode_only_past:

                num_latents = curr_samples.shape[0]  # includes the prior
                num_decodes = curr_prev_obs.shape[0]

                # expand the latent to match the (x, y) pairs of the decoder
                dec_embedding = curr_samples.unsqueeze(0).expand((num_decodes, *curr_samples.shape)).transpose(1, 0)
                dec_embedding_task = curr_samples

                # expand the (x, y) pair of the encoder
                dec_prev_obs = curr_prev_obs.unsqueeze(0).expand((num_latents, *curr_prev_obs.shape))
                dec_next_obs = curr_next_obs.unsqueeze(0).expand((num_latents, *curr_next_obs.shape))
                dec_actions = curr_actions.unsqueeze(0).expand((num_latents, *curr_actions.shape))
                dec_rewards = curr_rewards.unsqueeze(0).expand((num_latents, *curr_rewards.shape))

            # otherwise, we unfortunately have to loop!
            # loop through the lengths we are feeding into the encoder for that trajectory (starting with prior)
            # (these are the different ELBO_t terms)
            else:

                for i, idx_timestep in enumerate(len_encoder[idx_traj]):

                    # get samples

                    # get the index until which we want to decode
                    # (i.e. eithe runtil curr timestep or entire trajectory including future)
                    if self.args.decode_only_past:
                        dec_until = idx_timestep
                    else:
                        dec_until = trajectory_lens[idx_traj]

                    if dec_until != 0:
                        # (1) ... get the latent sample after feeding in some data (determined by len_encoder) & expand (to number of outputs)
                        # # num latent samples x embedding size
                        if not self.args.disable_stochasticity_in_latent:
                            dec_embedding.append(curr_samples[i].expand(dec_until, -1))
                            dec_embedding_task.append(curr_samples[i])
                        else:
                            dec_embedding.append(
                                torch.cat((curr_means[idx_timestep], curr_logvars[idx_timestep])).expand(dec_until, -1))
                            dec_embedding_task.append(torch.cat((curr_means[idx_timestep], curr_logvars[idx_timestep])))
                        # (2) ... get the predictions for the trajectory until the timestep we're interested in
                        dec_prev_obs.append(curr_prev_obs[:dec_until])
                        dec_next_obs.append(curr_next_obs[:dec_until])
                        dec_actions.append(curr_actions[:dec_until])
                        dec_rewards.append(curr_rewards[:dec_until])

                # stack all of the things we decode! the dimensions of these will be:
                # number of elbo terms (current timesteps from which we want to decode (H+1)
                # x
                # number of terms in elbo (reconstr. of traj.) (H)
                # x
                # dimension (of latent space or obs/act/rew)
                #
                # what we want to do is SUM across the length of the predicted trajectory and AVERAGE across the rest
                if self.args.decode_only_past:
                    dec_embedding = torch.cat(dec_embedding)
                    dec_embedding_task = torch.cat(dec_embedding_task)
                    #
                    dec_prev_obs = torch.cat(dec_prev_obs)
                    dec_next_obs = torch.cat(dec_next_obs)
                    dec_actions = torch.cat(dec_actions)
                    dec_rewards = torch.cat(dec_rewards)
                else:
                    dec_embedding = torch.stack(dec_embedding)
                    dec_embedding_task = torch.stack(dec_embedding_task)
                    #
                    dec_prev_obs = torch.stack(dec_prev_obs)
                    dec_next_obs = torch.stack(dec_next_obs)
                    dec_actions = torch.stack(dec_actions)
                    dec_rewards = torch.stack(dec_rewards)

            if self.args.decode_reward:
                # compute reconstruction loss for this trajectory
                # (for each timestep that was encoded, decode everything and sum it up)
                rrc = self.compute_rew_reconstruction_loss(dec_embedding,
                                                           dec_prev_obs,
                                                           dec_next_obs,
                                                           dec_actions,
                                                           dec_rewards
                                                           )
                # sum along the trajectory which we decoded (sum in ELBO_t)
                if self.args.decode_only_past:
                    curr_idx = 0
                    past_reconstr_sum = []
                    for i, idx_timestep in enumerate(len_encoder[idx_traj]):
                        dec_until = idx_timestep
                        if dec_until != 0:
                            past_reconstr_sum.append(rrc[curr_idx:curr_idx + dec_until].sum())
                        curr_idx += dec_until
                    rrc = torch.stack(past_reconstr_sum)
                else:
                    rrc = rrc.sum(dim=1)
                rew_reconstruction_loss.append(rrc)
            if self.args.decode_state:
                src = self.compute_state_reconstruction_loss(dec_embedding, dec_prev_obs, dec_next_obs, dec_actions)
                src = src.sum(dim=1)
                state_reconstruction_loss.append(src)
            if self.args.decode_task:
                trc = self.compute_task_reconstruction_loss(dec_embedding_task, vae_tasks[idx_traj])
                task_reconstruction_loss.append(trc)
            if not self.args.disable_stochasticity_in_latent:
                # compute the KL term for each ELBO term of the current trajectory
                kl = self.compute_kl_loss(curr_means, curr_logvars, len_encoder[idx_traj])
                kl_loss.append(kl)

        # sum the ELBO_t terms per task
        if self.args.decode_reward:
            rew_reconstruction_loss = torch.stack(rew_reconstruction_loss)
            rew_reconstruction_loss = rew_reconstruction_loss.sum(dim=1)
        else:
            rew_reconstruction_loss = 0

        if self.args.decode_state:
            state_reconstruction_loss = torch.stack(state_reconstruction_loss)
            state_reconstruction_loss = state_reconstruction_loss.sum(dim=1)
        else:
            state_reconstruction_loss = 0

        if self.args.decode_task:
            task_reconstruction_loss = torch.stack(task_reconstruction_loss)
            task_reconstruction_loss = task_reconstruction_loss.sum(dim=1)
        else:
            task_reconstruction_loss = 0

        if not self.args.disable_stochasticity_in_latent:
            kl_loss = torch.stack(kl_loss)
            kl_loss = kl_loss.sum(dim=1)
        else:
            kl_loss = 0

        # VAE loss = KL loss + reward reconstruction + state transition reconstruction
        # take average (this is the expectation over p(M))
        loss = (self.args.rew_loss_coeff * rew_reconstruction_loss +
                self.args.state_loss_coeff * state_reconstruction_loss +
                self.args.task_loss_coeff * task_reconstruction_loss +
                self.args.kl_weight * kl_loss).mean()

        # make sure we can compute gradients
        if not self.args.disable_stochasticity_in_latent:
            assert kl_loss.requires_grad
        if self.args.decode_reward:
            assert rew_reconstruction_loss.requires_grad
        if self.args.decode_state:
            assert state_reconstruction_loss.requires_grad
        if self.args.decode_task:
            assert task_reconstruction_loss.requires_grad

        # overall loss
        elbo_loss = loss.mean()

        if update:
            self.optimiser_vae.zero_grad()
            elbo_loss.backward()
            self.optimiser_vae.step()
            # clip gradients
            # nn.utils.clip_grad_norm_(self.encoder.parameters(), self.args.a2c_max_grad_norm)
            # nn.utils.clip_grad_norm_(reward_decoder.parameters(), self.args.max_grad_norm)

        self.log(elbo_loss, rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, kl_loss)

        return elbo_loss

    def get_task_dim(self):
        if not self.args.decode_task:
            task_dim = None
        else:
            env = gym.make(self.args.env_name)
            if self.args.task_pred_type == 'task_description':
                task_dim = env.task_dim
            elif self.args.task_pred_type == 'task_id':
                task_dim = env.num_tasks
            else:
                raise NotImplementedError
        return task_dim

    def log(self, elbo_loss, rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, kl_loss):

        curr_iter_idx = self.get_iter_idx()
        if curr_iter_idx % self.args.log_interval == 0:

            if self.args.decode_reward:
                self.logger.add('vae_losses/reward_reconstr_err', rew_reconstruction_loss.mean(), curr_iter_idx)
            if self.args.decode_state:
                self.logger.add('vae_losses/state_reconstr_err', state_reconstruction_loss.mean(), curr_iter_idx)
            if self.args.decode_task:
                self.logger.add('vae_losses/task_reconstr_err', task_reconstruction_loss.mean(), curr_iter_idx)

            if not self.args.disable_stochasticity_in_latent:
                self.logger.add('vae_losses/kl', kl_loss.mean(), curr_iter_idx)
            self.logger.add('vae_losses/sum', elbo_loss, curr_iter_idx)
Exemple #3
0
class VaribadVAE:
    """
    VAE of VariBAD:
    - has an encoder and decoder
    - can compute the ELBO loss
    - can update the VAE (encoder+decoder)
    """
    def __init__(self, args, logger, get_iter_idx):

        self.args = args
        self.logger = logger
        self.get_iter_idx = get_iter_idx
        self.task_dim = get_task_dim(self.args)
        self.num_tasks = get_num_tasks(self.args)

        # initialise the encoder
        self.encoder = self.initialise_encoder()

        # initialise the decoders (returns None for unused decoders)
        self.state_decoder, self.reward_decoder, self.task_decoder = self.initialise_decoder(
        )

        # initialise rollout storage for the VAE update
        # (this differs from the data that the on-policy RL algorithm uses)
        self.rollout_storage = RolloutStorageVAE(
            num_processes=self.args.num_processes,
            max_trajectory_len=self.args.max_trajectory_len,
            zero_pad=True,
            max_num_rollouts=self.args.size_vae_buffer,
            state_dim=self.args.state_dim,
            action_dim=self.args.action_dim,
            vae_buffer_add_thresh=self.args.vae_buffer_add_thresh,
            task_dim=self.task_dim,
        )

        # initalise optimiser for the encoder and decoders
        decoder_params = []
        if not self.args.disable_decoder:
            if self.args.decode_reward:
                decoder_params.extend(self.reward_decoder.parameters())
            if self.args.decode_state:
                decoder_params.extend(self.state_decoder.parameters())
            if self.args.decode_task:
                decoder_params.extend(self.task_decoder.parameters())
        self.optimiser_vae = torch.optim.Adam(
            [*self.encoder.parameters(), *decoder_params], lr=self.args.lr_vae)

    def initialise_encoder(self):
        """ Initialises and returns an RNN encoder """
        encoder = RNNEncoder(
            layers_before_gru=self.args.encoder_layers_before_gru,
            hidden_size=self.args.encoder_gru_hidden_size,
            layers_after_gru=self.args.encoder_layers_after_gru,
            latent_dim=self.args.latent_dim,
            action_dim=self.args.action_dim,
            action_embed_dim=self.args.action_embedding_size,
            state_dim=self.args.state_dim,
            state_embed_dim=self.args.state_embedding_size,
            reward_size=1,
            reward_embed_size=self.args.reward_embedding_size,
        ).to(device)
        return encoder

    def initialise_decoder(self):
        """ Initialises and returns the (state/reward/task) decoder as specified in self.args """

        if self.args.disable_decoder:
            return None, None, None

        latent_dim = self.args.latent_dim
        # double latent dimension (input size to decoder) if we use a deterministic latents (for easier comparison)
        if self.args.disable_stochasticity_in_latent:
            latent_dim *= 2

        # initialise state decoder for VAE
        if self.args.decode_state:
            state_decoder = StateTransitionDecoder(
                layers=self.args.state_decoder_layers,
                latent_dim=latent_dim,
                action_dim=self.args.action_dim,
                action_embed_dim=self.args.action_embedding_size,
                state_dim=self.args.state_dim,
                state_embed_dim=self.args.state_embedding_size,
                pred_type=self.args.state_pred_type,
            ).to(device)
        else:
            state_decoder = None

        # initialise reward decoder for VAE
        if self.args.decode_reward:
            reward_decoder = RewardDecoder(
                layers=self.args.reward_decoder_layers,
                latent_dim=latent_dim,
                state_dim=self.args.state_dim,
                state_embed_dim=self.args.state_embedding_size,
                action_dim=self.args.action_dim,
                action_embed_dim=self.args.action_embedding_size,
                num_states=self.args.num_states,
                multi_head=self.args.multihead_for_reward,
                pred_type=self.args.rew_pred_type,
                input_prev_state=self.args.input_prev_state,
                input_action=self.args.input_action,
            ).to(device)
        else:
            reward_decoder = None

        # initialise task decoder for VAE
        if self.args.decode_task:
            task_decoder = TaskDecoder(
                latent_dim=latent_dim,
                layers=self.args.task_decoder_layers,
                task_dim=self.task_dim,
                num_tasks=self.num_tasks,
                pred_type=self.args.task_pred_type,
            ).to(device)
        else:
            task_decoder = None

        return state_decoder, reward_decoder, task_decoder

    def compute_state_reconstruction_loss(self,
                                          latent,
                                          prev_obs,
                                          next_obs,
                                          action,
                                          return_predictions=False):
        """ Compute state reconstruction loss.
        (No reduction of loss along batch dimension is done here; sum/avg has to be done outside) """

        state_pred = self.state_decoder(latent, prev_obs, action)

        if self.args.state_pred_type == 'deterministic':
            loss_state = (state_pred - next_obs).pow(2).mean(dim=-1)
        elif self.args.state_pred_type == 'gaussian':  # TODO: untested!
            state_pred_mean = state_pred[:, :state_pred.shape[1] // 2]
            state_pred_std = torch.exp(
                0.5 * state_pred[:, state_pred.shape[1] // 2:])
            m = torch.distributions.normal.Normal(state_pred_mean,
                                                  state_pred_std)
            loss_state = -m.log_prob(next_obs).mean(dim=-1)
        else:
            raise NotImplementedError

        if return_predictions:
            return loss_state, state_pred
        else:
            return loss_state

    def compute_rew_reconstruction_loss(self,
                                        latent,
                                        prev_obs,
                                        next_obs,
                                        action,
                                        reward,
                                        return_predictions=False):
        """ Compute reward reconstruction loss.
        (No reduction of loss along batch dimension is done here; sum/avg has to be done outside) """

        if self.args.multihead_for_reward:

            rew_pred = self.reward_decoder(latent, None)
            if self.args.rew_pred_type == 'categorical':
                rew_pred = F.softmax(rew_pred, dim=-1)
            elif self.args.rew_pred_type == 'bernoulli':
                rew_pred = torch.sigmoid(rew_pred)

            env = gym.make(self.args.env_name)
            state_indices = env.task_to_id(next_obs).to(device)
            if state_indices.dim() < rew_pred.dim():
                state_indices = state_indices.unsqueeze(-1)
            rew_pred = rew_pred.gather(dim=-1, index=state_indices)
            rew_target = (reward == 1).float()
            if self.args.rew_pred_type == 'deterministic':  # TODO: untested!
                loss_rew = (rew_pred - reward).pow(2).mean(dim=-1)
            elif self.args.rew_pred_type in ['categorical', 'bernoulli']:
                loss_rew = F.binary_cross_entropy(
                    rew_pred, rew_target, reduction='none').mean(dim=-1)
            else:
                raise NotImplementedError
        else:
            rew_pred = self.reward_decoder(latent, next_obs, prev_obs,
                                           action.float())
            rew_target = (rew_pred == 1).float()
            if self.args.rew_pred_type == 'bernoulli':
                loss_rew = F.binary_cross_entropy(
                    rew_pred, rew_target, reduction='none').mean(dim=-1)
            elif self.args.rew_pred_type == 'deterministic':
                loss_rew = (rew_pred - reward).pow(2).mean(dim=-1)
            else:
                raise NotImplementedError

        if return_predictions:
            return loss_rew, rew_pred
        else:
            return loss_rew

    def compute_task_reconstruction_loss(self,
                                         latent,
                                         task,
                                         return_predictions=False):
        """ Compute task reconstruction loss.
        (No reduction of loss along batch dimension is done here; sum/avg has to be done outside) """

        task_pred = self.task_decoder(latent)

        if self.args.task_pred_type == 'task_id':
            env = gym.make(self.args.env_name)
            task_target = env.task_to_id(task).to(device)
            # expand along first axis (number of ELBO terms)
            task_target = task_target.expand(task_pred.shape[:-1]).reshape(-1)
            loss_task = F.cross_entropy(
                task_pred.view(-1, task_pred.shape[-1]),
                task_target,
                reduction='none').view(task_pred.shape[:-1])
        elif self.args.task_pred_type == 'task_description':
            loss_task = (task_pred - task).pow(2).mean(dim=-1)
        else:
            raise NotImplementedError

        if return_predictions:
            return loss_task, task_pred
        else:
            return loss_task

    def compute_kl_loss(self, latent_mean, latent_logvar, elbo_indices):
        # -- KL divergence
        if self.args.kl_to_gauss_prior:
            kl_divergences = (-0.5 * (1 + latent_logvar - latent_mean.pow(2) -
                                      latent_logvar.exp()).sum(dim=-1))
        else:
            gauss_dim = latent_mean.shape[-1]
            # add the gaussian prior
            all_means = torch.cat(
                (torch.zeros(1,
                             *latent_mean.shape[1:]).to(device), latent_mean))
            all_logvars = torch.cat(
                (torch.zeros(1, *latent_logvar.shape[1:]).to(device),
                 latent_logvar))
            # https://arxiv.org/pdf/1811.09975.pdf
            # KL(N(mu,E)||N(m,S)) = 0.5 * (log(|S|/|E|) - K + tr(S^-1 E) + (m-mu)^T S^-1 (m-mu)))
            mu = all_means[1:]
            m = all_means[:-1]
            logE = all_logvars[1:]
            logS = all_logvars[:-1]
            kl_divergences = 0.5 * (
                torch.sum(logS, dim=-1) - torch.sum(logE, dim=-1) - gauss_dim +
                torch.sum(1 / torch.exp(logS) * torch.exp(logE), dim=-1) +
                ((m - mu) / torch.exp(logS) * (m - mu)).sum(dim=-1))

        # returns, for each ELBO_t term, one KL (so H+1 kl's)
        if elbo_indices is not None:
            return kl_divergences[elbo_indices]
        else:
            return kl_divergences

    def sum_reconstruction_terms(self, losses, idx_traj, len_encoder,
                                 trajectory_lens):
        """ Sums the reconstruction errors along episode horizon """
        if len(np.unique(
                trajectory_lens)) == 1 and not self.args.decode_only_past:
            # if for each embedding we decode the entire trajectory, we have a matrix and can sum along dim 1
            losses = losses.sum(dim=1)
        else:
            # otherwise, we loop and sum along the trajectory which we decoded (sum in ELBO_t)
            start_idx = 0
            partial_reconstruction_loss = []
            for i, idx_timestep in enumerate(len_encoder[idx_traj]):
                if self.args.decode_only_past:
                    dec_from = 0
                    dec_until = idx_timestep
                else:
                    dec_from = 0
                    dec_until = trajectory_lens[idx_traj]
                end_idx = start_idx + (dec_until - dec_from)
                if end_idx - start_idx != 0:
                    partial_reconstruction_loss.append(
                        losses[start_idx:end_idx].sum())
                start_idx = end_idx
            losses = torch.stack(partial_reconstruction_loss)
        return losses

    def compute_loss(self, latent_mean, latent_logvar, vae_prev_obs,
                     vae_next_obs, vae_actions, vae_rewards, vae_tasks,
                     trajectory_lens):
        """
        Computes the VAE loss for the given data.
        Batches everything together and therefore needs all trajectories to be of the same length.
        (Important because we need to separate ELBOs and decoding terms so can't collapse those dimensions)
        """

        num_unique_trajectory_lens = len(np.unique(trajectory_lens))
        assert (num_unique_trajectory_lens
                == 1) or (self.args.vae_subsample_elbos
                          and self.args.vae_subsample_decodes)
        assert not self.args.decode_only_past

        # cut down the batch to the longest trajectory length
        # this way we can preserve the structure
        # but we will waste some computation on zero-padded trajectories that are shorter than max_traj_len
        max_traj_len = np.max(trajectory_lens)
        latent_mean = latent_mean[:max_traj_len + 1]
        latent_logvar = latent_logvar[:max_traj_len + 1]
        vae_prev_obs = vae_prev_obs[:max_traj_len]
        vae_next_obs = vae_next_obs[:max_traj_len]
        vae_actions = vae_actions[:max_traj_len]
        vae_rewards = vae_rewards[:max_traj_len]

        # take one sample for each ELBO term
        if not self.args.disable_stochasticity_in_latent:
            latent_samples = self.encoder._sample_gaussian(
                latent_mean, latent_logvar)
        else:
            latent_samples = torch.cat((latent_mean, latent_logvar), dim=-1)

        num_elbos = latent_samples.shape[0]
        num_decodes = vae_prev_obs.shape[0]
        batchsize = latent_samples.shape[1]  # number of trajectories

        # subsample elbo terms
        #   shape before: num_elbos * batchsize * dim
        #   shape after: vae_subsample_elbos * batchsize * dim
        if self.args.vae_subsample_elbos is not None:
            # randomly choose which elbo's to subsample
            if num_unique_trajectory_lens == 1:
                elbo_indices = torch.LongTensor(
                    self.args.vae_subsample_elbos * batchsize).random_(
                        0, num_elbos)  # select diff elbos for each task
            else:
                # if we have different trajectory lengths, subsample elbo indices separately
                # up to their maximum possible encoding length;
                # only allow duplicates if the sample size would be larger than the number of samples
                elbo_indices = np.concatenate([
                    np.random.choice(range(0, t + 1),
                                     self.args.vae_subsample_elbos,
                                     replace=self.args.vae_subsample_elbos >
                                     (t + 1)) for t in trajectory_lens
                ])
                if max_traj_len < self.args.vae_subsample_elbos:
                    warnings.warn(
                        'The required number of ELBOs is larger than the shortest trajectory, '
                        'so there will be duplicates in your batch.'
                        'To avoid this use --split_batches_by_elbo or --split_batches_by_task.'
                    )
            task_indices = torch.arange(batchsize).repeat(
                self.args.vae_subsample_elbos)  # for selection mask
            latent_samples = latent_samples[elbo_indices,
                                            task_indices, :].reshape(
                                                (self.args.vae_subsample_elbos,
                                                 batchsize, -1))
            num_elbos = latent_samples.shape[0]
        else:
            elbo_indices = None

        # expand the state/rew/action inputs to the decoder (to match size of latents)
        # shape will be: [num tasks in batch] x [num elbos] x [len trajectory (reconstrution loss)] x [dimension]
        dec_prev_obs = vae_prev_obs.unsqueeze(0).expand(
            (num_elbos, *vae_prev_obs.shape))
        dec_next_obs = vae_next_obs.unsqueeze(0).expand(
            (num_elbos, *vae_next_obs.shape))
        dec_actions = vae_actions.unsqueeze(0).expand(
            (num_elbos, *vae_actions.shape))
        dec_rewards = vae_rewards.unsqueeze(0).expand(
            (num_elbos, *vae_rewards.shape))

        # subsample reconstruction terms
        if self.args.vae_subsample_decodes is not None:
            # shape before: vae_subsample_elbos * num_decodes * batchsize * dim
            # shape after: vae_subsample_elbos * vae_subsample_decodes * batchsize * dim
            # (Note that this will always have duplicates given how we set up the code)
            indices0 = torch.arange(num_elbos).repeat(
                self.args.vae_subsample_decodes * batchsize)
            if num_unique_trajectory_lens == 1:
                indices1 = torch.LongTensor(num_elbos *
                                            self.args.vae_subsample_decodes *
                                            batchsize).random_(0, num_decodes)
            else:
                indices1 = np.concatenate([
                    np.random.choice(range(0, t),
                                     num_elbos *
                                     self.args.vae_subsample_decodes,
                                     replace=True) for t in trajectory_lens
                ])
            indices2 = torch.arange(batchsize).repeat(
                num_elbos * self.args.vae_subsample_decodes)
            dec_prev_obs = dec_prev_obs[indices0, indices1,
                                        indices2, :].reshape(
                                            (num_elbos,
                                             self.args.vae_subsample_decodes,
                                             batchsize, -1))
            dec_next_obs = dec_next_obs[indices0, indices1,
                                        indices2, :].reshape(
                                            (num_elbos,
                                             self.args.vae_subsample_decodes,
                                             batchsize, -1))
            dec_actions = dec_actions[indices0, indices1, indices2, :].reshape(
                (num_elbos, self.args.vae_subsample_decodes, batchsize, -1))
            dec_rewards = dec_rewards[indices0, indices1, indices2, :].reshape(
                (num_elbos, self.args.vae_subsample_decodes, batchsize, -1))
            num_decodes = dec_prev_obs.shape[1]

        # expand the latent (to match the number of state/rew/action inputs to the decoder)
        # shape will be: [num tasks in batch] x [num elbos] x [len trajectory (reconstrution loss)] x [dimension]
        dec_embedding = latent_samples.unsqueeze(0).expand(
            (num_decodes, *latent_samples.shape)).transpose(1, 0)

        if self.args.decode_reward:
            # compute reconstruction loss for this trajectory (for each timestep that was encoded, decode everything and sum it up)
            # shape: [num_elbo_terms] x [num_reconstruction_terms] x [num_trajectories]
            rew_reconstruction_loss = self.compute_rew_reconstruction_loss(
                dec_embedding, dec_prev_obs, dec_next_obs, dec_actions,
                dec_rewards)
            # avg/sum across individual ELBO terms
            if self.args.vae_avg_elbo_terms:
                rew_reconstruction_loss = rew_reconstruction_loss.mean(dim=0)
            else:
                rew_reconstruction_loss = rew_reconstruction_loss.sum(dim=0)
            # avg/sum across individual reconstruction terms
            if self.args.vae_avg_reconstruction_terms:
                rew_reconstruction_loss = rew_reconstruction_loss.mean(dim=0)
            else:
                rew_reconstruction_loss = rew_reconstruction_loss.sum(dim=0)
            # average across tasks
            rew_reconstruction_loss = rew_reconstruction_loss.mean()
        else:
            rew_reconstruction_loss = 0

        if self.args.decode_state:
            state_reconstruction_loss = self.compute_state_reconstruction_loss(
                dec_embedding, dec_prev_obs, dec_next_obs, dec_actions)
            # avg/sum across individual ELBO terms
            if self.args.vae_avg_elbo_terms:
                state_reconstruction_loss = state_reconstruction_loss.mean(
                    dim=0)
            else:
                state_reconstruction_loss = state_reconstruction_loss.sum(
                    dim=0)
            # avg/sum across individual reconstruction terms
            if self.args.vae_avg_reconstruction_terms:
                state_reconstruction_loss = state_reconstruction_loss.mean(
                    dim=0)
            else:
                state_reconstruction_loss = state_reconstruction_loss.sum(
                    dim=0)
            # average across tasks
            state_reconstruction_loss = state_reconstruction_loss.mean()
        else:
            state_reconstruction_loss = 0

        if self.args.decode_task:
            task_reconstruction_loss = self.compute_task_reconstruction_loss(
                latent_samples, vae_tasks)
            # avg/sum across individual ELBO terms
            if self.args.vae_avg_elbo_terms:
                task_reconstruction_loss = task_reconstruction_loss.mean(dim=0)
            else:
                task_reconstruction_loss = task_reconstruction_loss.sum(dim=0)
            # sum the elbos, average across tasks
            task_reconstruction_loss = task_reconstruction_loss.sum(
                dim=0).mean()
        else:
            task_reconstruction_loss = 0

        if not self.args.disable_stochasticity_in_latent:
            # compute the KL term for each ELBO term of the current trajectory
            kl_loss = self.compute_kl_loss(latent_mean, latent_logvar,
                                           elbo_indices)
            # avg/sum the elbos
            if self.args.vae_avg_elbo_terms:
                kl_loss = kl_loss.mean(dim=0)
            else:
                kl_loss = kl_loss.sum(dim=0)
            # average across tasks
            kl_loss = kl_loss.sum(dim=0).mean()
        else:
            kl_loss = 0

        return rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, kl_loss

    def compute_loss_split_batches_by_elbo(self, latent_mean, latent_logvar,
                                           vae_prev_obs, vae_next_obs,
                                           vae_actions, vae_rewards, vae_tasks,
                                           trajectory_lens):
        """
        Loop over the elvo_t terms to compute losses per t.
        Saves some memory if batch sizes are very large,
        or if trajectory lengths are different, or if we decode only the past.
        """

        rew_reconstruction_loss = []
        state_reconstruction_loss = []
        task_reconstruction_loss = []

        assert len(np.unique(trajectory_lens)) == 1
        n_horizon = np.unique(trajectory_lens)[0]
        n_elbos = latent_mean.shape[0]  # includes the prior

        # for each elbo term (including one for the prior)...
        for idx_elbo in range(n_elbos):

            # get the embedding values (size: traj_length+1 * latent_dim; the +1 is for the prior)
            curr_means = latent_mean[idx_elbo]
            curr_logvars = latent_logvar[idx_elbo]

            # take one sample for each task
            if not self.args.disable_stochasticity_in_latent:
                curr_samples = self.encoder._sample_gaussian(
                    curr_means, curr_logvars)
            else:
                curr_samples = torch.cat((latent_mean, latent_logvar))

            # if the size of what we decode is always the same, we can speed up creating the batches
            if not self.args.decode_only_past:

                # expand the latent to match the (x, y) pairs of the decoder
                dec_embedding = curr_samples.unsqueeze(0).expand(
                    (n_horizon, *curr_samples.shape))
                dec_embedding_task = curr_samples

                dec_prev_obs = vae_prev_obs
                dec_next_obs = vae_next_obs
                dec_actions = vae_actions
                dec_rewards = vae_rewards

            # otherwise, we unfortunately have to loop!
            # loop through the lengths we are feeding into the encoder for that trajectory (starting with prior)
            # (these are the different ELBO_t terms)
            else:

                # get the index until which we want to decode
                # (i.e. eithe runtil curr timestep or entire trajectory including future)
                if self.args.decode_only_past:
                    dec_from = 0
                    dec_until = idx_elbo
                else:
                    dec_from = 0
                    dec_until = n_horizon

                if dec_from == dec_until:
                    continue

                # (1) ... get the latent sample after feeding in some data (determined by len_encoder) & expand (to number of outputs)
                # num latent samples x embedding size
                dec_embedding = curr_samples.unsqueeze(0).expand(
                    dec_until - dec_from, *curr_samples.shape)
                dec_embedding_task = curr_samples
                # (2) ... get the predictions for the trajectory until the timestep we're interested in
                dec_prev_obs = vae_prev_obs[dec_from:dec_until]
                dec_next_obs = vae_next_obs[dec_from:dec_until]
                dec_actions = vae_actions[dec_from:dec_until]
                dec_rewards = vae_rewards[dec_from:dec_until]

            if self.args.decode_reward:
                # compute reconstruction loss for this trajectory (for each timestep that was encoded, decode everything and sum it up)
                # size: if all trajectories are of same length [num_elbo_terms x num_reconstruction_terms], otherwise it's flattened into one
                rrc = self.compute_rew_reconstruction_loss(
                    dec_embedding, dec_prev_obs, dec_next_obs, dec_actions,
                    dec_rewards)
                # sum up the reconstruction terms; average over tasks
                rrc = rrc.sum(dim=0).mean()
                rew_reconstruction_loss.append(rrc)

            if self.args.decode_state:
                src = self.compute_state_reconstruction_loss(
                    dec_embedding, dec_prev_obs, dec_next_obs, dec_actions)
                # sum up the reconstruction terms; average over tasks
                src = src.sum(dim=0).mean()
                state_reconstruction_loss.append(src)

            if self.args.decode_task:
                trc = self.compute_task_reconstruction_loss(
                    dec_embedding_task, vae_tasks)
                # average across tasks
                trc = trc.mean()
                task_reconstruction_loss.append(trc)

        # sum the ELBO_t terms
        if self.args.decode_reward:
            rew_reconstruction_loss = torch.stack(rew_reconstruction_loss)
            rew_reconstruction_loss = rew_reconstruction_loss.sum()
        else:
            rew_reconstruction_loss = 0

        if self.args.decode_state:
            state_reconstruction_loss = torch.stack(state_reconstruction_loss)
            state_reconstruction_loss = state_reconstruction_loss.sum()
        else:
            state_reconstruction_loss = 0

        if self.args.decode_task:
            task_reconstruction_loss = torch.stack(task_reconstruction_loss)
            task_reconstruction_loss = task_reconstruction_loss.sum()
        else:
            task_reconstruction_loss = 0

        if not self.args.disable_stochasticity_in_latent:
            # compute the KL term for each ELBO term of the current trajectory
            kl_loss = self.compute_kl_loss(latent_mean, latent_logvar, None)
            # sum the elbos, average across tasks
            kl_loss = kl_loss.sum(dim=0).mean()
        else:
            kl_loss = 0

        return rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, kl_loss

    def compute_vae_loss(self, update=False):
        """
        Returns the VAE loss
        """

        if not self.rollout_storage.ready_for_update():
            return 0

        if self.args.disable_decoder and self.args.disable_stochasticity_in_latent:
            return 0

        # get a mini-batch
        vae_prev_obs, vae_next_obs, vae_actions, vae_rewards, vae_tasks, \
        trajectory_lens = self.rollout_storage.get_batch(batchsize=self.args.vae_batch_num_trajs)
        # vae_prev_obs will be of size: max trajectory len x num trajectories x dimension of observations

        # pass through encoder (outputs will be: (max_traj_len+1) x number of rollouts x latent_dim -- includes the prior!)
        _, latent_mean, latent_logvar, _ = self.encoder(
            actions=vae_actions,
            states=vae_next_obs,
            rewards=vae_rewards,
            hidden_state=None,
            return_prior=True,
            detach_every=self.args.tbptt_stepsize if hasattr(
                self.args, 'tbptt_stepsize') else None,
        )

        if self.args.split_batches_by_task:
            raise NotImplementedError
            losses = self.compute_loss_split_batches_by_task(
                latent_mean, latent_logvar, vae_prev_obs, vae_next_obs,
                vae_actions, vae_rewards, vae_tasks, trajectory_lens,
                len_encoder)
        elif self.args.split_batches_by_elbo:
            losses = self.compute_loss_split_batches_by_elbo(
                latent_mean, latent_logvar, vae_prev_obs, vae_next_obs,
                vae_actions, vae_rewards, vae_tasks, trajectory_lens)
        else:
            losses = self.compute_loss(latent_mean, latent_logvar,
                                       vae_prev_obs, vae_next_obs, vae_actions,
                                       vae_rewards, vae_tasks, trajectory_lens)
        rew_reconstruction_loss, state_reconstruction_loss, task_reconstruction_loss, kl_loss = losses

        # VAE loss = KL loss + reward reconstruction + state transition reconstruction
        # take average (this is the expectation over p(M))
        loss = (self.args.rew_loss_coeff * rew_reconstruction_loss +
                self.args.state_loss_coeff * state_reconstruction_loss +
                self.args.task_loss_coeff * task_reconstruction_loss +
                self.args.kl_weight * kl_loss).mean()

        # make sure we can compute gradients
        if not self.args.disable_stochasticity_in_latent:
            assert kl_loss.requires_grad
        if self.args.decode_reward:
            assert rew_reconstruction_loss.requires_grad
        if self.args.decode_state:
            assert state_reconstruction_loss.requires_grad
        if self.args.decode_task:
            assert task_reconstruction_loss.requires_grad

        # overall loss
        elbo_loss = loss.mean()

        if update:
            self.optimiser_vae.zero_grad()
            elbo_loss.backward()
            self.optimiser_vae.step()
            # clip gradients
            # nn.utils.clip_grad_norm_(self.encoder.parameters(), self.args.a2c_max_grad_norm)
            # nn.utils.clip_grad_norm_(reward_decoder.parameters(), self.args.max_grad_norm)

        self.log(elbo_loss, rew_reconstruction_loss, state_reconstruction_loss,
                 task_reconstruction_loss, kl_loss)

        return elbo_loss

    def log(self, elbo_loss, rew_reconstruction_loss,
            state_reconstruction_loss, task_reconstruction_loss, kl_loss):

        curr_iter_idx = self.get_iter_idx()
        if curr_iter_idx % self.args.log_interval == 0:

            if self.args.decode_reward:
                self.logger.add('vae_losses/reward_reconstr_err',
                                rew_reconstruction_loss.mean(), curr_iter_idx)
            if self.args.decode_state:
                self.logger.add('vae_losses/state_reconstr_err',
                                state_reconstruction_loss.mean(),
                                curr_iter_idx)
            if self.args.decode_task:
                self.logger.add('vae_losses/task_reconstr_err',
                                task_reconstruction_loss.mean(), curr_iter_idx)

            if not self.args.disable_stochasticity_in_latent:
                self.logger.add('vae_losses/kl', kl_loss.mean(), curr_iter_idx)
            self.logger.add('vae_losses/sum', elbo_loss, curr_iter_idx)