Ejemplo n.º 1
0
def _transform_to_interval(constraint):
    loc = constraint.lower_bound
    scale = constraint.upper_bound - constraint.lower_bound
    return transforms.ComposeTransform([
        transforms.SigmoidTransform(),
        transforms.AffineTransform(loc, scale)
    ])
Ejemplo n.º 2
0
def _transform_to_interval(constraint):
    # Handle the special case of the unit interval.
    lower_is_0 = isinstance(constraint.lower_bound, numbers.Number) and constraint.lower_bound == 0
    upper_is_1 = isinstance(constraint.upper_bound, numbers.Number) and constraint.upper_bound == 1
    if lower_is_0 and upper_is_1:
        return transforms.SigmoidTransform()

    loc = constraint.lower_bound
    scale = constraint.upper_bound - constraint.lower_bound
    return transforms.ComposeTransform([transforms.SigmoidTransform(),
                                        transforms.AffineTransform(loc, scale)])
Ejemplo n.º 3
0
def mcmc_transform(
    prior: Distribution,
    num_prior_samples_for_zscoring: int = 1000,
    enable_transform: bool = True,
    device: str = "cpu",
    **kwargs,
) -> TorchTransform:
    """
    Builds a transform that is applied to parameters during MCMC.

    The resulting transform is defined such that the forward mapping maps from
    constrained to unconstrained space.

    It does two things:
    1) When the prior support is bounded, it transforms the parameters into unbounded
        space.
    2) It z-scores the parameters such that MCMC is performed in a z-scored space.

    Args:
        prior: The prior distribution.
        num_prior_samples_for_zscoring: The number of samples drawn from the prior
            to infer the `mean` and `stddev` of the prior used for z-scoring. Unused if
            the prior has bounded support or when the prior has `mean` and `stddev`
            attributes.
        enable_transform: Whether or not to use a transformation during MCMC.

    Returns: A transformation that transforms whose `forward()` maps from unconstrained
        (or z-scored) to constrained (or non-z-scored) space.
    """

    if enable_transform:
        # Some distributions have a support argument but it raises a
        # NotImplementedError. We catch this case here.
        try:
            _ = prior.support
            has_support = True
        except (NotImplementedError, AttributeError):
            # NotImplementedError -> Distribution that inherits from torch dist but
            # does not implement support.
            # AttributeError -> Custom distribution that has no support attribute.
            warnings.warn(
                """The passed prior has no support property, transform will be
                constructed from mean and std. If the passed prior is supposed to be
                bounded consider implementing the prior.support property.""")
            has_support = False

        # If the distribution has a `support`, check if the support is bounded.
        # If it is not bounded, we want to z-score the space. This is not done
        # by `biject_to()`, so we have to deal with this case separately.
        if has_support:
            if hasattr(prior.support, "base_constraint"):
                constraint = prior.support.base_constraint  # type: ignore
            else:
                constraint = prior.support
            if isinstance(constraint, constraints._Real):
                support_is_bounded = False
            else:
                support_is_bounded = True
        else:
            support_is_bounded = False

        # Prior with bounded support, e.g., uniform priors.
        if has_support and support_is_bounded:
            transform = biject_to(prior.support)
        # For all other cases build affine transform with mean and std.
        else:
            if hasattr(prior, "mean") and hasattr(prior, "stddev"):
                prior_mean = prior.mean.to(device)
                prior_std = prior.stddev.to(device)
            else:
                theta = prior.sample(
                    torch.Size((num_prior_samples_for_zscoring, )))
                prior_mean = theta.mean(dim=0).to(device)
                prior_std = theta.std(dim=0).to(device)

            transform = torch_tf.AffineTransform(loc=prior_mean,
                                                 scale=prior_std)
    else:
        transform = torch_tf.identity_transform

    # Pytorch `transforms` do not sum the determinant over the parameters. However, if
    # the `transform` explicitly is an `IndependentTransform`, it does. Since our
    # `BoxUniform` is a `Independent` distribution, it will also automatically get a
    # `IndependentTransform` wrapper in `biject_to`. Our solution here is to wrap all
    # transforms as `IndependentTransform`.
    if not isinstance(transform, torch_tf.IndependentTransform):
        transform = torch_tf.IndependentTransform(transform,
                                                  reinterpreted_batch_ndims=1)

    check_transform(prior, transform)  # type: ignore

    return transform.inv  # type: ignore
Ejemplo n.º 4
0
def _transform_to_less_than(constraint):
    return transforms.ComposeTransform([
        transforms.ExpTransform(),
        transforms.AffineTransform(constraint.upper_bound, -1)
    ])
