def guide_3DA(data): # Hyperparameters a_psi_1 = pyro.param('a_psi_1', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) b_psi_1 = pyro.param('b_psi_1', torch.tensor(np.pi), constraint=constraints.less_than(3.15)) x_psi_1 = pyro.param('x_psi_1', torch.tensor(2.), constraint=constraints.positive) a_phi_2 = pyro.param('a_phi_2', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) b_phi_2 = pyro.param('b_phi_2', torch.tensor(np.pi), constraint=constraints.less_than(3.15)) x_phi_2 = pyro.param('x_phi_2', torch.tensor(2.), constraint=constraints.positive) a_psi_2 = pyro.param('a_psi_2', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) b_psi_2 = pyro.param('b_psi_2', torch.tensor(np.pi), constraint=constraints.less_than(3.15)) x_psi_2 = pyro.param('x_psi_2', torch.tensor(2.), constraint=constraints.positive) a_phi_3 = pyro.param('a_phi_3', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) b_phi_3 = pyro.param('b_phi_3', torch.tensor(np.pi), constraint=constraints.less_than(3.15)) x_phi_3 = pyro.param('x_phi_3', torch.tensor(2.), constraint=constraints.positive) # Sampling mu and kappa pyro.sample("mu_psi_1", dist.Uniform(a_psi_1, b_psi_1)) pyro.sample("inv_kappa_psi_1", dist.HalfNormal(x_psi_1)) pyro.sample("mu_phi_2", dist.Uniform(a_phi_2, b_phi_2)) pyro.sample("inv_kappa_phi_2", dist.HalfNormal(x_phi_2)) pyro.sample("mu_psi_2", dist.Uniform(a_psi_2, b_psi_2)) pyro.sample("inv_kappa_psi_2", dist.HalfNormal(x_psi_2)) pyro.sample("mu_phi_3", dist.Uniform(a_phi_3, b_phi_3)) pyro.sample("inv_kappa_phi_3", dist.HalfNormal(x_phi_3))
def parametrized_guide(doc_word_data, category_data, args, batch_size=None): # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param("topic_weights_posterior", lambda: torch.ones(args.num_topics), constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", lambda: torch.ones(args.num_topics, args.num_words), constraint=constraints.greater_than(0.5)) with pyro.plate("topics", args.num_topics): pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.)) pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior)) category_weights_posterior = pyro.param( "category_weights_posterior", lambda: torch.ones(args.num_categories), constraint=constraints.positive) category_topics_posterior = pyro.param( "category_topics_posterior", lambda: torch.ones(args.num_categories, args.num_topics), constraint=constraints.greater_than(0.5)) with pyro.plate("categories", args.num_categories): pyro.sample("category_weights", dist.Gamma(category_weights_posterior, 1.)) pyro.sample("category_topics", dist.Dirichlet(category_topics_posterior)) doc_category_posterior = pyro.param("doc_category_posterior", lambda: torch.ones(args.num_topics), constraint=constraints.positive) with pyro.plate("documents", args.num_docs, batch_size) as ind: pyro.sample("doc_categories", dist.Categorical(doc_category_posterior))
def build_support( lower_bound: Optional[Tensor] = None, upper_bound: Optional[Tensor] = None) -> constraints.Constraint: """Return support for prior distribution, depending on available bounds. Args: lower_bound: lower bound of the prior support, can be None upper_bound: upper bound of the prior support, can be None Returns: support: Pytorch constraint object. """ # Support is real if no bounds are passed. if lower_bound is None and upper_bound is None: support = constraints.real warnings.warn( """No prior bounds were passed, consider passing lower_bound and / or upper_bound if your prior has bounded support.""") # Only lower bound is specified. elif upper_bound is None: num_dimensions = lower_bound.numel() # type: ignore if num_dimensions > 1: support = constraints._IndependentConstraint( constraints.greater_than(lower_bound), 1, ) else: support = constraints.greater_than(lower_bound) # Only upper bound is specified. elif lower_bound is None: num_dimensions = upper_bound.numel() if num_dimensions > 1: support = constraints._IndependentConstraint( constraints.less_than(upper_bound), 1, ) else: support = constraints.less_than(upper_bound) # Both are specified. else: num_dimensions = lower_bound.numel() assert (num_dimensions == upper_bound.numel() ), "There must be an equal number of independent bounds." if num_dimensions > 1: support = constraints._IndependentConstraint( constraints.interval(lower_bound, upper_bound), 1, ) else: support = constraints.interval(lower_bound, upper_bound) return support
def guide(self, data): alpha_posterior = pyro.param("alpha_posterior", lambda: torch.ones(self.n_topics), constraint=positive) beta_posterior = pyro.param( "beta_posterior", lambda: torch.ones(self.n_topics, self.vocab_size), constraint=greater_than(0.5)) with pyro.plate("topics", self.n_topics): alpha = pyro.sample("alpha", dist.Gamma(alpha_posterior, 1.)) betas = pyro.sample("beta", dist.Dirichlet(beta_posterior)) theta = None z = None for d in pyro.plate("doc_loop", len(data)): gamma_q = pyro.param(f"gamma_{d}", torch.ones(self.n_topics), constraint=positive) theta = pyro.sample(f"theta_{d}", dist.Dirichlet(gamma_q)) nwords = len(data[d]) for w in pyro.plate(f"word_loop_{d}", nwords): phi_q = pyro.param(f"phi{d}_{w}", torch.ones(self.n_topics), constraint=positive) z = pyro.sample(f"z{d}_{w}", dist.Categorical(phi_q)) return theta, z, alpha, betas
def __init__(self, X, y, kernel, noise=None, mean_function=None, jitter=1e-6, name="GPR"): super(GPRegression, self).__init__(X, y, kernel, mean_function, jitter, name) noise = self.X.new_ones(()) if noise is None else noise self.noise = Parameter(noise) self.set_constraint("noise", constraints.greater_than(self.jitter))
def parametrized_guide(predictor, data, num_words_per_doc, args): # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", lambda: torch.ones(args.num_topics), constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", lambda: torch.ones(args.num_topics, args.num_words), constraint=constraints.greater_than(0.5)) with pyro.plate("topics", args.num_topics): pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.)) pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior)) # Use an amortized guide for local variables. pyro.module("predictor", predictor) for doc in pyro.plate("documents", args.num_docs, args.batch_size): # data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(args.num_words, 1) for i in data[doc]: counts[i] += 1 # .scatter_add(0, data[doc], torch.ones(data[doc].shape))) doc_topics = predictor(counts.transpose(0, 1)) pyro.sample("doc_topics_{}".format(doc), dist.Delta(doc_topics, event_dim=1)) # added this part since with pyro.plate("words_{}".format(doc), num_words_per_doc[doc]): word_topics = pyro.sample("word_topics_{}".format(doc), dist.Categorical(doc_topics))
def parametrized_guide(predictor, data, args, batch_size=None): # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", lambda: torch.ones(args.num_topics), constraint=constraints.positive, ) topic_words_posterior = pyro.param( "topic_words_posterior", lambda: torch.ones(args.num_topics, args.num_words), constraint=constraints.greater_than(0.5), ) with pyro.plate("topics", args.num_topics): pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.0)) pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior)) # Use an amortized guide for local variables. pyro.module("predictor", predictor) with pyro.plate("documents", args.num_docs, batch_size) as ind: data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(args.num_words, ind.size(0)).scatter_add(0, data, torch.ones(data.shape)) doc_topics = predictor(counts.transpose(0, 1)) pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))
def __init__(self, X, y, kernel, Xu, noise=None, mean_function=None, approx=None, jitter=1e-6, name="SGPR"): super(SparseGPRegression, self).__init__(X, y, kernel, mean_function, jitter, name) self.Xu = Parameter(Xu) noise = self.X.new_ones(()) if noise is None else noise self.noise = Parameter(noise) self.set_constraint("noise", constraints.greater_than(self.jitter)) if approx is None: self.approx = "VFE" elif approx in ["DTC", "FITC", "VFE"]: self.approx = approx else: raise ValueError( "The sparse approximation method should be one of " "'DTC', 'FITC', 'VFE'.")
def model(loc, cov): x = pyro.param("x", torch.randn(2)) y = pyro.param("y", torch.randn(3, 2)) z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1)) pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x) with pyro.plate("y_plate", 3): pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y) with pyro.plate("z_plate", 4): pyro.sample("obs_z", dist.MultivariateNormal(loc, cov), obs=z)
def __init__(self, dof, lambda1, lambda2, nu, validate_args=None): self.dof = dof self.lambda1 = lambda1 self.lambda2 = lambda2 self.nu = nu batch_shape, event_shape = lambda1.shape[:-1], lambda1.shape[-1:] D = event_shape[0] self.arg_constraints['dof'] = constraints.greater_than(2. * D + 1.) super().__init__(batch_shape, event_shape, validate_args=validate_args)
def __init__(self, df: Union[torch.Tensor, Number], covariance_matrix: torch.Tensor = None, precision_matrix: torch.Tensor = None, scale_tril: torch.Tensor = None, validate_args=None): assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \ "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." param = next(p for p in (covariance_matrix, precision_matrix, scale_tril) if p is not None) if param.dim() < 2: raise ValueError("scale_tril must be at least two-dimensional, with optional leading batch dimensions") if isinstance(df, Number): batch_shape = torch.Size(param.shape[:-2]) self.df = torch.tensor(df, dtype=param.dtype, device=param.device) else: batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape) self.df = df.expand(batch_shape) event_shape = param.shape[-2:] if self.df.le(event_shape[-1] - 1).any(): raise ValueError(f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}.") if scale_tril is not None: self.scale_tril = param.expand(batch_shape + (-1, -1)) elif covariance_matrix is not None: self.covariance_matrix = param.expand(batch_shape + (-1, -1)) elif precision_matrix is not None: self.precision_matrix = param.expand(batch_shape + (-1, -1)) self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] - 1) if self.df.lt(event_shape[-1]).any(): warnings.warn("Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim.") super(Wishart, self).__init__(batch_shape, event_shape, validate_args=validate_args) self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))] if scale_tril is not None: self._unbroadcasted_scale_tril = scale_tril elif covariance_matrix is not None: self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix) else: # precision_matrix is not None self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix) # Chi2 distribution is needed for Bartlett decomposition sampling self._dist_chi2 = torch.distributions.chi2.Chi2( df=( self.df.unsqueeze(-1) - torch.arange( self._event_shape[-1], dtype=self._unbroadcasted_scale_tril.dtype, device=self._unbroadcasted_scale_tril.device, ).expand(batch_shape + (-1,)) ) )
def model(loc, cov): x = pyro.param("x", torch.randn(2)) y = pyro.param("y", torch.randn(3, 2)) z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1)) pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x) with pyro.iarange("y_iarange", 3): pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y) with pyro.iarange("z_iarange", 4): pyro.sample("obs_z", dist.MultivariateNormal(loc, cov), obs=z)
def __init__(self, concentration, scale, validate_args=None): self.concentration = torch.as_tensor(concentration) self.scale = torch.as_tensor(scale) batch_shape = self.concentration.shape event_shape = self.scale.shape[-2:] self.arg_constraints['concentration'] = constraints.greater_than( .5 * (event_shape[-1] - 1)) super(Wishart, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def prior(self): """ Prior definition of the weights used for log regression. This function has to set the variables 'self.weight_prior_dist', 'self.weight_mean_init' and 'self.weight_stddev_init'. """ # number of weights num_weights = 4 * (self.num_features + 1) self._sites = OrderedDict() # initial values for mean, scale and prior dist init_mean = torch.ones(num_weights).uniform_(1. + self.epsilon, 2.) init_scale = torch.ones(num_weights) prior = dist.Normal(init_mean, 10 * init_scale, validate_args=True) # we have constraints on the weights to be positive # this is usually solved by defining constraints on the MLE optimizer # however, on MCMC and VI this is not possible # instead, we are using a "shifted" LogNormal to obtain only positive samples if self.method in ['variational', 'mcmc']: # for this purpose, we need to transform the prior mean first and set the # distribution to be a LogNormal init_mean = torch.log(torch.exp(init_mean) - 1) prior = dist.LogNormal(init_mean, 10 * init_scale, validate_args=True) # set properties for "weights": weights must be positive self._sites['weights'] = { 'values': None, 'constraint': constraints.greater_than(self.epsilon), 'init': { 'mean': init_mean, 'scale': init_scale }, 'prior': prior } # set properties for "bias" self._sites['bias'] = { 'values': None, 'constraint': constraints.real, 'init': { 'mean': torch.ones(1) * self.epsilon, 'scale': torch.ones(1) }, 'prior': dist.Normal(torch.zeros(1), 10 * torch.ones(1), validate_args=True), }
def guide(dataset_total_length, x_data, y_data, kl_factor): global n_features, n_hidden, n_out a1_mean = pyro.param('a1_mean', 0.01 * torch.randn(n_features, n_hidden)) a1_scale = pyro.param('a1_scale', 0.1 * torch.ones(n_features, n_hidden), constraint=constraints.greater_than(0.01)) a2_mean = pyro.param('a2_mean', 0.01 * torch.randn(n_hidden + 1, n_out)) a2_scale = pyro.param('a2_scale', 0.1 * torch.ones(n_hidden + 1, n_out), constraint=constraints.greater_than(0.01)) with pyro.plate('map', dataset_total_length, subsample=x_data): # with pyro.plate('map', size=x_data.shape[0]): # sample first hidden layer h1 = pyro.sample('h1', bnn.HiddenLayer(x_data, a1_mean, a1_scale, non_linearity=nnf.leaky_relu, KL_factor=kl_factor)) # sample second hidden layer rate = pyro.sample('rate', bnn.HiddenLayer(h1, a2_mean, a2_scale, non_linearity=lambda x: nnf.relu(x)+1e-3, KL_factor=kl_factor, include_hidden_bias=False))
def __init__(self, df, scale=None, scale_tril=None, validate_args=None): self.df = torch.as_tensor(df) if (scale is None) + (scale_tril is None) != 1: raise ValueError("Exactly one of scale or scale_tril must be specified.") if scale is not None: self.scale = torch.as_tensor(scale) if scale_tril is not None: self.scale_tril = torch.as_tensor(scale_tril) batch_shape = self.df.shape event_shape = self.scale_tril.shape[-2:] self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] - 1) super(Wishart, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def __init__(self, X, y, kernel, noise=None, mean_function=None, jitter=1e-6, name="GPR"): super(GPRegression, self).__init__(X, y, kernel, mean_function, jitter, name) noise = self.X.new_ones(()) if noise is None else noise self.noise = Parameter(noise) self.set_constraint("noise", constraints.greater_than(self.jitter))
def __init__(self, X, y, kernel, Xu, noise=None, mean_function=None, approx=None, jitter=1e-6, name="SGPR"): super(SparseGPRegression, self).__init__(X, y, kernel, mean_function, jitter, name) self.Xu = Parameter(Xu) noise = self.X.new_ones(()) if noise is None else noise self.noise = Parameter(noise) self.set_constraint("noise", constraints.greater_than(self.jitter)) if approx is None: self.approx = "VFE" elif approx in ["DTC", "FITC", "VFE"]: self.approx = approx else: raise ValueError("The sparse approximation method should be one of " "'DTC', 'FITC', 'VFE'.")
def __init__(self, dag: DirectedAcyclicGraph, noise_std_dict: Dict[str, float] = {}, default_noise_std_bounds: Tuple[float, float] = (0.5, 1.0), seed=None): """ Args: dag: DAG, defining SEM noise_std_dict: Noise std dict for each variable default_noise_std_bounds: Default noise std """ super(PostNonLinearMultiplicativeHalfNormalSEM, self).__init__(dag) np.random.seed(seed) if seed is not None else None for node in self.topological_order: self.model[node]['noise_std'] = noise_std_dict[node] if node in noise_std_dict \ else np.random.uniform(*default_noise_std_bounds) self.parents_conditional_support = constraints.greater_than(0.0)
def _parametrized_guide(self, predictor, data, labels, batch_size=None): args = self.args # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", lambda: torch.ones(args.num_topics), constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", lambda: torch.ones(args.num_topics, args.num_words), constraint=constraints.greater_than(0.5)) # Is this needed? label_posterior = pyro.param("label_posterior", lambda: torch.ones(2, args.num_topics), constraint=constraints.positive) with pyro.plate("topics", args.num_topics): pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.)) pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior)) pyro.sample("label_prior", dist.Beta(*label_posterior)) # Use an amortized guide for local variables. pyro.module("predictor", predictor) with pyro.plate("documents", args.num_docs, batch_size) as ind: data = data[:, ind] labels = labels[ind] counts = (torch.zeros(args.num_words, ind.size(0)).scatter_add( 0, data, torch.ones(data.shape))) augmented_input = torch.zeros(batch_size, args.num_words + args.num_topics) augmented_input[:, :args.num_words] = counts.transpose(0, 1) augmented_input[:, args.num_words:] = labels doc_topics = predictor(augmented_input) pyro.sample("doc_topics", dist.Delta(doc_topics).to_event(1))
def guide(self, docs=None, doc_sum=None): # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", lambda: torch.ones(self.num_topics, device=self.device), constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", lambda: torch.ones( self.num_topics, self.vocab_size, device=self.device), constraint=constraints.greater_than(0.5)) with pyro.plate("topics", self.num_topics): pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.)) pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior)) # Use an amortized guide for local variables. pyro.module("inference_net", self.inference_net) with pyro.plate("documents", doc_sum.shape[0]): doc_topics = self.inference_net(doc_sum) pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))
def guide(n_ice, n_obs, floe_size, cover_subp): a_floe_loc = pyro.param('a_floe_loc', torch.tensor(0.)) b_floe_loc = pyro.param('b_floe_loc', torch.tensor(0.)) a_floe_scale = pyro.param('a_floe_scale', torch.tensor(1.), constraint=constraints.greater_than(0.01)) b_floe_scale = pyro.param('b_floe_scale', torch.tensor(1.), constraint=constraints.greater_than(0.01)) a_cover_loc = pyro.param('a_cover_loc', torch.tensor(1.)) b_cover_loc = pyro.param('b_cover_loc', torch.tensor(1.)) a_cover_b_loc = pyro.param('a_cover_b_loc', torch.tensor(1.)) b_cover_b_loc = pyro.param('b_cover_b_loc', torch.tensor(1.)) a_cover_scale = pyro.param('a_cover_scale', torch.tensor(1.), constraint=constraints.greater_than(0.01)) b_cover_scale = pyro.param('b_cover_scale', torch.tensor(1.), constraint=constraints.greater_than(0.01)) a_cover_b_scale = pyro.param('a_cover_b_scale', torch.tensor(1.), constraint=constraints.greater_than(0.01)) b_cover_b_scale = pyro.param('b_cover_b_scale', torch.tensor(1.), constraint=constraints.greater_than(0.01)) a_floe = pyro.sample("a_floe", dist.LogNormal(a_floe_loc, a_floe_scale)) b_floe = pyro.sample("b_floe", dist.LogNormal(b_floe_loc, b_floe_scale)) a_cover = pyro.sample("a_cover", dist.LogNormal(a_cover_loc, a_cover_scale)) b_cover = pyro.sample("b_cover", dist.LogNormal(b_cover_loc, b_cover_scale)) a_cover_b = pyro.sample("a_cover_b", dist.LogNormal(a_cover_b_loc, a_cover_b_scale)) b_cover_b = pyro.sample("b_cover_b", dist.LogNormal(b_cover_b_loc, b_cover_b_scale)) lambda_ice = sigmoid((a_floe * floe_size + b_floe)) alpha_det = sigmoid((a_cover * cover_subp + b_cover)) beta_det = sigmoid((a_cover_b * cover_subp + b_cover_b)) with pyro.plate('subp', size=len(floe_size), subsample_size=10): phi_det = pyro.sample('phi_det', Beta(alpha_det, beta_det))
def support(self): return constraints.greater_than(self.scale)
def guide(data, args, batch_size=None): data = torch.reshape(data, [64, 1000]) pyro.module("layer1", layer1) pyro.module("layer2", layer2) pyro.module("layer3", layer3) # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", # WL: edited. ===== # # lambda: torch.ones(8) / 8, # torch.ones(8) / 8, torch.ones(8), # ================= constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", # WL: edited. ===== # # lambda: torch.ones(8, 1024) / 1024, # torch.ones(8, 1024) / 1024, # constraint=constraints.positive) torch.ones(8, 1024), constraint=constraints.greater_than(0.5)) # ================= """ # wy: dummy param for word_topics word_topics_posterior = pyro.param( "word_topics_posterior", torch.ones(64, 1024, 8) / 8, constraint=constraints.positive) """ with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.)) # shape = [8] + [1024] topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior)) # Use an amortized guide for local variables. with pyro.plate("documents", 1000, 32) as ind: # shape = [64, 32] data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(1024, 32) counts = torch.Tensor.scatter_add_\ (counts, 0, data, torch.Tensor.expand(torch.tensor(1.), [1024, 32])) h1 = sigmoid(layer1(torch.transpose(counts, 0, 1))) h2 = sigmoid(layer2(h1)) # shape = [32, 8] doc_topics_w = sigmoid(layer3(h2)) # shape = [32] + [8] # WL: edited. ===== # # # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w, event_dim=1)) # # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w).to_event(1)) # doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics_w)) doc_topics = softmax(doc_topics_w) pyro.sample("doc_topics", Delta(doc_topics, event_dim=1)) # ================= """
def support(self): return constraints.greater_than(self.scale)
class Wishart(ExponentialFamily): r""" Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`, or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top` Example: >>> m = Wishart(torch.eye(2), torch.Tensor([2])) >>> m.sample() # Wishart distributed with mean=`df * I` and # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j Args: covariance_matrix (Tensor): positive-definite covariance matrix precision_matrix (Tensor): positive-definite precision matrix scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1 Note: Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or :attr:`scale_tril` can be specified. Using :attr:`scale_tril` will be more efficient: all computations internally are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or :attr:`precision_matrix` is passed instead, it is only used to compute the corresponding lower triangular matrices using a Cholesky decomposition. 'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1] **References** [1] `On equivalence of the LKJ distribution and the restricted Wishart distribution`, Zhenxun Wang, Yunan Wu, Haitao Chu. """ arg_constraints = { 'covariance_matrix': constraints.positive_definite, 'precision_matrix': constraints.positive_definite, 'scale_tril': constraints.lower_cholesky, 'df': constraints.greater_than(0), } support = constraints.positive_definite has_rsample = True _mean_carrier_measure = 0 def __init__(self, df: Union[torch.Tensor, Number], covariance_matrix: torch.Tensor = None, precision_matrix: torch.Tensor = None, scale_tril: torch.Tensor = None, validate_args=None): assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \ "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." param = next(p for p in (covariance_matrix, precision_matrix, scale_tril) if p is not None) if param.dim() < 2: raise ValueError( "scale_tril must be at least two-dimensional, with optional leading batch dimensions" ) if isinstance(df, Number): batch_shape = torch.Size(param.shape[:-2]) self.df = torch.tensor(df, dtype=param.dtype, device=param.device) else: batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape) self.df = df.expand(batch_shape) event_shape = param.shape[-2:] if self.df.le(event_shape[-1] - 1).any(): raise ValueError( f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}." ) if scale_tril is not None: self.scale_tril = param.expand(batch_shape + (-1, -1)) elif covariance_matrix is not None: self.covariance_matrix = param.expand(batch_shape + (-1, -1)) elif precision_matrix is not None: self.precision_matrix = param.expand(batch_shape + (-1, -1)) self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] - 1) if self.df.lt(event_shape[-1]).any(): warnings.warn( "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim." ) super(Wishart, self).__init__(batch_shape, event_shape, validate_args=validate_args) self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))] if scale_tril is not None: self._unbroadcasted_scale_tril = scale_tril elif covariance_matrix is not None: self._unbroadcasted_scale_tril = torch.linalg.cholesky( covariance_matrix) else: # precision_matrix is not None self._unbroadcasted_scale_tril = _precision_to_scale_tril( precision_matrix) # Chi2 distribution is needed for Bartlett decomposition sampling self._dist_chi2 = torch.distributions.chi2.Chi2( df=(self.df.unsqueeze(-1) - torch.arange( self._event_shape[-1], dtype=self._unbroadcasted_scale_tril.dtype, device=self._unbroadcasted_scale_tril.device, ).expand(batch_shape + (-1, )))) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Wishart, _instance) batch_shape = torch.Size(batch_shape) cov_shape = batch_shape + self.event_shape new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand( cov_shape) new.df = self.df.expand(batch_shape) new._batch_dims = [-(x + 1) for x in range(len(batch_shape))] if 'covariance_matrix' in self.__dict__: new.covariance_matrix = self.covariance_matrix.expand(cov_shape) if 'scale_tril' in self.__dict__: new.scale_tril = self.scale_tril.expand(cov_shape) if 'precision_matrix' in self.__dict__: new.precision_matrix = self.precision_matrix.expand(cov_shape) # Chi2 distribution is needed for Bartlett decomposition sampling new._dist_chi2 = torch.distributions.chi2.Chi2( df=(new.df.unsqueeze(-1) - torch.arange( self.event_shape[-1], dtype=new._unbroadcasted_scale_tril.dtype, device=new._unbroadcasted_scale_tril.device, ).expand(batch_shape + (-1, )))) super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new @lazy_property def scale_tril(self): return self._unbroadcasted_scale_tril.expand(self._batch_shape + self._event_shape) @lazy_property def covariance_matrix(self): return (self._unbroadcasted_scale_tril @ self._unbroadcasted_scale_tril.transpose( -2, -1)).expand(self._batch_shape + self._event_shape) @lazy_property def precision_matrix(self): identity = torch.eye( self._event_shape[-1], device=self._unbroadcasted_scale_tril.device, dtype=self._unbroadcasted_scale_tril.dtype, ) return torch.cholesky_solve( identity, self._unbroadcasted_scale_tril).expand(self._batch_shape + self._event_shape) @property def mean(self): return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix @property def variance(self): V = self.covariance_matrix # has shape (batch_shape x event_shape) diag_V = V.diagonal(dim1=-2, dim2=-1) return self.df.view(self._batch_shape + (1, 1)) * ( V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V)) def _bartlett_sampling(self, sample_shape=torch.Size()): p = self._event_shape[-1] # has singleton shape # Implemented Sampling using Bartlett decomposition noise = _clamp_above_eps( self._dist_chi2.rsample(sample_shape).sqrt()).diag_embed(dim1=-2, dim2=-1) i, j = torch.tril_indices(p, p, offset=-1) noise[..., i, j] = torch.randn( torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2), ), dtype=noise.dtype, device=noise.device, ) chol = self._unbroadcasted_scale_tril @ noise return chol @ chol.transpose(-2, -1) def rsample(self, sample_shape=torch.Size(), max_try_correction=None): r""" .. warning:: In some cases, sampling algorithn based on Bartlett decomposition may return singular matrix samples. Several tries to correct singular samples are performed by default, but it may end up returning singular matrix samples. Sigular samples may return `-inf` values in `.log_prob()`. In those cases, the user should validate the samples and either fix the value of `df` or adjust `max_try_correction` value for argument in `.rsample` accordingly. """ if max_try_correction is None: max_try_correction = 3 if torch._C._get_tracing_state() else 10 sample_shape = torch.Size(sample_shape) sample = self._bartlett_sampling(sample_shape) # Below part is to improve numerical stability temporally and should be removed in the future is_singular = self.support.check(sample) if self._batch_shape: is_singular = is_singular.amax(self._batch_dims) if torch._C._get_tracing_state(): # Less optimized version for JIT for _ in range(max_try_correction): sample_new = self._bartlett_sampling(sample_shape) sample = torch.where(is_singular, sample_new, sample) is_singular = ~self.support.check(sample) if self._batch_shape: is_singular = is_singular.amax(self._batch_dims) else: # More optimized version with data-dependent control flow. if is_singular.any(): warnings.warn("Singular sample detected.") for _ in range(max_try_correction): sample_new = self._bartlett_sampling( is_singular[is_singular].shape) sample[is_singular] = sample_new is_singular_new = ~self.support.check(sample_new) if self._batch_shape: is_singular_new = is_singular_new.amax( self._batch_dims) is_singular[is_singular.clone()] = is_singular_new if not is_singular.any(): break return sample def log_prob(self, value): if self._validate_args: self._validate_sample(value) nu = self.df # has shape (batch_shape) p = self._event_shape[-1] # has singleton shape return (-nu * (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal( dim1=-2, dim2=-1).log().sum(-1)) - torch.mvlgamma(nu / 2, p=p) + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet - torch.cholesky_solve(value, self._unbroadcasted_scale_tril). diagonal(dim1=-2, dim2=-1).sum(dim=-1) / 2) def entropy(self): nu = self.df # has shape (batch_shape) p = self._event_shape[-1] # has singleton shape V = self.covariance_matrix # has shape (batch_shape x event_shape) return ((p + 1) * (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal( dim1=-2, dim2=-1).log().sum(-1)) + torch.mvlgamma(nu / 2, p=p) - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p) + nu * p / 2) @property def _natural_params(self): nu = self.df # has shape (batch_shape) p = self._event_shape[-1] # has singleton shape return -self.precision_matrix / 2, (nu - p - 1) / 2 def _log_normalizer(self, x, y): p = self._event_shape[-1] return ((y + (p + 1) / 2) * (-torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p) + torch.mvlgamma(y + (p + 1) / 2, p=p))
def _init_parameters(self): """ Parameters shared between different models. """ device = self.device data = self.data pyro.param( "proximity_loc", lambda: torch.tensor(0.5, device=device), constraint=constraints.interval( 0, (self.data.P + 1) / math.sqrt(12) - torch.finfo(self.dtype).eps, ), ) pyro.param( "proximity_size", lambda: torch.tensor(100, device=device), constraint=constraints.greater_than(2.0), ) pyro.param( "lamda_loc", lambda: torch.tensor(0.5, device=device), constraint=constraints.positive, ) pyro.param( "lamda_beta", lambda: torch.tensor(100, device=device), constraint=constraints.positive, ) pyro.param( "gain_loc", lambda: torch.tensor(5, device=device), constraint=constraints.positive, ) pyro.param( "gain_beta", lambda: torch.tensor(100, device=device), constraint=constraints.positive, ) pyro.param( "background_mean_loc", lambda: torch.full( (data.Nt, 1), data.median - self.data.offset.mean, device=device, ), constraint=constraints.positive, ) pyro.param( "background_std_loc", lambda: torch.ones(data.Nt, 1, device=device), constraint=constraints.positive, ) pyro.param( "b_loc", lambda: torch.full( (data.Nt, data.F), data.median - self.data.offset.mean, device=device, ), constraint=constraints.positive, ) pyro.param( "b_beta", lambda: torch.ones(data.Nt, data.F, device=device), constraint=constraints.positive, ) pyro.param( "h_loc", lambda: torch.full((self.K, data.Nt, data.F), 2000, device=device), constraint=constraints.positive, ) pyro.param( "h_beta", lambda: torch.full( (self.K, data.Nt, data.F), 0.001, device=device), constraint=constraints.positive, ) pyro.param( "w_mean", lambda: torch.full((self.K, data.Nt, data.F), 1.5, device=device), constraint=constraints.interval( 0.75 + torch.finfo(self.dtype).eps, 2.25 - torch.finfo(self.dtype).eps, ), ) pyro.param( "w_size", lambda: torch.full((self.K, data.Nt, data.F), 100, device=device), constraint=constraints.greater_than(2.0), ) pyro.param( "x_mean", lambda: torch.zeros(self.K, data.Nt, data.F, device=device), constraint=constraints.interval( -(data.P + 1) / 2 + torch.finfo(self.dtype).eps, (data.P + 1) / 2 - torch.finfo(self.dtype).eps, ), ) pyro.param( "y_mean", lambda: torch.zeros(self.K, data.Nt, data.F, device=device), constraint=constraints.interval( -(data.P + 1) / 2 + torch.finfo(self.dtype).eps, (data.P + 1) / 2 - torch.finfo(self.dtype).eps, ), ) pyro.param( "size", lambda: torch.full((self.K, data.Nt, data.F), 200, device=device), constraint=constraints.greater_than(2.0), )
return x def _constraint_hash(constraint: constraints.Constraint) -> int: assert isinstance(constraint, constraints.Constraint) out = hash(type(constraint)) out ^= hash(frozenset(constraint.__dict__.items())) return out # useful if the distribution's init has multiple ways of specifying (e.g. both logits or probs) _distribution_to_param_names = {NegativeBinomial: ['probs', 'total_count']} # torch.distributions has a whole 'transforms' module but I don't know if they provide a mapping _constraint_to_ilink = { _constraint_hash(constraints.positive): torch.exp, _constraint_hash(constraints.greater_than(0)): torch.exp, _constraint_hash(constraints.unit_interval): torch.sigmoid, _constraint_hash(constraints.real): identity, # TODO: is there a way to make these work? _constraint_hash(constraints.greater_than_eq(0)): torch.exp, _constraint_hash(constraints.half_open_interval(0, 1)): torch.sigmoid, # TODO: constraints.interval }
SHAPE_CONSTRAINT = [ ((), constraints.real), ((4, ), constraints.real), ((3, 2), constraints.real), ((), constraints.positive), ((4, ), constraints.positive), ((3, 2), constraints.positive), ((5, ), constraints.simplex), (( 2, 5, ), constraints.simplex), ((5, 5), constraints.lower_cholesky), ((2, 5, 5), constraints.lower_cholesky), ((10, ), constraints.greater_than(-torch.randn(10).exp())), ((4, 10), constraints.greater_than(-torch.randn(10).exp())), ((4, 10), constraints.greater_than(-torch.randn(4, 10).exp())), ((3, 2, 10), constraints.greater_than(-torch.randn(10).exp())), ((3, 2, 10), constraints.greater_than(-torch.randn(2, 10).exp())), ((3, 2, 10), constraints.greater_than(-torch.randn(3, 1, 10).exp())), ((3, 2, 10), constraints.greater_than(-torch.randn(3, 2, 10).exp())), ((5, ), constraints.real_vector), (( 2, 5, ), constraints.real_vector), ((), constraints.unit_interval), ((4, ), constraints.unit_interval), ((3, 2), constraints.unit_interval), ((10, ), constraints.interval(-torch.randn(10).exp(),
def guide(self, images, labels=None, kl_factor=1.0): images = images.view(-1, 784) n_images = images.size(0) # Set-up parameters to be optimized to approximate the true posterior # Mean parameters are randomly initialized to small values around 0, and scale parameters # are initialized to be 0.1 to be closer to the expected posterior value which we assume is stronger than # the prior scale of 1. # Scale parameters must be positive, so we constraint them to be larger than some epsilon value (0.01). # Variational dropout are initialized as in the prior model, and constrained to be between 0.1 and 1 (so dropout # rate is between 0.1 and 0.5) as suggested in the local reparametrization paper a1_mean = pyro.param('a1_mean', 0.01 * torch.randn(784, self.n_hidden)) a1_scale = pyro.param('a1_scale', 0.1 * torch.ones(784, self.n_hidden), constraint=constraints.greater_than(0.01)) a1_dropout = pyro.param('a1_dropout', torch.tensor(0.25), constraint=constraints.interval(0.1, 1.0)) a2_mean = pyro.param( 'a2_mean', 0.01 * torch.randn(self.n_hidden + 1, self.n_hidden)) a2_scale = pyro.param('a2_scale', 0.1 * torch.ones(self.n_hidden + 1, self.n_hidden), constraint=constraints.greater_than(0.01)) a2_dropout = pyro.param('a2_dropout', torch.tensor(1.0), constraint=constraints.interval(0.1, 1.0)) a3_mean = pyro.param( 'a3_mean', 0.01 * torch.randn(self.n_hidden + 1, self.n_hidden)) a3_scale = pyro.param('a3_scale', 0.1 * torch.ones(self.n_hidden + 1, self.n_hidden), constraint=constraints.greater_than(0.01)) a3_dropout = pyro.param('a3_dropout', torch.tensor(1.0), constraint=constraints.interval(0.1, 1.0)) a4_mean = pyro.param( 'a4_mean', 0.01 * torch.randn(self.n_hidden + 1, self.n_classes)) a4_scale = pyro.param('a4_scale', 0.1 * torch.ones(self.n_hidden + 1, self.n_classes), constraint=constraints.greater_than(0.01)) # Sample latent values using the variational parameters that are set-up above. # Notice how there is no conditioning on labels in the guide! with pyro.plate('data', size=n_images): h1 = pyro.sample( 'h1', bnn.HiddenLayer(images, a1_mean, a1_dropout * a1_scale, non_linearity=nnf.leaky_relu, KL_factor=kl_factor)) h2 = pyro.sample( 'h2', bnn.HiddenLayer(h1, a2_mean, a2_dropout * a2_scale, non_linearity=nnf.leaky_relu, KL_factor=kl_factor)) h3 = pyro.sample( 'h3', bnn.HiddenLayer(h2, a3_mean, a3_dropout * a3_scale, non_linearity=nnf.leaky_relu, KL_factor=kl_factor)) logits = pyro.sample( 'logits', bnn.HiddenLayer( h3, a4_mean, a4_scale, non_linearity=lambda x: nnf.log_softmax(x, dim=-1), KL_factor=kl_factor, include_hidden_bias=False))
class Normal(ExponentialFamily): r""" Creates a normal (also called Gaussian) distribution parameterized by :attr:`loc` and :attr:`scale`. Example:: >>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # normally distributed with loc=0 and scale=1 tensor([ 0.1046]) Args: loc (float or Tensor): mean of the distribution (often referred to as mu) scale (float or Tensor): standard deviation of the distribution (often referred to as sigma) """ arg_constraints = {'loc': constraints.real, 'scale': constraints.greater_than(8.)} support = constraints.real has_rsample = True _mean_carrier_measure = 0 @property def mean(self): return self.loc @property def stddev(self): return self.scale @property def variance(self): return self.stddev.pow(2) def __init__(self, loc, scale, validate_args=None): self.loc, self.scale = broadcast_all(loc, scale) if isinstance(loc, Number) and isinstance(scale, Number): batch_shape = torch.Size() else: batch_shape = self.loc.size() super(Normal, self).__init__(batch_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Normal, _instance) batch_shape = torch.Size(batch_shape) new.loc = self.loc.expand(batch_shape) new.scale = self.scale.expand(batch_shape) super(Normal, new).__init__(batch_shape, validate_args=False) new._validate_args = self._validate_args return new def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) with torch.no_grad(): return torch.normal(self.loc.expand(shape), self.scale.expand(shape)) def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) return self.loc + eps * self.scale def log_prob(self, value): if self._validate_args: self._validate_sample(value) # compute the variance var = (self.scale ** 2) log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log() return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi)) def cdf(self, value): if self._validate_args: self._validate_sample(value) return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2))) def icdf(self, value): if self._validate_args: self._validate_sample(value) return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2) def entropy(self): return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale) @property def _natural_params(self): return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal()) def _log_normalizer(self, x, y): return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)
def z(self): return constraints.greater_than(self.y)
def y(self): return constraints.greater_than(self.x)