コード例 #1
0
ファイル: agent.py プロジェクト: GittiHab/dreamer-pytorch
    def _compute_loss_world(self, state, data):
        # unpackage data
        beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = state
        observations, rewards, nonterminals = data

        observation_loss = F.mse_loss(
            bottle(self.observation_model, (beliefs, posterior_states)),
            observations,
            reduction='none').sum(
                dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1))

        reward_loss = F.mse_loss(bottle(self.reward_model,
                                        (beliefs, posterior_states)),
                                 rewards,
                                 reduction='none').mean(dim=(0, 1))  # TODO: 5

        # transition loss
        kl_loss = torch.max(
            kl_divergence(
                Independent(Normal(posterior_means, posterior_std_devs), 1),
                Independent(Normal(prior_means, prior_std_devs), 1)),
            self.free_nats).mean(dim=(0, 1))

        if self.args.pcont:
            pcont_loss = F.binary_cross_entropy(
                bottle(self.pcont_model, (beliefs, posterior_states)),
                nonterminals)

        return observation_loss, self.args.reward_scale * reward_loss, kl_loss, (
            self.args.pcont_scale * pcont_loss if self.args.pcont else 0)
コード例 #2
0
ファイル: actor.py プロジェクト: mjacar/deep-rl-algorithms
 def choose_action(self, state):
     state = torch.from_numpy(state).float().unsqueeze(0)
     mean, logstd = self.policy(state.cuda())
     dist = Independent(Normal(mean.squeeze(), torch.exp(logstd)), 1)
     action = dist.sample()
     log_prob = dist.log_prob(action)
     return action.squeeze().cpu().numpy(), log_prob.item()
コード例 #3
0
 def get_logp(mean, std, action, expand=None):
     if expand is not None:
         dist = Independent(Normal(mean, std),
                            reinterpreted_batch_ndims=1).expand(expand)
     else:
         dist = Independent(Normal(mean, std), reinterpreted_batch_ndims=1)
     return dist.log_prob(action)
コード例 #4
0
ファイル: mlp_policy.py プロジェクト: Quanticnova/td-reg
    def get_log_prob_entropy(self, x, actions):
        action_mean, action_log_std, action_std = self.forward(x)   #[batch_size, a_dim]

        normal = Normal(loc=action_mean, scale=action_std)
        diagn = Independent(normal, 1)
        log_prob = diagn.log_prob(actions).unsqueeze(dim=1)
        entropy = diagn.entropy()[0]

        #prob = MultivariateNormal(loc=action_mean, scale_tril=torch.diag(action_std[0,:]**2))
        #log_prob = prob.log_prob(actions).unsqueeze(dim=1)
        #entropy = prob.entropy()[0]

        return log_prob, entropy
コード例 #5
0
 def forward(self, state):
     outputs = super().forward(state)
     action_dim = outputs.shape[-1] // 2
     means = outputs[..., 0:action_dim]
     logvars = outputs[..., action_dim:]
     std = (0.5 * logvars).exp_()
     return Independent(Normal(means + self._center, std * self._scale), 1)
コード例 #6
0
    def __init__(self, base_distribution, transforms, validate_args=None):
        if isinstance(transforms, Transform):
            self.transforms = [transforms, ]
        elif isinstance(transforms, list):
            if not all(isinstance(t, Transform) for t in transforms):
                raise ValueError("transforms must be a Transform or a list of Transforms")
            self.transforms = transforms
        else:
            raise ValueError("transforms must be a Transform or list, but was {}".format(transforms))

        # Reshape base_distribution according to transforms.
        base_shape = base_distribution.batch_shape + base_distribution.event_shape
        base_event_dim = len(base_distribution.event_shape)
        transform = ComposeTransform(self.transforms)
        domain_event_dim = transform.domain.event_dim
        if len(base_shape) < domain_event_dim:
            raise ValueError("base_distribution needs to have shape with size at least {}, but got {}."
                             .format(domain_event_dim, base_shape))
        shape = transform.forward_shape(base_shape)
        expanded_base_shape = transform.inverse_shape(shape)
        if base_shape != expanded_base_shape:
            base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim]
            base_distribution = base_distribution.expand(base_batch_shape)
        reinterpreted_batch_ndims = domain_event_dim - base_event_dim
        if reinterpreted_batch_ndims > 0:
            base_distribution = Independent(base_distribution, reinterpreted_batch_ndims)
        self.base_dist = base_distribution

        # Compute shapes.
        event_dim = transform.codomain.event_dim + max(base_event_dim - domain_event_dim, 0)
        assert len(shape) >= event_dim
        cut = len(shape) - event_dim
        batch_shape = shape[:cut]
        event_shape = shape[cut:]
        super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)
