示例#1
0
文件: pNeRF.py 项目: andojas/pNeRF
def guide_3DA(data):    
    # Hyperparameters    
    a_psi_1 = pyro.param('a_psi_1', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) 
    b_psi_1 = pyro.param('b_psi_1', torch.tensor(np.pi), constraint=constraints.less_than(3.15))
    x_psi_1 = pyro.param('x_psi_1', torch.tensor(2.), constraint=constraints.positive)
    
    a_phi_2 = pyro.param('a_phi_2', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) 
    b_phi_2 = pyro.param('b_phi_2', torch.tensor(np.pi), constraint=constraints.less_than(3.15))
    x_phi_2 = pyro.param('x_phi_2', torch.tensor(2.), constraint=constraints.positive)
    
    a_psi_2 = pyro.param('a_psi_2', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) 
    b_psi_2 = pyro.param('b_psi_2', torch.tensor(np.pi), constraint=constraints.less_than(3.15))
    x_psi_2 = pyro.param('x_psi_2', torch.tensor(2.), constraint=constraints.positive)
    
    a_phi_3 = pyro.param('a_phi_3', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) 
    b_phi_3 = pyro.param('b_phi_3', torch.tensor(np.pi), constraint=constraints.less_than(3.15))
    x_phi_3 = pyro.param('x_phi_3', torch.tensor(2.), constraint=constraints.positive)
    
    # Sampling mu and kappa
    pyro.sample("mu_psi_1", dist.Uniform(a_psi_1, b_psi_1))
    pyro.sample("inv_kappa_psi_1", dist.HalfNormal(x_psi_1))    
    pyro.sample("mu_phi_2", dist.Uniform(a_phi_2, b_phi_2))
    pyro.sample("inv_kappa_phi_2", dist.HalfNormal(x_phi_2))
    pyro.sample("mu_psi_2", dist.Uniform(a_psi_2, b_psi_2))
    pyro.sample("inv_kappa_psi_2", dist.HalfNormal(x_psi_2))    
    pyro.sample("mu_phi_3", dist.Uniform(a_phi_3, b_phi_3))
    pyro.sample("inv_kappa_phi_3", dist.HalfNormal(x_phi_3))
示例#2
0
def parametrized_guide(doc_word_data, category_data, args, batch_size=None):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param("topic_weights_posterior",
                                         lambda: torch.ones(args.num_topics),
                                         constraint=constraints.positive)
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        lambda: torch.ones(args.num_topics, args.num_words),
        constraint=constraints.greater_than(0.5))
    with pyro.plate("topics", args.num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    category_weights_posterior = pyro.param(
        "category_weights_posterior",
        lambda: torch.ones(args.num_categories),
        constraint=constraints.positive)
    category_topics_posterior = pyro.param(
        "category_topics_posterior",
        lambda: torch.ones(args.num_categories, args.num_topics),
        constraint=constraints.greater_than(0.5))
    with pyro.plate("categories", args.num_categories):
        pyro.sample("category_weights",
                    dist.Gamma(category_weights_posterior, 1.))
        pyro.sample("category_topics",
                    dist.Dirichlet(category_topics_posterior))

    doc_category_posterior = pyro.param("doc_category_posterior",
                                        lambda: torch.ones(args.num_topics),
                                        constraint=constraints.positive)
    with pyro.plate("documents", args.num_docs, batch_size) as ind:
        pyro.sample("doc_categories", dist.Categorical(doc_category_posterior))
示例#3
0
def build_support(
        lower_bound: Optional[Tensor] = None,
        upper_bound: Optional[Tensor] = None) -> constraints.Constraint:
    """Return support for prior distribution, depending on available bounds.

    Args:
        lower_bound: lower bound of the prior support, can be None
        upper_bound: upper bound of the prior support, can be None

    Returns:
        support: Pytorch constraint object.
    """

    # Support is real if no bounds are passed.
    if lower_bound is None and upper_bound is None:
        support = constraints.real
        warnings.warn(
            """No prior bounds were passed, consider passing lower_bound
            and / or upper_bound if your prior has bounded support.""")
    # Only lower bound is specified.
    elif upper_bound is None:
        num_dimensions = lower_bound.numel()  # type: ignore
        if num_dimensions > 1:
            support = constraints._IndependentConstraint(
                constraints.greater_than(lower_bound),
                1,
            )
        else:
            support = constraints.greater_than(lower_bound)
    # Only upper bound is specified.
    elif lower_bound is None:
        num_dimensions = upper_bound.numel()
        if num_dimensions > 1:
            support = constraints._IndependentConstraint(
                constraints.less_than(upper_bound),
                1,
            )
        else:
            support = constraints.less_than(upper_bound)

    # Both are specified.
    else:
        num_dimensions = lower_bound.numel()
        assert (num_dimensions == upper_bound.numel()
                ), "There must be an equal number of independent bounds."
        if num_dimensions > 1:
            support = constraints._IndependentConstraint(
                constraints.interval(lower_bound, upper_bound),
                1,
            )
        else:
            support = constraints.interval(lower_bound, upper_bound)

    return support
示例#4
0
    def guide(self, data):
        alpha_posterior = pyro.param("alpha_posterior",
                                     lambda: torch.ones(self.n_topics),
                                     constraint=positive)
        beta_posterior = pyro.param(
            "beta_posterior",
            lambda: torch.ones(self.n_topics, self.vocab_size),
            constraint=greater_than(0.5))

        with pyro.plate("topics", self.n_topics):
            alpha = pyro.sample("alpha", dist.Gamma(alpha_posterior, 1.))
            betas = pyro.sample("beta", dist.Dirichlet(beta_posterior))

        theta = None
        z = None

        for d in pyro.plate("doc_loop", len(data)):
            gamma_q = pyro.param(f"gamma_{d}",
                                 torch.ones(self.n_topics),
                                 constraint=positive)
            theta = pyro.sample(f"theta_{d}", dist.Dirichlet(gamma_q))
            nwords = len(data[d])
            for w in pyro.plate(f"word_loop_{d}", nwords):
                phi_q = pyro.param(f"phi{d}_{w}",
                                   torch.ones(self.n_topics),
                                   constraint=positive)
                z = pyro.sample(f"z{d}_{w}", dist.Categorical(phi_q))
        return theta, z, alpha, betas
示例#5
0
文件: gpr.py 项目: lewisKit/pyro
    def __init__(self, X, y, kernel, noise=None, mean_function=None, jitter=1e-6,
                 name="GPR"):
        super(GPRegression, self).__init__(X, y, kernel, mean_function, jitter, name)

        noise = self.X.new_ones(()) if noise is None else noise
        self.noise = Parameter(noise)
        self.set_constraint("noise", constraints.greater_than(self.jitter))
示例#6
0
def parametrized_guide(predictor, data, num_words_per_doc, args):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        lambda: torch.ones(args.num_topics),
        constraint=constraints.positive)
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        lambda: torch.ones(args.num_topics, args.num_words),
        constraint=constraints.greater_than(0.5))
    with pyro.plate("topics", args.num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    for doc in pyro.plate("documents", args.num_docs, args.batch_size):
        # data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(args.num_words, 1)
        for i in data[doc]: counts[i] += 1
        #    .scatter_add(0, data[doc], torch.ones(data[doc].shape)))
        doc_topics = predictor(counts.transpose(0, 1))
        pyro.sample("doc_topics_{}".format(doc), dist.Delta(doc_topics, event_dim=1))
        # added this part since
        with pyro.plate("words_{}".format(doc), num_words_per_doc[doc]):
            word_topics = pyro.sample("word_topics_{}".format(doc), dist.Categorical(doc_topics))
示例#7
0
文件: lda.py 项目: pyro-ppl/pyro
def parametrized_guide(predictor, data, args, batch_size=None):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        lambda: torch.ones(args.num_topics),
        constraint=constraints.positive,
    )
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        lambda: torch.ones(args.num_topics, args.num_words),
        constraint=constraints.greater_than(0.5),
    )
    with pyro.plate("topics", args.num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.0))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    with pyro.plate("documents", args.num_docs, batch_size) as ind:
        data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(args.num_words,
                             ind.size(0)).scatter_add(0, data,
                                                      torch.ones(data.shape))
        doc_topics = predictor(counts.transpose(0, 1))
        pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))
