Beispiel #1
0
def guide_3DA(data):    
    # Hyperparameters    
    a_psi_1 = pyro.param('a_psi_1', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) 
    b_psi_1 = pyro.param('b_psi_1', torch.tensor(np.pi), constraint=constraints.less_than(3.15))
    x_psi_1 = pyro.param('x_psi_1', torch.tensor(2.), constraint=constraints.positive)
    
    a_phi_2 = pyro.param('a_phi_2', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) 
    b_phi_2 = pyro.param('b_phi_2', torch.tensor(np.pi), constraint=constraints.less_than(3.15))
    x_phi_2 = pyro.param('x_phi_2', torch.tensor(2.), constraint=constraints.positive)
    
    a_psi_2 = pyro.param('a_psi_2', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) 
    b_psi_2 = pyro.param('b_psi_2', torch.tensor(np.pi), constraint=constraints.less_than(3.15))
    x_psi_2 = pyro.param('x_psi_2', torch.tensor(2.), constraint=constraints.positive)
    
    a_phi_3 = pyro.param('a_phi_3', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) 
    b_phi_3 = pyro.param('b_phi_3', torch.tensor(np.pi), constraint=constraints.less_than(3.15))
    x_phi_3 = pyro.param('x_phi_3', torch.tensor(2.), constraint=constraints.positive)
    
    # Sampling mu and kappa
    pyro.sample("mu_psi_1", dist.Uniform(a_psi_1, b_psi_1))
    pyro.sample("inv_kappa_psi_1", dist.HalfNormal(x_psi_1))    
    pyro.sample("mu_phi_2", dist.Uniform(a_phi_2, b_phi_2))
    pyro.sample("inv_kappa_phi_2", dist.HalfNormal(x_phi_2))
    pyro.sample("mu_psi_2", dist.Uniform(a_psi_2, b_psi_2))
    pyro.sample("inv_kappa_psi_2", dist.HalfNormal(x_psi_2))    
    pyro.sample("mu_phi_3", dist.Uniform(a_phi_3, b_phi_3))
    pyro.sample("inv_kappa_phi_3", dist.HalfNormal(x_phi_3))
Beispiel #2
0
def build_support(
        lower_bound: Optional[Tensor] = None,
        upper_bound: Optional[Tensor] = None) -> constraints.Constraint:
    """Return support for prior distribution, depending on available bounds.

    Args:
        lower_bound: lower bound of the prior support, can be None
        upper_bound: upper bound of the prior support, can be None

    Returns:
        support: Pytorch constraint object.
    """

    # Support is real if no bounds are passed.
    if lower_bound is None and upper_bound is None:
        support = constraints.real
        warnings.warn(
            """No prior bounds were passed, consider passing lower_bound
            and / or upper_bound if your prior has bounded support.""")
    # Only lower bound is specified.
    elif upper_bound is None:
        num_dimensions = lower_bound.numel()  # type: ignore
        if num_dimensions > 1:
            support = constraints._IndependentConstraint(
                constraints.greater_than(lower_bound),
                1,
            )
        else:
            support = constraints.greater_than(lower_bound)
    # Only upper bound is specified.
    elif lower_bound is None:
        num_dimensions = upper_bound.numel()
        if num_dimensions > 1:
            support = constraints._IndependentConstraint(
                constraints.less_than(upper_bound),
                1,
            )
        else:
            support = constraints.less_than(upper_bound)

    # Both are specified.
    else:
        num_dimensions = lower_bound.numel()
        assert (num_dimensions == upper_bound.numel()
                ), "There must be an equal number of independent bounds."
        if num_dimensions > 1:
            support = constraints._IndependentConstraint(
                constraints.interval(lower_bound, upper_bound),
                1,
            )
        else:
            support = constraints.interval(lower_bound, upper_bound)

    return support
Beispiel #3
0
 def __init__(self,
              a,
              b,
              sigma=0.01,
              log_transform=False,
              validate_args=False):
     TModule.__init__(self)
     _a = torch.tensor(float(a)) if isinstance(a, Number) else a
     _a = _a.view(-1) if _a.dim() < 1 else _a
     _a, _b, _sigma = broadcast_all(_a, b, sigma)
     if not torch.all(constraints.less_than(_b).check(_a)):
         raise ValueError("must have that a < b (element-wise)")
     # TODO: Proper argument validation including broadcasting
     batch_shape, event_shape = _a.shape[:-1], _a.shape[-1:]
     # need to assign values before registering as buffers to make argument validation work
     self.a, self.b, self.sigma = _a, _b, _sigma
     super(SmoothedBoxPrior, self).__init__(batch_shape,
                                            event_shape,
                                            validate_args=validate_args)
     # now need to delete to be able to register buffer
     del self.a, self.b, self.sigma
     self.register_buffer("a", _a)
     self.register_buffer("b", _b)
     self.register_buffer("sigma", _sigma)
     self.tails = NormalPrior(torch.zeros_like(_a),
                              _sigma,
                              validate_args=validate_args)
     self._log_transform = log_transform