コード例 #7
0
    def forward(self, *inputs):
        """Forward method.
        Args:
            *inputs: Input to the module.
        Returns:
            torch.distributions.independent.Independent: Independent
                distribution.
        """
        mean, log_std_uncentered = self._get_mean_and_log_std(*inputs)

        if self._min_std_param or self._max_std_param:
            log_std_uncentered = log_std_uncentered.clamp(
                min=(None if self._min_std_param is None else
                     self._min_std_param.item()),
                max=(None if self._max_std_param is None else
                     self._max_std_param.item()))

        if self._std_parameterization == 'exp':
            std = log_std_uncentered.exp()
        else:
            std = log_std_uncentered.exp().exp().add(1.).log()
        dist = self._norm_dist_class(mean, std)
        # This control flow is needed because if a TanhNormal distribution is
        # wrapped by torch.distributions.Independent, then custom functions
        # such as rsample_with_pretanh_value of the TanhNormal distribution
        # are not accessable.
        if not isinstance(dist, TanhNormal):
            # Makes it so that a sample from the distribution is treated as a
            # single sample and not dist.batch_shape samples.
            dist = Independent(dist, 1)

        return dist
コード例 #8
0
    def forward(self, *inputs):
        """Forward method.

        Args:
            *inputs: Input to the module.

        Returns:
            torch.Tensor: Module output.

        """
        mean, log_std_uncentered = self._get_mean_and_log_std(*inputs)

        if self._min_std_param or self._max_std_param:
            log_std_uncentered = log_std_uncentered.clamp(
                min=self._to_scalar_if_not_none(self._min_std_param),
                max=self._to_scalar_if_not_none(self._max_std_param))

        if self._std_parameterization == 'exp':
            std = log_std_uncentered.exp()
        else:
            std = log_std_uncentered.exp().exp().add(1.).log()

        dist = Independent(Normal(mean, std), 1)

        return dist
コード例 #9
0
 def forward(self, state):
     outputs = super().forward(state)
     action_dim = outputs.shape[1] // 2
     means = self._squash(torch.tanh(outputs[:, 0:action_dim]))
     logvars = outputs[:, action_dim:] * self._scale
     std = logvars.exp_()
     return Independent(Normal(means, std), 1)
コード例 #10
0
    def forward(self, *inputs):
        """Forward method.

        Args:
            *inputs: Input to the module.

        Returns:
            torch.Tensor: Module output.

        """
        mean, log_std_uncentered = self._get_mean_and_log_std(*inputs)

        if self._min_std_param or self._max_std_param:
            log_std_uncentered = log_std_uncentered.clamp(
                min=self._to_scalar_if_not_none(self._min_std_param),
                max=self._to_scalar_if_not_none(self._max_std_param))

        if self._std_parameterization == 'exp':
            std = log_std_uncentered.exp()
        else:
            std = log_std_uncentered.exp().exp().add(1.).log()
        dist = self._norm_dist_class(mean, std)
        # Makes it so that a sample from the distribution is treated as a
        # single sample and not dist.batch_shape samples.
        dist = Independent(dist, len(dist.batch_shape))

        return dist
コード例 #11
0
 def spectral_density(self, smk) -> MixtureSameFamily:
     """Returns the Mixture of Gaussians thet model the spectral density
     of the provided spectral mixture kernel."""
     mus = smk.mixture_means.detach().reshape(-1, 1)
     sigmas = smk.mixture_scales.detach().reshape(-1, 1)
     mix = Categorical(smk.mixture_weights.detach())
     comp = Independent(Normal(mus, sigmas), 1)
     return MixtureSameFamily(mix, comp)