示例#8
0
    def __init__(self,
                 X,
                 y,
                 kernel,
                 Xu,
                 noise=None,
                 mean_function=None,
                 approx=None,
                 jitter=1e-6,
                 name="SGPR"):
        super(SparseGPRegression, self).__init__(X, y, kernel, mean_function,
                                                 jitter, name)

        self.Xu = Parameter(Xu)

        noise = self.X.new_ones(()) if noise is None else noise
        self.noise = Parameter(noise)
        self.set_constraint("noise", constraints.greater_than(self.jitter))

        if approx is None:
            self.approx = "VFE"
        elif approx in ["DTC", "FITC", "VFE"]:
            self.approx = approx
        else:
            raise ValueError(
                "The sparse approximation method should be one of "
                "'DTC', 'FITC', 'VFE'.")
示例#9
0
 def model(loc, cov):
     x = pyro.param("x", torch.randn(2))
     y = pyro.param("y", torch.randn(3, 2))
     z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1))
     pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x)
     with pyro.plate("y_plate", 3):
         pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y)
     with pyro.plate("z_plate", 4):
         pyro.sample("obs_z", dist.MultivariateNormal(loc, cov), obs=z)
示例#10
0
 def __init__(self, dof, lambda1, lambda2, nu, validate_args=None):
     self.dof = dof
     self.lambda1 = lambda1
     self.lambda2 = lambda2
     self.nu = nu
     batch_shape, event_shape = lambda1.shape[:-1], lambda1.shape[-1:]
     D = event_shape[0]
     self.arg_constraints['dof'] = constraints.greater_than(2. * D + 1.)
     super().__init__(batch_shape, event_shape, validate_args=validate_args)
示例#11
0
    def __init__(self,
                 df: Union[torch.Tensor, Number],
                 covariance_matrix: torch.Tensor = None,
                 precision_matrix: torch.Tensor = None,
                 scale_tril: torch.Tensor = None,
                 validate_args=None):
        assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \
            "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."

        param = next(p for p in (covariance_matrix, precision_matrix, scale_tril) if p is not None)

        if param.dim() < 2:
            raise ValueError("scale_tril must be at least two-dimensional, with optional leading batch dimensions")

        if isinstance(df, Number):
            batch_shape = torch.Size(param.shape[:-2])
            self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
        else:
            batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
            self.df = df.expand(batch_shape)
        event_shape = param.shape[-2:]

        if self.df.le(event_shape[-1] - 1).any():
            raise ValueError(f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}.")

        if scale_tril is not None:
            self.scale_tril = param.expand(batch_shape + (-1, -1))
        elif covariance_matrix is not None:
            self.covariance_matrix = param.expand(batch_shape + (-1, -1))
        elif precision_matrix is not None:
            self.precision_matrix = param.expand(batch_shape + (-1, -1))

        self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] - 1)
        if self.df.lt(event_shape[-1]).any():
            warnings.warn("Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim.")

        super(Wishart, self).__init__(batch_shape, event_shape, validate_args=validate_args)
        self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]

        if scale_tril is not None:
            self._unbroadcasted_scale_tril = scale_tril
        elif covariance_matrix is not None:
            self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
        else:  # precision_matrix is not None
            self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)

        # Chi2 distribution is needed for Bartlett decomposition sampling
        self._dist_chi2 = torch.distributions.chi2.Chi2(
            df=(
                self.df.unsqueeze(-1)
                - torch.arange(
                    self._event_shape[-1],
                    dtype=self._unbroadcasted_scale_tril.dtype,
                    device=self._unbroadcasted_scale_tril.device,
                ).expand(batch_shape + (-1,))
            )
        )