Ejemplo n.º 5
0
def _transform_to_greater_than(constraint):
    return transforms.ComposeTransform([
        transforms.ExpTransform(),
        transforms.AffineTransform(constraint.lower_bound, 1)
    ])
Ejemplo n.º 6
0
    def condition(self, context):
        loc, log_scale = self.context_nn(context)
        scale = torch.exp(log_scale)

        ac = transforms.AffineTransform(loc, scale, event_dim=self.event_dim)
        return ac
    def forward(self, obs, deterministic=False, with_logprob=True):
        """Perform forward pass through the network.

        Args:
            obs (torch.Tensor): The tensor of observations.

            deterministic (bool, optional): Whether we want to use a deterministic
                policy (used at test time). When true the mean action of the stochastic
                policy is returned. If false the action is sampled from the stochastic
                policy. Defaults to False.

            with_logprob (bool, optional): Whether we want to return the log probability
                of an action. Defaults to True.

        Returns:
            torch.Tensor,  torch.Tensor: The actions given by the policy, the log
            probabilities of each of these actions.
        """

        # Create base distribution
        base_distribution = mn.MultivariateNormal(torch.zeros(self.a_dim),
                                                  torch.eye(self.a_dim))
        epsilon = base_distribution.sample((obs.shape[0], ))

        # Calculate required variables
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_sigma = self.log_sigma(net_out)
        log_sigma = torch.clamp(log_sigma, self._log_std_min,
                                self._log_std_max)
        sigma = torch.exp(log_sigma)

        # Create bijection
        squash_bijector = transforms.TanhTransform()
        affine_bijector = transforms.ComposeTransform([
            transforms.AffineTransform(mu, sigma),
        ])

        # Calculate raw action
        raw_action = bijector(epsilon)

        # Check summing axis
        sum_axis = 0 if obs.shape.__len__() == 1 else 1

        # Pre-squash distribution and sample
        # TODO: Check if this has the right size. LAC samples from base distribution
        # Sample size is memory buffer size!
        pi_distribution = Normal(mu, sigma)
        raw_action = (
            pi_distribution.rsample()
            # DEBUG: The tensorflow implmentation samples
        )  # Sample while using the parameterization trick

        # Compute log probability in squashed gaussian
        if with_logprob:
            # Compute logprob from Gaussian, and then apply correction for Tanh
            # squashing. NOTE: The correction formula is a little bit magic. To get an
            # understanding of where it comes from, check out the original SAC paper
            # (arXiv 1801.01290) and look in appendix C. This is a more
            # numerically-stable equivalent to Eq 21. Try deriving it yourself as a
            # (very difficult) exercise. :)
            logp_pi = pi_distribution.log_prob(raw_action).sum(axis=-1)
            logp_pi -= (2 * (np.log(2) - raw_action -
                             F.softplus(-2 * raw_action))).sum(axis=1)
        else:
            logp_pi = None

        # Calculate scaled action and return the action and its log probability
        clipped_a = torch.tanh(
            raw_action)  # Squash gaussian to be between -1 and 1

        # Get clipped mu
        # FIXME: Is this okay LAC also squashes this output?!
        # clipped_mu = torch.tanh(mu) # LAC version
        clipped_mu = mu

        # Return action and log likelihood
        # Debug: The LAC expects a distribution we already return the log probabilities
        return clipped_a, clipped_mu, logp_pi
Ejemplo n.º 8
0
def _transform_to_less_than(constraint):
    loc = constraint.upper_bound
    scale = loc.new([-1]).expand_as(loc)
    return transforms.ComposeTransform(
        [transforms.ExpTransform(),
         transforms.AffineTransform(loc, scale)])
Ejemplo n.º 9
0
def _transform_to_less_than(constraint):
    loc, scale = broadcast_all(constraint.upper_bound, -1)
    return transforms.ComposeTransform(
        [transforms.ExpTransform(),
         transforms.AffineTransform(loc, scale)])
Ejemplo n.º 10
0
def _transform_to_greater_than(constraint):
    loc, scale = broadcast_all(constraint.lower_bound, 1)
    return transforms.ComposeTransform(
        [transforms.ExpTransform(),
         transforms.AffineTransform(loc, scale)])
Ejemplo n.º 11
0
def TanhTransform():
    return transforms.ComposeTransform([
        transforms.AffineTransform(loc=0., scale=2.),
        transforms.SigmoidTransform(),
        transforms.AffineTransform(loc=-1., scale=2.)
    ])