コード例 #12
0
    def _compute_loss_world(self, state, data):
        # unpackage data
        beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = state
        observations, rewards, nonterminals = data

        # observation_loss = F.mse_loss(
        #   bottle(self.observation_model, (beliefs, posterior_states)),
        #   observations[1:],
        #   reduction='none').sum(dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1))
        #
        # reward_loss = F.mse_loss(
        #   bottle(self.reward_model, (beliefs, posterior_states)),
        #   rewards[1:],
        #   reduction='none').mean(dim=(0,1))

        observation_loss = F.mse_loss(
            bottle(self.observation_model, (beliefs, posterior_states)),
            observations,
            reduction='none').sum(
                dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1))

        reward_loss = F.mse_loss(bottle(self.reward_model,
                                        (beliefs, posterior_states)),
                                 rewards,
                                 reduction='none').mean(dim=(0, 1))  # TODO: 5

        # transition loss
        kl_loss = torch.max(
            kl_divergence(
                Independent(Normal(posterior_means, posterior_std_devs), 1),
                Independent(Normal(prior_means, prior_std_devs), 1)),
            self.free_nats).mean(dim=(0, 1))

        # print("check the reward", bottle(pcont_model, (beliefs, posterior_states)).shape, nonterminals[:-1].shape)
        if self.args.pcont:
            pcont_loss = F.binary_cross_entropy(
                bottle(self.pcont_model, (beliefs, posterior_states)),
                nonterminals)
            # pcont_pred = torch.distributions.Bernoulli(logits=bottle(self.pcont_model, (beliefs, posterior_states)))
            # pcont_loss = -pcont_pred.log_prob(nonterminals[1:]).mean(dim=(0, 1))

        return observation_loss, self.args.reward_scale * reward_loss, kl_loss, (
            self.args.pcont_scale * pcont_loss if self.args.pcont else 0)
コード例 #13
0
    def latent_priors(self):
        scale_prior_mu, scale_prior_sd = 3.0, 0.1
        pos_prior_mu, pos_prior_sd = 0.0, 1.0
        z_where_mu_prior = nn.Parameter(torch.FloatTensor(
            [scale_prior_mu, pos_prior_mu, pos_prior_mu]).expand(n, -1),
                                        requires_grad=False)
        z_where_sd_prior = nn.Parameter(torch.FloatTensor(
            [scale_prior_sd, pos_prior_sd, pos_prior_sd]).expand(n, -1),
                                        requires_grad=False)
        z_what_mu_prior = nn.Parameter(torch.zeros(50))
        z_what_sd_prior = nn.Parameter(torch.ones(50))

        z_where = Independent(Normal(z_where_mu_prior, z_where_sd_prior),
                              1).sample()
        z_what = Independent(Normal(z_what_mu_prior, z_what_sd_prior),
                             1).sample()

        z_pres = torch.ones(n, 1)

        return z_where, z_what, z_pres
コード例 #14
0
 def log_prob(self, value_mean, value_precision):
     if self._validate_args:
         self._validate_sample(value_mean)
         self._validate_sample(value_precision)
     if (value_precision <= 0).any():
         raise ValueError("desired precision must be greater that 0")
     wishart_log_prob = DiagonalWishart(self.precision_diag,
                                        self.df).log_prob(value_precision)
     normal_log_prob = Independent(
         Normal(self.loc,
                (1 /
                 (self.belief.unsqueeze(-1) * value_precision)).pow(0.5)),
         1).log_prob(value_mean)
     return normal_log_prob + wishart_log_prob
コード例 #15
0
    def encode(self, x, z_where_prev, z_what_prev, z_pres_prev, h_prev,
               c_prev):

        kld_loss = 0
        h, c = compute_hidden_state(self.rnn, x, z_where_prev, z_what_prev,
                                    z_pres_prev, h_prev, c_prev)

        z_pres_proba, z_where_mu, z_where_sd = self.predict(h)
        kld_loss += self.latent_loss(z_where_mu, z_where_sd)

        z_pres = Independent(Bernoulli(z_pres_proba * z_pres_prev), 1).sample()

        z_where = self._reparameterized_sample(z_where_mu, z_where_sd)

        x_att = attentive_stn_encode(z_where, x)

        z_what_mu, z_what_sd = self.obj_encode(x_att)
        kld_loss += self.latent_loss(z_what_mu, z_what_sd)

        z_what = self._reparameterized_sample(z_what_mu, z_what_sd)

        return z_where, z_what, z_pres, h, c, kld_loss