示例#12
0
 def model(loc, cov):
     x = pyro.param("x", torch.randn(2))
     y = pyro.param("y", torch.randn(3, 2))
     z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1))
     pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x)
     with pyro.iarange("y_iarange", 3):
         pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y)
     with pyro.iarange("z_iarange", 4):
         pyro.sample("obs_z", dist.MultivariateNormal(loc, cov), obs=z)
示例#13
0
 def __init__(self, concentration, scale, validate_args=None):
     self.concentration = torch.as_tensor(concentration)
     self.scale = torch.as_tensor(scale)
     batch_shape = self.concentration.shape
     event_shape = self.scale.shape[-2:]
     self.arg_constraints['concentration'] = constraints.greater_than(
         .5 * (event_shape[-1] - 1))
     super(Wishart, self).__init__(batch_shape,
                                   event_shape,
                                   validate_args=validate_args)
示例#14
0
    def prior(self):
        """
        Prior definition of the weights used for log regression. This function has to set the
        variables 'self.weight_prior_dist', 'self.weight_mean_init' and 'self.weight_stddev_init'.
        """

        # number of weights
        num_weights = 4 * (self.num_features + 1)
        self._sites = OrderedDict()

        # initial values for mean, scale and prior dist
        init_mean = torch.ones(num_weights).uniform_(1. + self.epsilon, 2.)
        init_scale = torch.ones(num_weights)
        prior = dist.Normal(init_mean, 10 * init_scale, validate_args=True)

        # we have constraints on the weights to be positive
        # this is usually solved by defining constraints on the MLE optimizer
        # however, on MCMC and VI this is not possible
        # instead, we are using a "shifted" LogNormal to obtain only positive samples

        if self.method in ['variational', 'mcmc']:

            # for this purpose, we need to transform the prior mean first and set the
            # distribution to be a LogNormal
            init_mean = torch.log(torch.exp(init_mean) - 1)
            prior = dist.LogNormal(init_mean,
                                   10 * init_scale,
                                   validate_args=True)

        # set properties for "weights": weights must be positive
        self._sites['weights'] = {
            'values': None,
            'constraint': constraints.greater_than(self.epsilon),
            'init': {
                'mean': init_mean,
                'scale': init_scale
            },
            'prior': prior
        }

        # set properties for "bias"
        self._sites['bias'] = {
            'values':
            None,
            'constraint':
            constraints.real,
            'init': {
                'mean': torch.ones(1) * self.epsilon,
                'scale': torch.ones(1)
            },
            'prior':
            dist.Normal(torch.zeros(1), 10 * torch.ones(1),
                        validate_args=True),
        }
示例#15
0
def guide(dataset_total_length, x_data, y_data, kl_factor):
    global n_features, n_hidden, n_out
    a1_mean = pyro.param('a1_mean', 0.01 * torch.randn(n_features, n_hidden))
    a1_scale = pyro.param('a1_scale', 0.1 * torch.ones(n_features, n_hidden),
                          constraint=constraints.greater_than(0.01))
    a2_mean = pyro.param('a2_mean', 0.01 * torch.randn(n_hidden + 1, n_out))
    a2_scale = pyro.param('a2_scale', 0.1 * torch.ones(n_hidden + 1, n_out),
                          constraint=constraints.greater_than(0.01))

    with pyro.plate('map', dataset_total_length, subsample=x_data):
    # with pyro.plate('map', size=x_data.shape[0]):
        # sample first hidden layer
        h1 = pyro.sample('h1', bnn.HiddenLayer(x_data, a1_mean, a1_scale,
                                                   non_linearity=nnf.leaky_relu,
                                                   KL_factor=kl_factor))
        # sample second hidden layer
        rate = pyro.sample('rate', bnn.HiddenLayer(h1, a2_mean, a2_scale,
                                                   non_linearity=lambda x: nnf.relu(x)+1e-3,
                                                   KL_factor=kl_factor,
                                                   include_hidden_bias=False))
 def __init__(self, df, scale=None, scale_tril=None, validate_args=None):
     self.df = torch.as_tensor(df)
     if (scale is None) + (scale_tril is None) != 1:
         raise ValueError("Exactly one of scale or scale_tril must be specified.")
     if scale is not None:
         self.scale = torch.as_tensor(scale)
     if scale_tril is not None:
         self.scale_tril = torch.as_tensor(scale_tril)
     batch_shape = self.df.shape
     event_shape = self.scale_tril.shape[-2:]
     self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] - 1)
     super(Wishart, self).__init__(batch_shape, event_shape, validate_args=validate_args)
示例#17
0
文件: gpr.py 项目: zippeurfou/pyro
    def __init__(self,
                 X,
                 y,
                 kernel,
                 noise=None,
                 mean_function=None,
                 jitter=1e-6,
                 name="GPR"):
        super(GPRegression, self).__init__(X, y, kernel, mean_function, jitter,
                                           name)

        noise = self.X.new_ones(()) if noise is None else noise
        self.noise = Parameter(noise)
        self.set_constraint("noise", constraints.greater_than(self.jitter))
示例#18
0
文件: sgpr.py 项目: lewisKit/pyro
    def __init__(self, X, y, kernel, Xu, noise=None, mean_function=None, approx=None,
                 jitter=1e-6, name="SGPR"):
        super(SparseGPRegression, self).__init__(X, y, kernel, mean_function, jitter,
                                                 name)

        self.Xu = Parameter(Xu)

        noise = self.X.new_ones(()) if noise is None else noise
        self.noise = Parameter(noise)
        self.set_constraint("noise", constraints.greater_than(self.jitter))

        if approx is None:
            self.approx = "VFE"
        elif approx in ["DTC", "FITC", "VFE"]:
            self.approx = approx
        else:
            raise ValueError("The sparse approximation method should be one of "
                             "'DTC', 'FITC', 'VFE'.")