コード例 #16
0
ファイル: gaussian_mlp.py プロジェクト: keiohta/torchrl
    def forward(self, *inputs):
        mean, log_std_uncentered = self._get_mean_and_log_std(*inputs)

        if self._min_std_param or self._max_std_param:
            log_std_uncentered = log_std_uncentered.clamp(
                min=(None if self._min_std_param is None else
                     self._min_std_param.item()),
                max=(None if self._max_std_param is None else
                     self._max_std_param.item()))

        if self._std_parameterization == 'exp':
            std = log_std_uncentered.exp()
        else:
            std = log_std_uncentered.exp().exp().add(1.).log()

        dist = self._norm_dist_class(mean, std)

        # Makes it so that a sample from the distribution is treated as a
        # single sample and not dist.batch_shape samples.
        dist = Independent(dist, 1)

        return dist
コード例 #17
0
    def __init__(self,
                 obs_shape,
                 action_dim,
                 hidden_size,
                 inner_hidden,
                 discrete_actions=False,
                 base=None,
                 base_kwargs=None):
        super().__init__()
        if base_kwargs is None:
            base_kwargs = {}

        if discrete_actions:
            self.hyper = False
            self.dist_type = Categorical
            self.base = AdvantageNetwork(obs_shape, hidden_size, action_dim)
            self.value_base = ValueNetwork(obs_shape, hidden_size)
        else:
            self.hyper = True
            self.base = HyperOption(obs_shape, hidden_size, action_dim,
                                    inner_hidden, **base_kwargs)
            self.value_base = ValueNetwork(obs_shape, hidden_size)
            self.dist_type = lambda mean, sigma: Independent(
                Normal(mean, sigma), 1)
コード例 #18
0
def MultivariateNormalDiag(loc, scale_diag):
    """Multi variate Gaussian with a diagonal covariance function (on the last dimension)."""
    if loc.dim() < 1:
        raise ValueError("loc must be at least one-dimensional.")
    return Independent(Normal(loc, scale_diag), 1)
コード例 #19
0
ファイル: world_models.py プロジェクト: mihdalal/rlkit
 def get_dist(self, mean, std, dims=1, normal=True):
     if normal:
         return Independent(Normal(mean, std), dims)
     else:
         return Independent(Bernoulli(logits=mean), dims)
コード例 #20
0
 def choose_action(self, state):
     state = torch.from_numpy(state).float().unsqueeze(0)
     mean, logstd = self.policy(state)
     dist = Independent(Normal(mean, torch.exp(logstd)), 1)
     return dist.sample().squeeze().numpy()
コード例 #21
0
def StandardNormal(d,device=torch.device('cuda:0')):
    return Independent(Normal(torch.zeros(d).to(device),torch.ones(d).to(device)),1)