示例#19
0
文件: sem.py 项目: Lakshitha0912/rfi1
    def __init__(self,
                 dag: DirectedAcyclicGraph,
                 noise_std_dict: Dict[str, float] = {},
                 default_noise_std_bounds: Tuple[float, float] = (0.5, 1.0),
                 seed=None):
        """
        Args:
            dag: DAG, defining SEM
            noise_std_dict: Noise std dict for each variable
            default_noise_std_bounds: Default noise std
        """
        super(PostNonLinearMultiplicativeHalfNormalSEM, self).__init__(dag)

        np.random.seed(seed) if seed is not None else None

        for node in self.topological_order:
            self.model[node]['noise_std'] = noise_std_dict[node] if node in noise_std_dict \
                else np.random.uniform(*default_noise_std_bounds)

        self.parents_conditional_support = constraints.greater_than(0.0)
示例#20
0
    def _parametrized_guide(self, predictor, data, labels, batch_size=None):
        args = self.args
        # Use a conjugate guide for global variables.
        topic_weights_posterior = pyro.param(
            "topic_weights_posterior",
            lambda: torch.ones(args.num_topics),
            constraint=constraints.positive)
        topic_words_posterior = pyro.param(
            "topic_words_posterior",
            lambda: torch.ones(args.num_topics, args.num_words),
            constraint=constraints.greater_than(0.5))

        # Is this needed?
        label_posterior = pyro.param("label_posterior",
                                     lambda: torch.ones(2, args.num_topics),
                                     constraint=constraints.positive)

        with pyro.plate("topics", args.num_topics):
            pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior,
                                                    1.))
            pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))
            pyro.sample("label_prior", dist.Beta(*label_posterior))

        # Use an amortized guide for local variables.
        pyro.module("predictor", predictor)
        with pyro.plate("documents", args.num_docs, batch_size) as ind:
            data = data[:, ind]
            labels = labels[ind]
            counts = (torch.zeros(args.num_words, ind.size(0)).scatter_add(
                0, data, torch.ones(data.shape)))

            augmented_input = torch.zeros(batch_size,
                                          args.num_words + args.num_topics)
            augmented_input[:, :args.num_words] = counts.transpose(0, 1)
            augmented_input[:, args.num_words:] = labels

            doc_topics = predictor(augmented_input)
            pyro.sample("doc_topics", dist.Delta(doc_topics).to_event(1))
示例#21
0
    def guide(self, docs=None, doc_sum=None):
        # Use a conjugate guide for global variables.
        topic_weights_posterior = pyro.param(
            "topic_weights_posterior",
            lambda: torch.ones(self.num_topics, device=self.device),
            constraint=constraints.positive)

        topic_words_posterior = pyro.param(
            "topic_words_posterior",
            lambda: torch.ones(
                self.num_topics, self.vocab_size, device=self.device),
            constraint=constraints.greater_than(0.5))

        with pyro.plate("topics", self.num_topics):
            pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior,
                                                    1.))
            pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

        # Use an amortized guide for local variables.
        pyro.module("inference_net", self.inference_net)
        with pyro.plate("documents", doc_sum.shape[0]):
            doc_topics = self.inference_net(doc_sum)
            pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))
def guide(n_ice, n_obs, floe_size, cover_subp):
    a_floe_loc = pyro.param('a_floe_loc', torch.tensor(0.))
    b_floe_loc = pyro.param('b_floe_loc', torch.tensor(0.))
    a_floe_scale = pyro.param('a_floe_scale',
                              torch.tensor(1.),
                              constraint=constraints.greater_than(0.01))
    b_floe_scale = pyro.param('b_floe_scale',
                              torch.tensor(1.),
                              constraint=constraints.greater_than(0.01))
    a_cover_loc = pyro.param('a_cover_loc', torch.tensor(1.))
    b_cover_loc = pyro.param('b_cover_loc', torch.tensor(1.))
    a_cover_b_loc = pyro.param('a_cover_b_loc', torch.tensor(1.))
    b_cover_b_loc = pyro.param('b_cover_b_loc', torch.tensor(1.))
    a_cover_scale = pyro.param('a_cover_scale',
                               torch.tensor(1.),
                               constraint=constraints.greater_than(0.01))
    b_cover_scale = pyro.param('b_cover_scale',
                               torch.tensor(1.),
                               constraint=constraints.greater_than(0.01))
    a_cover_b_scale = pyro.param('a_cover_b_scale',
                                 torch.tensor(1.),
                                 constraint=constraints.greater_than(0.01))
    b_cover_b_scale = pyro.param('b_cover_b_scale',
                                 torch.tensor(1.),
                                 constraint=constraints.greater_than(0.01))
    a_floe = pyro.sample("a_floe", dist.LogNormal(a_floe_loc, a_floe_scale))
    b_floe = pyro.sample("b_floe", dist.LogNormal(b_floe_loc, b_floe_scale))
    a_cover = pyro.sample("a_cover", dist.LogNormal(a_cover_loc,
                                                    a_cover_scale))
    b_cover = pyro.sample("b_cover", dist.LogNormal(b_cover_loc,
                                                    b_cover_scale))
    a_cover_b = pyro.sample("a_cover_b",
                            dist.LogNormal(a_cover_b_loc, a_cover_b_scale))
    b_cover_b = pyro.sample("b_cover_b",
                            dist.LogNormal(b_cover_b_loc, b_cover_b_scale))
    lambda_ice = sigmoid((a_floe * floe_size + b_floe))
    alpha_det = sigmoid((a_cover * cover_subp + b_cover))
    beta_det = sigmoid((a_cover_b * cover_subp + b_cover_b))
    with pyro.plate('subp', size=len(floe_size), subsample_size=10):
        phi_det = pyro.sample('phi_det', Beta(alpha_det, beta_det))