コード例 #22
0
ファイル: tanh_normal.py プロジェクト: ziyiwu9494/garage
class TanhNormal(torch.distributions.Distribution):
    r"""A distribution induced by applying a tanh transformation to a Gaussian random variable.

    Algorithms like SAC and Pearl use this transformed distribution.
    It can be thought of as a distribution of X where
        :math:`Y ~ \mathcal{N}(\mu, \sigma)`
        :math:`X = tanh(Y)`

    Args:
        loc (torch.Tensor): The mean of this distribution.
        scale (torch.Tensor): The stdev of this distribution.

    """ # noqa: 501

    def __init__(self, loc, scale):
        self._normal = Independent(Normal(loc, scale), 1)
        super().__init__()

    def log_prob(self, value, pre_tanh_value=None, epsilon=1e-6):
        """The log likelihood of a sample on the this Tanh Distribution.

        Args:
            value (torch.Tensor): The sample whose loglikelihood is being
                computed.
            pre_tanh_value (torch.Tensor): The value prior to having the tanh
                function applied to it but after it has been sampled from the
                normal distribution.
            epsilon (float): Regularization constant. Making this value larger
                makes the computation more stable but less precise.

        Note:
              when pre_tanh_value is None, an estimate is made of what the
              value is. This leads to a worse estimation of the log_prob.
              If the value being used is collected from functions like
              `sample` and `rsample`, one can instead use functions like
              `sample_return_pre_tanh_value` or
              `rsample_return_pre_tanh_value`


        Returns:
            torch.Tensor: The log likelihood of value on the distribution.

        """
        # pylint: disable=arguments-differ
        if pre_tanh_value is None:
            pre_tanh_value = torch.log(
                (1 + epsilon + value) / (1 + epsilon - value)) / 2
        norm_lp = self._normal.log_prob(pre_tanh_value)
        ret = (norm_lp - torch.sum(
            torch.log(self._clip_but_pass_gradient((1. - value**2)) + epsilon),
            axis=-1))
        return ret

    def sample(self, sample_shape=torch.Size()):
        """Return a sample, sampled from this TanhNormal Distribution.

        Args:
            sample_shape (list): Shape of the returned value.

        Note:
            Gradients `do not` pass through this operation.

        Returns:
            torch.Tensor: Sample from this TanhNormal distribution.

        """
        with torch.no_grad():
            return self.rsample(sample_shape=sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        """Return a sample, sampled from this TanhNormal Distribution.

        Args:
            sample_shape (list): Shape of the returned value.

        Note:
            Gradients pass through this operation.

        Returns:
            torch.Tensor: Sample from this TanhNormal distribution.

        """
        z = self._normal.rsample(sample_shape)
        return torch.tanh(z)

    def rsample_with_pre_tanh_value(self, sample_shape=torch.Size()):
        """Return a sample, sampled from this TanhNormal distribution.

        Returns the sampled value before the tanh transform is applied and the
        sampled value with the tanh transform applied to it.

        Args:
            sample_shape (list): shape of the return.

        Note:
            Gradients pass through this operation.

        Returns:
            torch.Tensor: Samples from this distribution.
            torch.Tensor: Samples from the underlying
                :obj:`torch.distributions.Normal` distribution, prior to being
                transformed with `tanh`.

        """
        z = self._normal.rsample(sample_shape)
        return z, torch.tanh(z)

    def cdf(self, value):
        """Returns the CDF at the value.

        Returns the cumulative density/mass function evaluated at
        `value` on the underlying normal distribution.

        Args:
            value (torch.Tensor): The element where the cdf is being evaluated
                at.

        Returns:
            torch.Tensor: the result of the cdf being computed.

        """
        return self._normal.cdf(value)

    def icdf(self, value):
        """Returns the icdf function evaluated at `value`.

        Returns the icdf function evaluated at `value` on the underlying
        normal distribution.

        Args:
            value (torch.Tensor): The element where the cdf is being evaluated
                at.

        Returns:
            torch.Tensor: the result of the cdf being computed.

        """
        return self._normal.icdf(value)

    @classmethod
    def _from_distribution(cls, new_normal):
        """Construct a new TanhNormal distribution from a normal distribution.

        Args:
            new_normal (Independent(Normal)): underlying normal dist for
                the new TanhNormal distribution.

        Returns:
            TanhNormal: A new distribution whose underlying normal dist
                is new_normal.

        """
        # pylint: disable=protected-access
        new = cls(torch.zeros(1), torch.zeros(1))
        new._normal = new_normal
        return new

    def expand(self, batch_shape, _instance=None):
        """Returns a new TanhNormal distribution.

        (or populates an existing instance provided by a derived class) with
        batch dimensions expanded to `batch_shape`. This method calls
        :class:`~torch.Tensor.expand` on the distribution's parameters. As
        such, this does not allocate new memory for the expanded distribution
        instance. Additionally, this does not repeat any args checking or
        parameter broadcasting in `__init__.py`, when an instance is first
        created.

        Args:
            batch_shape (torch.Size): the desired expanded size.
            _instance(instance): new instance provided by subclasses that
                need to override `.expand`.

        Returns:
            Instance: New distribution instance with batch dimensions expanded
            to `batch_size`.

        """
        new_normal = self._normal.expand(batch_shape, _instance)
        new = self._from_distribution(new_normal)
        return new

    def enumerate_support(self, expand=True):
        """Returns tensor containing all values supported by a discrete dist.

        The result will enumerate over dimension 0, so the shape
        of the result will be `(cardinality,) + batch_shape + event_shape`
        (where `event_shape = ()` for univariate distributions).

        Note that this enumerates over all batched tensors in lock-step
        `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens
        along dim 0, but with the remaining batch dimensions being
        singleton dimensions, `[[0], [1], ..`.

        To iterate over the full Cartesian product use
        `itertools.product(m.enumerate_support())`.

        Args:
            expand (bool): whether to expand the support over the
                batch dims to match the distribution's `batch_shape`.

        Note:
            Calls the enumerate_support function of the underlying normal
            distribution.

        Returns:
            torch.Tensor: Tensor iterating over dimension 0.

        """
        return self._normal.enumerate_support(expand)

    @property
    def mean(self):
        """torch.Tensor: mean of the distribution."""
        return torch.tanh(self._normal.mean)

    @property
    def variance(self):
        """torch.Tensor: variance of the underlying normal distribution."""
        return self._normal.variance

    def entropy(self):
        """Returns entropy of the underlying normal distribution.

        Returns:
            torch.Tensor: entropy of the underlying normal distribution.

        """
        return self._normal.entropy()

    @staticmethod
    def _clip_but_pass_gradient(x, lower=0., upper=1.):
        """Clipping function that allows for gradients to flow through.

        Args:
            x (torch.Tensor): value to be clipped
            lower (float): lower bound of clipping
            upper (float): upper bound of clipping

        Returns:
            torch.Tensor: x clipped between lower and upper.

        """
        clip_up = (x > upper).float()
        clip_low = (x < lower).float()
        with torch.no_grad():
            clip = ((upper - x) * clip_up + (lower - x) * clip_low)
        return x + clip

    def __repr__(self):
        """Returns the parameterization of the distribution.

        Returns:
            str: The parameterization of the distribution and underlying
                distribution.

        """
        return self.__class__.__name__
コード例 #23
0
ファイル: tanh_normal.py プロジェクト: ziyiwu9494/garage
 def __init__(self, loc, scale):
     self._normal = Independent(Normal(loc, scale), 1)
     super().__init__()
コード例 #24
0
                    [
                        torch.tensor(sample[3]).float().unsqueeze(0)
                        for sample in rollout_batch
                    ]
                ).cuda()
                old_values = torch.cat(
                    [torch.tensor(sample[4]).unsqueeze(0) for sample in rollout_batch]
                ).cuda()
                old_log_probs = torch.cat(
                    [torch.tensor(sample[5]).unsqueeze(0) for sample in rollout_batch]
                ).cuda()

                means, logstd = policy(states)
                dist = Independent(
                    Normal(
                        means, torch.exp(logstd.unsqueeze(0).expand(batch_size, -1))
                    ),
                    1,
                )
                log_probs = dist.log_prob(actions.squeeze())

                values = value_fn(states)
                clipped_values = old_values + torch.clamp(
                    values - old_values, -args.clip_range, args.clip_range
                )
                l_vf1 = (values - targets).pow(2)
                l_vf2 = (clipped_values - targets).pow(2)
                value_fn_loss = 0.5 * torch.max(l_vf1, l_vf2).mean()
                value_fn_loss.backward()
                clip_grad_norm_(value_fn.parameters(), args.max_grad_norm)
                value_fn_opt.step()
                value_fn_opt.zero_grad()