示例#23
0
 def support(self):
     return constraints.greater_than(self.scale)
def guide(data, args, batch_size=None):
    data = torch.reshape(data, [64, 1000])
    
    pyro.module("layer1", layer1)
    pyro.module("layer2", layer2)
    pyro.module("layer3", layer3)

    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        # WL: edited. =====
        # # lambda: torch.ones(8) / 8,
        # torch.ones(8) / 8,
        torch.ones(8),
        # =================
        constraint=constraints.positive)
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        # WL: edited. =====
        # # lambda: torch.ones(8, 1024) / 1024,
        # torch.ones(8, 1024) / 1024,
        # constraint=constraints.positive)
        torch.ones(8, 1024),
        constraint=constraints.greater_than(0.5))
        # =================
    """
    # wy: dummy param for word_topics
    word_topics_posterior = pyro.param(
        "word_topics_posterior",
        torch.ones(64, 1024, 8) / 8,
        constraint=constraints.positive)
    """
    
    with pyro.plate("topics", 8):
        # shape = [8] + []
        topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.))
        # shape = [8] + [1024]
        topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior))

    # Use an amortized guide for local variables.
    with pyro.plate("documents", 1000, 32) as ind:
        # shape =  [64, 32]
        data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(1024, 32)
        counts = torch.Tensor.scatter_add_\
            (counts, 0, data,
             torch.Tensor.expand(torch.tensor(1.), [1024, 32]))
        h1 = sigmoid(layer1(torch.transpose(counts, 0, 1)))
        h2 = sigmoid(layer2(h1))
        # shape = [32, 8]
        doc_topics_w = sigmoid(layer3(h2))
        # shape = [32] + [8]
        # WL: edited. =====
        # # # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w, event_dim=1))
        # # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w).to_event(1))
        # doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics_w))
        doc_topics = softmax(doc_topics_w)
        pyro.sample("doc_topics", Delta(doc_topics, event_dim=1))
        # =================
        """
示例#25
0
 def support(self):
     return constraints.greater_than(self.scale)
示例#26
0
class Wishart(ExponentialFamily):
    r"""
    Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
    or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`

    Example:
        >>> m = Wishart(torch.eye(2), torch.Tensor([2]))
        >>> m.sample()  # Wishart distributed with mean=`df * I` and
                        # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j

    Args:
        covariance_matrix (Tensor): positive-definite covariance matrix
        precision_matrix (Tensor): positive-definite precision matrix
        scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
        df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
    Note:
        Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
        :attr:`scale_tril` can be specified.
        Using :attr:`scale_tril` will be more efficient: all computations internally
        are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
        :attr:`precision_matrix` is passed instead, it is only used to compute
        the corresponding lower triangular matrices using a Cholesky decomposition.
        'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]

    **References**

    [1] `On equivalence of the LKJ distribution and the restricted Wishart distribution`,
    Zhenxun Wang, Yunan Wu, Haitao Chu.
    """
    arg_constraints = {
        'covariance_matrix': constraints.positive_definite,
        'precision_matrix': constraints.positive_definite,
        'scale_tril': constraints.lower_cholesky,
        'df': constraints.greater_than(0),
    }
    support = constraints.positive_definite
    has_rsample = True
    _mean_carrier_measure = 0

    def __init__(self,
                 df: Union[torch.Tensor, Number],
                 covariance_matrix: torch.Tensor = None,
                 precision_matrix: torch.Tensor = None,
                 scale_tril: torch.Tensor = None,
                 validate_args=None):
        assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \
            "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."

        param = next(p
                     for p in (covariance_matrix, precision_matrix, scale_tril)
                     if p is not None)

        if param.dim() < 2:
            raise ValueError(
                "scale_tril must be at least two-dimensional, with optional leading batch dimensions"
            )

        if isinstance(df, Number):
            batch_shape = torch.Size(param.shape[:-2])
            self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
        else:
            batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
            self.df = df.expand(batch_shape)
        event_shape = param.shape[-2:]

        if self.df.le(event_shape[-1] - 1).any():
            raise ValueError(
                f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}."
            )

        if scale_tril is not None:
            self.scale_tril = param.expand(batch_shape + (-1, -1))
        elif covariance_matrix is not None:
            self.covariance_matrix = param.expand(batch_shape + (-1, -1))
        elif precision_matrix is not None:
            self.precision_matrix = param.expand(batch_shape + (-1, -1))

        self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] -
                                                              1)
        if self.df.lt(event_shape[-1]).any():
            warnings.warn(
                "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim."
            )

        super(Wishart, self).__init__(batch_shape,
                                      event_shape,
                                      validate_args=validate_args)
        self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]

        if scale_tril is not None:
            self._unbroadcasted_scale_tril = scale_tril
        elif covariance_matrix is not None:
            self._unbroadcasted_scale_tril = torch.linalg.cholesky(
                covariance_matrix)
        else:  # precision_matrix is not None
            self._unbroadcasted_scale_tril = _precision_to_scale_tril(
                precision_matrix)

        # Chi2 distribution is needed for Bartlett decomposition sampling
        self._dist_chi2 = torch.distributions.chi2.Chi2(
            df=(self.df.unsqueeze(-1) - torch.arange(
                self._event_shape[-1],
                dtype=self._unbroadcasted_scale_tril.dtype,
                device=self._unbroadcasted_scale_tril.device,
            ).expand(batch_shape + (-1, ))))

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(Wishart, _instance)
        batch_shape = torch.Size(batch_shape)
        cov_shape = batch_shape + self.event_shape
        new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(
            cov_shape)
        new.df = self.df.expand(batch_shape)

        new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]

        if 'covariance_matrix' in self.__dict__:
            new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
        if 'scale_tril' in self.__dict__:
            new.scale_tril = self.scale_tril.expand(cov_shape)
        if 'precision_matrix' in self.__dict__:
            new.precision_matrix = self.precision_matrix.expand(cov_shape)

        # Chi2 distribution is needed for Bartlett decomposition sampling
        new._dist_chi2 = torch.distributions.chi2.Chi2(
            df=(new.df.unsqueeze(-1) - torch.arange(
                self.event_shape[-1],
                dtype=new._unbroadcasted_scale_tril.dtype,
                device=new._unbroadcasted_scale_tril.device,
            ).expand(batch_shape + (-1, ))))

        super(Wishart, new).__init__(batch_shape,
                                     self.event_shape,
                                     validate_args=False)
        new._validate_args = self._validate_args
        return new

    @lazy_property
    def scale_tril(self):
        return self._unbroadcasted_scale_tril.expand(self._batch_shape +
                                                     self._event_shape)

    @lazy_property
    def covariance_matrix(self):
        return (self._unbroadcasted_scale_tril
                @ self._unbroadcasted_scale_tril.transpose(
                    -2, -1)).expand(self._batch_shape + self._event_shape)

    @lazy_property
    def precision_matrix(self):
        identity = torch.eye(
            self._event_shape[-1],
            device=self._unbroadcasted_scale_tril.device,
            dtype=self._unbroadcasted_scale_tril.dtype,
        )
        return torch.cholesky_solve(
            identity,
            self._unbroadcasted_scale_tril).expand(self._batch_shape +
                                                   self._event_shape)

    @property
    def mean(self):
        return self.df.view(self._batch_shape +
                            (1, 1)) * self.covariance_matrix

    @property
    def variance(self):
        V = self.covariance_matrix  # has shape (batch_shape x event_shape)
        diag_V = V.diagonal(dim1=-2, dim2=-1)
        return self.df.view(self._batch_shape + (1, 1)) * (
            V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V))

    def _bartlett_sampling(self, sample_shape=torch.Size()):
        p = self._event_shape[-1]  # has singleton shape

        # Implemented Sampling using Bartlett decomposition
        noise = _clamp_above_eps(
            self._dist_chi2.rsample(sample_shape).sqrt()).diag_embed(dim1=-2,
                                                                     dim2=-1)

        i, j = torch.tril_indices(p, p, offset=-1)
        noise[..., i, j] = torch.randn(
            torch.Size(sample_shape) + self._batch_shape +
            (int(p * (p - 1) / 2), ),
            dtype=noise.dtype,
            device=noise.device,
        )
        chol = self._unbroadcasted_scale_tril @ noise
        return chol @ chol.transpose(-2, -1)

    def rsample(self, sample_shape=torch.Size(), max_try_correction=None):
        r"""
        .. warning::
            In some cases, sampling algorithn based on Bartlett decomposition may return singular matrix samples.
            Several tries to correct singular samples are performed by default, but it may end up returning
            singular matrix samples. Sigular samples may return `-inf` values in `.log_prob()`.
            In those cases, the user should validate the samples and either fix the value of `df`
            or adjust `max_try_correction` value for argument in `.rsample` accordingly.
        """

        if max_try_correction is None:
            max_try_correction = 3 if torch._C._get_tracing_state() else 10

        sample_shape = torch.Size(sample_shape)
        sample = self._bartlett_sampling(sample_shape)

        # Below part is to improve numerical stability temporally and should be removed in the future
        is_singular = self.support.check(sample)
        if self._batch_shape:
            is_singular = is_singular.amax(self._batch_dims)

        if torch._C._get_tracing_state():
            # Less optimized version for JIT
            for _ in range(max_try_correction):
                sample_new = self._bartlett_sampling(sample_shape)
                sample = torch.where(is_singular, sample_new, sample)

                is_singular = ~self.support.check(sample)
                if self._batch_shape:
                    is_singular = is_singular.amax(self._batch_dims)

        else:
            # More optimized version with data-dependent control flow.
            if is_singular.any():
                warnings.warn("Singular sample detected.")

                for _ in range(max_try_correction):
                    sample_new = self._bartlett_sampling(
                        is_singular[is_singular].shape)
                    sample[is_singular] = sample_new

                    is_singular_new = ~self.support.check(sample_new)
                    if self._batch_shape:
                        is_singular_new = is_singular_new.amax(
                            self._batch_dims)
                    is_singular[is_singular.clone()] = is_singular_new

                    if not is_singular.any():
                        break

        return sample

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        nu = self.df  # has shape (batch_shape)
        p = self._event_shape[-1]  # has singleton shape
        return (-nu *
                (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal(
                    dim1=-2, dim2=-1).log().sum(-1)) -
                torch.mvlgamma(nu / 2, p=p) +
                (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet -
                torch.cholesky_solve(value, self._unbroadcasted_scale_tril).
                diagonal(dim1=-2, dim2=-1).sum(dim=-1) / 2)

    def entropy(self):
        nu = self.df  # has shape (batch_shape)
        p = self._event_shape[-1]  # has singleton shape
        V = self.covariance_matrix  # has shape (batch_shape x event_shape)
        return ((p + 1) *
                (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal(
                    dim1=-2, dim2=-1).log().sum(-1)) +
                torch.mvlgamma(nu / 2, p=p) -
                (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p) + nu * p / 2)

    @property
    def _natural_params(self):
        nu = self.df  # has shape (batch_shape)
        p = self._event_shape[-1]  # has singleton shape
        return -self.precision_matrix / 2, (nu - p - 1) / 2

    def _log_normalizer(self, x, y):
        p = self._event_shape[-1]
        return ((y + (p + 1) / 2) *
                (-torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p) +
                torch.mvlgamma(y + (p + 1) / 2, p=p))
示例#27
0
    def _init_parameters(self):
        """
        Parameters shared between different models.
        """
        device = self.device
        data = self.data

        pyro.param(
            "proximity_loc",
            lambda: torch.tensor(0.5, device=device),
            constraint=constraints.interval(
                0,
                (self.data.P + 1) / math.sqrt(12) -
                torch.finfo(self.dtype).eps,
            ),
        )
        pyro.param(
            "proximity_size",
            lambda: torch.tensor(100, device=device),
            constraint=constraints.greater_than(2.0),
        )
        pyro.param(
            "lamda_loc",
            lambda: torch.tensor(0.5, device=device),
            constraint=constraints.positive,
        )
        pyro.param(
            "lamda_beta",
            lambda: torch.tensor(100, device=device),
            constraint=constraints.positive,
        )
        pyro.param(
            "gain_loc",
            lambda: torch.tensor(5, device=device),
            constraint=constraints.positive,
        )
        pyro.param(
            "gain_beta",
            lambda: torch.tensor(100, device=device),
            constraint=constraints.positive,
        )

        pyro.param(
            "background_mean_loc",
            lambda: torch.full(
                (data.Nt, 1),
                data.median - self.data.offset.mean,
                device=device,
            ),
            constraint=constraints.positive,
        )
        pyro.param(
            "background_std_loc",
            lambda: torch.ones(data.Nt, 1, device=device),
            constraint=constraints.positive,
        )

        pyro.param(
            "b_loc",
            lambda: torch.full(
                (data.Nt, data.F),
                data.median - self.data.offset.mean,
                device=device,
            ),
            constraint=constraints.positive,
        )
        pyro.param(
            "b_beta",
            lambda: torch.ones(data.Nt, data.F, device=device),
            constraint=constraints.positive,
        )
        pyro.param(
            "h_loc",
            lambda: torch.full((self.K, data.Nt, data.F), 2000, device=device),
            constraint=constraints.positive,
        )
        pyro.param(
            "h_beta",
            lambda: torch.full(
                (self.K, data.Nt, data.F), 0.001, device=device),
            constraint=constraints.positive,
        )
        pyro.param(
            "w_mean",
            lambda: torch.full((self.K, data.Nt, data.F), 1.5, device=device),
            constraint=constraints.interval(
                0.75 + torch.finfo(self.dtype).eps,
                2.25 - torch.finfo(self.dtype).eps,
            ),
        )
        pyro.param(
            "w_size",
            lambda: torch.full((self.K, data.Nt, data.F), 100, device=device),
            constraint=constraints.greater_than(2.0),
        )
        pyro.param(
            "x_mean",
            lambda: torch.zeros(self.K, data.Nt, data.F, device=device),
            constraint=constraints.interval(
                -(data.P + 1) / 2 + torch.finfo(self.dtype).eps,
                (data.P + 1) / 2 - torch.finfo(self.dtype).eps,
            ),
        )
        pyro.param(
            "y_mean",
            lambda: torch.zeros(self.K, data.Nt, data.F, device=device),
            constraint=constraints.interval(
                -(data.P + 1) / 2 + torch.finfo(self.dtype).eps,
                (data.P + 1) / 2 - torch.finfo(self.dtype).eps,
            ),
        )
        pyro.param(
            "size",
            lambda: torch.full((self.K, data.Nt, data.F), 200, device=device),
            constraint=constraints.greater_than(2.0),
        )
示例#28
0
    return x


def _constraint_hash(constraint: constraints.Constraint) -> int:
    assert isinstance(constraint, constraints.Constraint)
    out = hash(type(constraint))
    out ^= hash(frozenset(constraint.__dict__.items()))
    return out


# useful if the distribution's init has multiple ways of specifying (e.g. both logits or probs)
_distribution_to_param_names = {NegativeBinomial: ['probs', 'total_count']}

# torch.distributions has a whole 'transforms' module but I don't know if they provide a mapping
_constraint_to_ilink = {
    _constraint_hash(constraints.positive):
    torch.exp,
    _constraint_hash(constraints.greater_than(0)):
    torch.exp,
    _constraint_hash(constraints.unit_interval):
    torch.sigmoid,
    _constraint_hash(constraints.real):
    identity,
    # TODO: is there a way to make these work?
    _constraint_hash(constraints.greater_than_eq(0)):
    torch.exp,
    _constraint_hash(constraints.half_open_interval(0, 1)):
    torch.sigmoid,
    # TODO: constraints.interval
}
示例#29
0
SHAPE_CONSTRAINT = [
    ((), constraints.real),
    ((4, ), constraints.real),
    ((3, 2), constraints.real),
    ((), constraints.positive),
    ((4, ), constraints.positive),
    ((3, 2), constraints.positive),
    ((5, ), constraints.simplex),
    ((
        2,
        5,
    ), constraints.simplex),
    ((5, 5), constraints.lower_cholesky),
    ((2, 5, 5), constraints.lower_cholesky),
    ((10, ), constraints.greater_than(-torch.randn(10).exp())),
    ((4, 10), constraints.greater_than(-torch.randn(10).exp())),
    ((4, 10), constraints.greater_than(-torch.randn(4, 10).exp())),
    ((3, 2, 10), constraints.greater_than(-torch.randn(10).exp())),
    ((3, 2, 10), constraints.greater_than(-torch.randn(2, 10).exp())),
    ((3, 2, 10), constraints.greater_than(-torch.randn(3, 1, 10).exp())),
    ((3, 2, 10), constraints.greater_than(-torch.randn(3, 2, 10).exp())),
    ((5, ), constraints.real_vector),
    ((
        2,
        5,
    ), constraints.real_vector),
    ((), constraints.unit_interval),
    ((4, ), constraints.unit_interval),
    ((3, 2), constraints.unit_interval),
    ((10, ), constraints.interval(-torch.randn(10).exp(),
示例#30
0
 def guide(self, images, labels=None, kl_factor=1.0):
     images = images.view(-1, 784)
     n_images = images.size(0)
     # Set-up parameters to be optimized to approximate the true posterior
     # Mean parameters are randomly initialized to small values around 0, and scale parameters
     # are initialized to be 0.1 to be closer to the expected posterior value which we assume is stronger than
     # the prior scale of 1.
     # Scale parameters must be positive, so we constraint them to be larger than some epsilon value (0.01).
     # Variational dropout are initialized as in the prior model, and constrained to be between 0.1 and 1 (so dropout
     # rate is between 0.1 and 0.5) as suggested in the local reparametrization paper
     a1_mean = pyro.param('a1_mean', 0.01 * torch.randn(784, self.n_hidden))
     a1_scale = pyro.param('a1_scale',
                           0.1 * torch.ones(784, self.n_hidden),
                           constraint=constraints.greater_than(0.01))
     a1_dropout = pyro.param('a1_dropout',
                             torch.tensor(0.25),
                             constraint=constraints.interval(0.1, 1.0))
     a2_mean = pyro.param(
         'a2_mean', 0.01 * torch.randn(self.n_hidden + 1, self.n_hidden))
     a2_scale = pyro.param('a2_scale',
                           0.1 *
                           torch.ones(self.n_hidden + 1, self.n_hidden),
                           constraint=constraints.greater_than(0.01))
     a2_dropout = pyro.param('a2_dropout',
                             torch.tensor(1.0),
                             constraint=constraints.interval(0.1, 1.0))
     a3_mean = pyro.param(
         'a3_mean', 0.01 * torch.randn(self.n_hidden + 1, self.n_hidden))
     a3_scale = pyro.param('a3_scale',
                           0.1 *
                           torch.ones(self.n_hidden + 1, self.n_hidden),
                           constraint=constraints.greater_than(0.01))
     a3_dropout = pyro.param('a3_dropout',
                             torch.tensor(1.0),
                             constraint=constraints.interval(0.1, 1.0))
     a4_mean = pyro.param(
         'a4_mean', 0.01 * torch.randn(self.n_hidden + 1, self.n_classes))
     a4_scale = pyro.param('a4_scale',
                           0.1 *
                           torch.ones(self.n_hidden + 1, self.n_classes),
                           constraint=constraints.greater_than(0.01))
     # Sample latent values using the variational parameters that are set-up above.
     # Notice how there is no conditioning on labels in the guide!
     with pyro.plate('data', size=n_images):
         h1 = pyro.sample(
             'h1',
             bnn.HiddenLayer(images,
                             a1_mean,
                             a1_dropout * a1_scale,
                             non_linearity=nnf.leaky_relu,
                             KL_factor=kl_factor))
         h2 = pyro.sample(
             'h2',
             bnn.HiddenLayer(h1,
                             a2_mean,
                             a2_dropout * a2_scale,
                             non_linearity=nnf.leaky_relu,
                             KL_factor=kl_factor))
         h3 = pyro.sample(
             'h3',
             bnn.HiddenLayer(h2,
                             a3_mean,
                             a3_dropout * a3_scale,
                             non_linearity=nnf.leaky_relu,
                             KL_factor=kl_factor))
         logits = pyro.sample(
             'logits',
             bnn.HiddenLayer(
                 h3,
                 a4_mean,
                 a4_scale,
                 non_linearity=lambda x: nnf.log_softmax(x, dim=-1),
                 KL_factor=kl_factor,
                 include_hidden_bias=False))
示例#31
0
class Normal(ExponentialFamily):
    r"""
    Creates a normal (also called Gaussian) distribution parameterized by
    :attr:`loc` and :attr:`scale`.

    Example::

        >>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
        >>> m.sample()  # normally distributed with loc=0 and scale=1
        tensor([ 0.1046])

    Args:
        loc (float or Tensor): mean of the distribution (often referred to as mu)
        scale (float or Tensor): standard deviation of the distribution
            (often referred to as sigma)
    """
    arg_constraints = {'loc': constraints.real, 'scale': constraints.greater_than(8.)}
    support = constraints.real
    has_rsample = True
    _mean_carrier_measure = 0

    @property
    def mean(self):
        return self.loc

    @property
    def stddev(self):
        return self.scale

    @property
    def variance(self):
        return self.stddev.pow(2)

    def __init__(self, loc, scale, validate_args=None):
        self.loc, self.scale = broadcast_all(loc, scale)
        if isinstance(loc, Number) and isinstance(scale, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.loc.size()
        super(Normal, self).__init__(batch_shape, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(Normal, _instance)
        batch_shape = torch.Size(batch_shape)
        new.loc = self.loc.expand(batch_shape)
        new.scale = self.scale.expand(batch_shape)
        super(Normal, new).__init__(batch_shape, validate_args=False)
        new._validate_args = self._validate_args
        return new

    def sample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        with torch.no_grad():
            return torch.normal(self.loc.expand(shape), self.scale.expand(shape))

    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
        return self.loc + eps * self.scale

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        # compute the variance
        var = (self.scale ** 2)
        log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log()
        return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))

    def cdf(self, value):
        if self._validate_args:
            self._validate_sample(value)
        return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)))

    def icdf(self, value):
        if self._validate_args:
            self._validate_sample(value)
        return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)

    def entropy(self):
        return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale)

    @property
    def _natural_params(self):
        return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())

    def _log_normalizer(self, x, y):
        return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)
示例#32
0
 def z(self):
     return constraints.greater_than(self.y)
示例#33
0
 def y(self):
     return constraints.greater_than(self.x)