コード例 #25
0
 def get_act(mean, std, amount=None):
     dist = Independent(Normal(mean, std), reinterpreted_batch_ndims=1)
     if amount is not None:
         return dist.sample(amount)
     return dist.sample()
コード例 #26
0
def MultivariateNormalDiag(loc, scale_diag):
    if loc.dim() < 1:
        raise ValueError("loc must be at least one-dimensional.")
    return Independent(Normal(loc, scale_diag), 1)
コード例 #27
0
        observation_loss = F.mse_loss(
            bottle(observation_model, (beliefs, posterior_states)),
            observations[1:],
            reduction='none').sum(dim=2 if args.symbolic else (2, 3, 4)).mean(
                dim=(0, 1))

        reward_loss = F.mse_loss(bottle(reward_model,
                                        (beliefs, posterior_states)),
                                 rewards[1:],
                                 reduction='none').mean(dim=(0, 1))

        # transition loss
        kl_loss = torch.max(
            kl_divergence(
                Independent(Normal(posterior_means, posterior_std_devs), 1),
                Independent(Normal(prior_means, prior_std_devs), 1)),
            free_nats).mean(dim=(0, 1))

        # print("check the reward", bottle(pcont_model, (beliefs, posterior_states)).shape, nonterminals[:-1].shape)
        if args.pcont:
            pcont_pred = torch.distributions.Bernoulli(
                logits=bottle(pcont_model, (beliefs, posterior_states)))
            # print("check pcont", pcont_pred)
            # print("nonterminal", nonterminals[1:])
            pcont_loss = -pcont_pred.log_prob(
                nonterminals[1:]).mean(dim=(0, 1))
            print(pcont_loss)
        # Update model parameters
        world_optimizer.zero_grad()
        (observation_loss + reward_loss + kl_loss +