Beispiel #1
0
class TestMultivariateNormal(TestCase):
    """
    Tests if the gradients of batch_log_pdf are the same regardless of normalization. The test is run once for a
    distribution that is parameterized by the full covariance matrix and once for one that is parameterized by the
    cholesky decomposition of the covariance matrix.
    """

    def setUp(self):
        N = 400
        self.L_tensor = torch.tril(1e-3 * torch.ones(N, N)).t()
        self.mu = Variable(torch.rand(N))
        self.L = Variable(self.L_tensor, requires_grad=True)
        self.sigma = Variable(torch.mm(self.L_tensor.t(), self.L_tensor), requires_grad=True)
        # Draw from an unrelated distribution as not to interfere with the gradients
        self.sample = Variable(torch.randn(N))

        self.cholesky_mv_normalized = MultivariateNormal(self.mu, scale_tril=self.L, normalized=True)
        self.cholesky_mv = MultivariateNormal(self.mu, scale_tril=self.L, normalized=False)

        self.full_mv_normalized = MultivariateNormal(self.mu, self.sigma, normalized=True)
        self.full_mv = MultivariateNormal(self.mu, self.sigma, normalized=False)

    def test_log_pdf_gradients_cholesky(self):
        grad1 = grad([self.cholesky_mv.log_pdf(self.sample)], [self.L])[0].data
        grad2 = grad([self.cholesky_mv_normalized.log_pdf(self.sample)], [self.L])[0].data
        assert_equal(grad1, grad2)

    def test_log_pdf_gradients(self):
        grad1 = grad([self.full_mv.log_pdf(self.sample)], [self.sigma])[0].data
        grad2 = grad([self.full_mv_normalized.log_pdf(self.sample)], [self.sigma])[0].data
        assert_equal(grad1, grad2)
Beispiel #2
0
    def posterior(
            self,
            potentials: MultivariateNormal) -> 'MultivariateNormalMixture':
        means = potentials.mean.unsqueeze(1)  # (N, 1, D)
        precs = potentials.precision_matrix.unsqueeze(1)  # (N, 1, D, D)
        covs = potentials.covariance_matrix.unsqueeze(1)  # (N, 1, D, D)

        prior_means = self.components.mean.unsqueeze(0)  # (1, K, D)
        prior_precs = self.components.precision_matrix.unsqueeze(
            0)  # (1, K, D, D)
        prior_covs = self.components.covariance_matrix.unsqueeze(
            0)  # (1, K, D, D)

        post_precs = precs + prior_precs
        post_means = posdef_solve(
            precs @ means[..., None] + prior_precs @ prior_means[..., None],
            post_precs)[0].squeeze(-1)
        post_components = MultivariateNormal(post_means,
                                             precision_matrix=post_precs)

        post_lognorm = MultivariateNormal(prior_means,
                                          covs + prior_covs).log_prob(means)
        post_logits = self.mixing.logits + post_lognorm

        return MultivariateNormalMixture(Categorical(logits=post_logits),
                                         post_components)
Beispiel #3
0
def test_log_prob():
    loc = torch.tensor([2.0, 1.0, 1.0, 2.0, 2.0])
    D = torch.tensor([1.0, 2.0, 3.0, 1.0, 3.0])
    W = torch.tensor([[1.0, -1.0, 2.0, 2.0, 4.0], [2.0, 1.0, 1.0, 2.0, 6.0]])
    x = torch.tensor([2.0, 3.0, 4.0, 1.0, 7.0])
    cov = D.diag() + W.t().matmul(W)

    mvn = MultivariateNormal(loc, cov)
    lowrank_mvn = LowRankMultivariateNormal(loc, W, D)

    assert_equal(mvn.log_prob(x), lowrank_mvn.log_prob(x))
Beispiel #4
0
def test_log_prob():
    loc = torch.tensor([2.0, 1.0, 1.0, 2.0, 2.0])
    D = torch.tensor([1.0, 2.0, 3.0, 1.0, 3.0])
    W = torch.tensor([[1.0, -1.0, 2.0, 2.0, 4.0], [2.0, 1.0, 1.0, 2.0, 6.0]])
    x = torch.tensor([2.0, 3.0, 4.0, 1.0, 7.0])
    L = D.diag() + torch.tril(W.t().matmul(W))
    cov = torch.mm(L, L.t())

    mvn = MultivariateNormal(loc, cov)
    omt_mvn = OMTMultivariateNormal(loc, L)
    assert_equal(mvn.log_prob(x), omt_mvn.log_prob(x))
Beispiel #5
0
    def get_loglikelihood(self, X, U, mask, X_init, U_init, mask_init):
        with torch.no_grad():

            total_num = mask.sum()
            loglik = torch.zeros((total_num))
            counter = 0
            T_max = U.size(1)

            # If no initial U image, compute training LL
            if U_init is None:
                # Feed forward
                _, locs, mix, covs = self.model(X=X, U=U, mask=mask)

                # For each time interval
                for t in range(0, T_max):
                    # Distributions for current time interval
                    fn_dist = MN(loc=locs[:, t, :, :],
                                 scale_tril=covs[:, t, :, :, :])
                    # Compute LL for all data points
                    for tt in range(0, mask[t]):
                        loglik[counter] = torch.log((mix[:, t, :] * torch.exp(
                            fn_dist.log_prob(X[:, t, tt, :].squeeze()))
                                                     ).sum() + 1e-16)
                        counter += 1
            else:
                # Concatenate
                U_cat = torch.cat((U_init, U), 1)
                X_cat = torch.cat((X_init, X), 1)
                mask_cat = np.hstack((mask_init, mask))
                # Feed all data through network
                _, locs, mix, covs = self.model(X=X_cat,
                                                U=U_cat,
                                                mask=mask_cat)
                # Discard initialisation parameters
                T_init = U_init.size(1)
                locs = locs[:, T_init:, :, :]
                mix = mix[:, T_init:, :]
                covs = covs[:, T_init:, :, :, :]

                # For each time interval
                for t in range(0, T_max):
                    # Distributions for current time interval
                    fn_dist = MN(loc=locs[:, t, :, :],
                                 scale_tril=covs[:, t, :, :, :])
                    # Compute LL for all data points
                    for tt in range(0, mask[t]):
                        loglik[counter] = torch.log((mix[:, t, :] * torch.exp(
                            fn_dist.log_prob(X[:, t, tt, :].squeeze()))
                                                     ).sum() + 1e-16)
                        counter += 1

        return loglik
Beispiel #6
0
def test_log_prob(mvn_dist):
    loc = torch.tensor([2.0, 1.0, 1.0, 2.0, 2.0])
    D = torch.tensor([1.0, 2.0, 3.0, 1.0, 3.0])
    W = torch.tensor([[1.0, -1.0, 2.0, 2.0, 4.0], [2.0, 1.0, 1.0, 2.0, 6.0]])
    x = torch.tensor([2.0, 3.0, 4.0, 1.0, 7.0])
    L = D.diag() + torch.tril(W.t().matmul(W))
    cov = torch.mm(L, L.t())

    mvn = MultivariateNormal(loc, cov)
    if mvn_dist == OMTMultivariateNormal:
        mvn_prime = OMTMultivariateNormal(loc, L)
    elif mvn_dist == AVFMultivariateNormal:
        CV = 0.2 * torch.rand(2, 2, 5)
        mvn_prime = AVFMultivariateNormal(loc, L, CV)
    assert_equal(mvn.log_prob(x), mvn_prime.log_prob(x))
    def forward(self, t):
        # t, yt: (n,)
        n = len(t)
        dt = t[1] - t[0]
        nzero = torch.zeros(n)

        # sample the prior
        a = pyro.sample("a", Uniform(*self.a_bounds))
        log_lscale = pyro.sample("log_lscale",
                                 Uniform(*self.log_lscale_bounds))
        lscale = torch.exp(log_lscale)
        log_sigma = pyro.sample("log_sigma", Uniform(*self.logs_bounds))
        tdist = t.unsqueeze(-1) - t  # (n,n)
        b_sigma = pyro.sample("b_sigma", Uniform(*self.b_bounds))
        b_cov = b_sigma * b_sigma * torch.exp(
            -tdist * tdist / (2 * lscale * lscale)) + torch.eye(n) * 1e-5
        b = pyro.sample("b", MultivariateNormal(nzero, b_cov))

        # calculate the rate
        int_bdt = torch.cumsum(b, dim=0) * dt
        mu = a + int_bdt  # (n,)

        # simulate the observation
        logysim = pyro.sample("logyt", Normal(mu, torch.exp(log_sigma)))

        return logysim
Beispiel #8
0
def guide_t0(data):
    # T-1 alpha params for beta sampling
    kappa = pyro.param('kappa',
                       lambda: Uniform(0, 2).sample([T - 1]),
                       constraint=constraints.positive)

    # concentration params for q_theta #[T,C]
    tau = pyro.param('tau',
                     lambda: MultivariateNormal(0.5 * torch.ones(C), 0.25 *
                                                torch.eye(C)).sample([T]),
                     constraint=constraints.unit_interval)

    # N params for categorical dist; topic weights; symmetric prior
    phi = pyro.param('phi',
                     lambda: Dirichlet(1 / T * torch.ones(T)).sample([N]),
                     constraint=constraints.simplex)

    with pyro.plate("beta_plate", T - 1):
        q_beta = 0
        q_beta += pyro.sample("beta", Beta(torch.ones(T - 1), kappa))
        # q_beta *= 1

    # sample probs for multinomial distributions
    with pyro.plate("theta_plate", T):
        # outputs multinomial probabilities for each topic
        q_theta = 0
        q_theta += pyro.sample("theta", Dirichlet(tau))
        # q_theta *= 1

    with pyro.plate("data", N):
        z = 0
        z += pyro.sample("z", Categorical(phi))
def model():
    prior = MultivariateNormal(torch.zeros(m),torch.Tensor(K))
    fs = pyro.sample("fs",prior)
    likelihood = Bernoulli(probs = (fs > 0).float())
    # softprobs = torch.sigmoid(fs)
    # likelihood = Bernoulli(probs = softprobs)
    ys = pyro.sample("ys",likelihood)
    return ys
Beispiel #10
0
def random_mvn(loc_shape, cov_shape, dim):
    """
    Generate a random MultivariateNormal distribution for testing.
    """
    rank = dim + dim
    loc = torch.randn(loc_shape + (dim, ), requires_grad=True)
    cov = torch.randn(cov_shape + (dim, rank), requires_grad=True)
    cov = cov.matmul(cov.transpose(-1, -2))
    return MultivariateNormal(loc, cov)
Beispiel #11
0
def test_scale_tril():
    loc = torch.tensor([1.0, 2.0, 1.0, 2.0, 0.0])
    D = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
    W = torch.tensor([[1.0, -1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 1.0, 2.0, 4.0]])
    cov = D.diag() + W.t().matmul(W)

    mvn = MultivariateNormal(loc, cov)
    lowrank_mvn = LowRankMultivariateNormal(loc, W, D)

    assert_equal(mvn.scale_tril, lowrank_mvn.scale_tril)
Beispiel #12
0
def test_variance():
    loc = torch.tensor([1.0, 1.0, 1.0, 2.0, 0.0])
    D = torch.tensor([1.0, 2.0, 2.0, 4.0, 5.0])
    W = torch.tensor([[3.0, 2.0], [-1.0, 3.0], [3.0, 1.0], [3.0, 3.0], [4.0, 4.0]])
    cov = D.diag() + W.matmul(W.t())

    mvn = MultivariateNormal(loc, cov)
    lowrank_mvn = LowRankMultivariateNormal(loc, W, D)

    assert_equal(mvn.variance, lowrank_mvn.variance)
 def _update_r_dist(self):
     loc = self._inverse_mass_matrix.new_zeros(self._inverse_mass_matrix.size(0))
     if self.is_diag_mass:
         self._r_dist = Normal(
             loc.repeat(self.batch_size, 1),
             self._inverse_mass_matrix.rsqrt())
     else:
         r_dist_dim = loc.shape[0] * self.batch_size
         self._r_dist = MultivariateNormal(
             loc.repeat(self.batch_size, 1),
             precision_matrix=self._inverse_mass_matrix)
Beispiel #14
0
def gaussian_model(data):
    mu = tensor([0., 0.])
    diag1 = pyro.sample("diag1", Normal(0., 2.))
    diag2 = pyro.sample("diag2", Normal(0., 2.))
    L = torch.tensor([[diag1, 0.], [0., diag2]], requires_grad=True)

    #pdb.set_trace()

    gaussian = MultivariateNormal(tensor([0., 0.]), scale_tril=L)

    for i in range(data.size(0)):
        pyro.sample("obs_{}".format(i), gaussian, obs=data[i])
Beispiel #15
0
    def setUp(self):
        N = 400
        self.L_tensor = torch.tril(1e-3 * torch.ones(N, N)).t()
        self.mu = Variable(torch.rand(N))
        self.L = Variable(self.L_tensor, requires_grad=True)
        self.sigma = Variable(torch.mm(self.L_tensor.t(), self.L_tensor), requires_grad=True)
        # Draw from an unrelated distribution as not to interfere with the gradients
        self.sample = Variable(torch.randn(N))

        self.cholesky_mv_normalized = MultivariateNormal(self.mu, scale_tril=self.L, normalized=True)
        self.cholesky_mv = MultivariateNormal(self.mu, scale_tril=self.L, normalized=False)

        self.full_mv_normalized = MultivariateNormal(self.mu, self.sigma, normalized=True)
        self.full_mv = MultivariateNormal(self.mu, self.sigma, normalized=False)
Beispiel #16
0
    def _guide(self):
        """
        Pyro variational posterior ("guide")
        """

        q_mu = pyro.param("q_mu", torch.zeros(self.num_parameters))
        q_sqrt = pyro.param(
            "q_sqrt",
            torch.eye(self.num_parameters),
            constraint=torch.distributions.constraints.lower_cholesky,
        )

        pyro.sample("theta", MultivariateNormal(q_mu, scale_tril=q_sqrt))
Beispiel #17
0
 def _prior_model(self):
     scales, variance, jitter, bias = self._get_samples()
     if self.n > 0:
         kyy = _rbf(self.x, self.x, scales, variance) + jitter * eye(self.n)
         try:
             ckyy = _jitchol(kyy)
             sample(
                 "output",
                 MultivariateNormal(bias + zeros(self.n), scale_tril=ckyy),
                 obs=self.y,
             )
         except RuntimeError:  # Cholesky fails?
             # "No chance"
             sample("output", Delta(zeros(1)), obs=ones(1))
Beispiel #18
0
    def get_loglikelihood(self, X, U, mask):
        with torch.no_grad():

            total_num = mask.sum()
            loglik = torch.zeros((total_num))
            counter = 0
            T_max = U.size(1)

            # Feed forward
            _, locs, mix, covs = self.model(X=X, U=U, mask=mask)

            # For each time interval
            for t in range(0, T_max):
                # Distributions for current time interval
                fn_dist = MN(loc=locs[:, t, :, :],
                             scale_tril=covs[:, t, :, :, :])
                # Compute LL for all data points
                for tt in range(0, mask[t]):
                    loglik[counter] = torch.log((mix[:, t, :] * torch.exp(
                        fn_dist.log_prob(X[:, t, tt, :].squeeze()))).sum() +
                                                1e-16)
                    counter += 1

        return loglik
Beispiel #19
0
    def get_results(self, detach_posterior=True) -> CalibrationResults:
        """
        Distill the results of the calibration

        :param detach: If true, make sure that the posterior's parameters are
        detached (can cause issues if you try deepcopying)
        """

        loc, scale_tril = self.q_mu, self.q_sqrt
        if detach_posterior:
            loc, scale_tril = loc.detach(), scale_tril.detach()

        return CalibrationResults(
            self.experiment,
            self.get_loss(),
            MultivariateNormal(loc, scale_tril=scale_tril),
        )
Beispiel #20
0
class ProbabilisticGloveLayer(nn.Embedding):

    # TODO: Is there a way to express constraints on weights other than
    # the nn.Functional.gradient_clipping ?
    # ANSWER: Yes, if you use Pyro with constraints.
    def __init__(self, num_embeddings, embedding_dim, co_occurrence,
                 # glove learning options
                 x_max=100, alpha=0.75,
                 seed=None, # if None means don't set
                 # whether or not to use the wi and wj
                 # if set to False, will use only wi
                 double=False,
                 # nn.Embedding options go here
                 padding_idx=None,
                 scale_grad_by_freq=None,
                 max_norm=None, norm_type=2,
                 sparse=False   # not supported- just here to keep interface
                 ):
        self.seed = seed
        if seed is not None:
            # internal import to allow setting seed before any other imports
            import pyro
            pyro.set_rng_seed(seed)
        # Internal import because we need to set seed first
        from pyro.distributions import MultivariateNormal
        # This is spurious; we won't actually be using any of the superclass
        # attributes, but we have to do this to get other things like the
        # registration of parameters to work.
        super(ProbabilisticGloveLayer, self).__init__(num_embeddings, embedding_dim,
                                                      padding_idx=None, max_norm=None,
                                                      norm_type=2.0, scale_grad_by_freq=False,
                                                      sparse=False, _weight=None)
        if sparse:
            raise NotImplementedError("`sparse` is not implemented for this class")
        # for the total weight to have a max norm of K, the embeddings
        # that are summed to make them up need a max norm of K/2
        used_norm = max_norm / 2 if max_norm else None
        kws = {}
        if used_norm:
            kws['max_norm'] = used_norm
            kws['norm_type'] = norm_type
        if padding_idx:
            kws['padding_idx'] = padding_idx
        if scale_grad_by_freq is not None:
            kws['scale_grad_by_freq'] = scale_grad_by_freq
        # double is not supported, but we keep the same API.
        assert not double, "Probabilistic embedding can only be used in single mode"
        # This assumes each dimension is independent of the others,
        # and that all the embeddings are independent of each other.
        # We express it as MV normal because this allows us to use a
        # non diagonal covariance matrix
        # try setting the variance low here instead
        # The output of these needs to be moved to GPU before use,
        # because there is currently no nice way to move a distribution to GPU.
        self.wi_dist = MultivariateNormal(
            torch.zeros((num_embeddings, embedding_dim)),
            # changing this to 1e-9 makes the embeddings converge to something
            # else than the pretrained
            torch.eye(embedding_dim)# * 1e-9
        )
        self.bi_dist = MultivariateNormal(
            torch.zeros((num_embeddings, 1)),
            torch.eye(1)
        )
        # Deterministic means for the weights and bias, that will be learnt
        # means will be used to transform the samples from the above wi/bi
        # samples.
        # Express them as embeddings because that sets up all the gradients
        # for backprop, and allows for easy indexing.
        self.wi_mu = nn.Embedding(num_embeddings, embedding_dim)
        self.wi_mu.weight.data.uniform_(-1, 1)
        self.bi_mu = nn.Embedding(num_embeddings, 1)
        self.bi_mu.weight.data.zero_()
        # wi_sigma = log(1 + exp(wi_rho)) to enforce positivity.
        self.wi_rho = nn.Embedding(num_embeddings, embedding_dim)
        # initialise the rho so softplus results in small values
        # 1e-9 - this appears to be about -20 so we have to re-center around -19?
        # except it doesn't work- re-centering just makes the means nowhere near
        self.wi_rho.weight.data.uniform_(-1, 1)
        self.bi_rho = nn.Embedding(num_embeddings, 1)
        self.bi_rho.weight.data.zero_()
        # using torch functions should ensure backprop is set up right
        self.softplus = nn.Softplus()
        #self.wi_sigma = softplus(self.wi_rho.weight) #torch.log(1 + torch.exp(self.wi_rho))
        #self.bi_sigma = softplus(self.bi_rho.weight) #torch.log(1 + torch.exp(self.bi_rho))

        self.co_occurrence = co_occurrence.coalesce()
        # it is not very big
        self.coo_dense = self.co_occurrence.to_dense()
        self.x_max = x_max
        self.alpha = alpha
        # Placeholder. In future, we will make an abstract base class
        # which will have the below attribute so that all instances
        # carry their own loss.
        self.losses = []
        self._setup_indices()

    def _setup_indices(self):
        # Do some preprocessing to make looking up indices faster.
        # The co-occurrence matrix is a large array of pairs of indices
        # In the course of training, we will be given a list of
        # indices, and we need to find the pairs that are present.
        self.allindices = self.co_occurrence.indices()
        N = self.allindices.max() + 1
        # Store a dense array of which pairs are active
        # It is booleans so should be small even if there are a lot of tokens
        self.allpairs = torch.zeros((N, N), dtype=bool)
        self.allpairs[(self.allindices[0], self.allindices[1])] = True
        self.N = N

    @property
    def weight(self):
        return self.weights()

    def weights(self, n=1, squeeze=True):
        # we are taking one sample from each embedding distribution
        sample_shape = torch.Size([n])
        wi_eps = self.wi_dist.sample(sample_shape).type_as(self.wi_mu.weight.data)
        # TODO: Only because we have assumed a diagonal covariance matrix,
        # is the below elementwise multiplication (* rather than @).
        # If it was not diagonal, we would have to do matrix multiplication
        #wi = self.wi_mu + wi_eps * self.wi_sigma
        wi = (
                self.wi_mu.weight +
                # multiplying by 1e-9 below should have the same effect
                # as changing the wi_eps variance to 1e-9, but it doesn't.
                # multiplying here results in wi_mu converging very closely
                # to the deterministic embeddings, but the wi_sigma variance remains
                # the same as in the other case.
                wi_eps * self.softplus(self.wi_rho.weight) #* 1e-9
        )
        if squeeze:
            return wi.squeeze()
        else:
            return wi

    # implemented as such to be consistent with nn.Embeddings interface
    def forward(self, indices):
        return self.weights()(indices)

    def _update(self, i_indices, j_indices):
        # we need to do all the sampling here.
        # TODO: Not sure what to do with j_indices. Do we update the j_indices
        # TODO: Only because we have assumed a diagonal covariance matrix,
        # is the below elementwise multiplication (* rather than @).
        # If it was not diagonal, we would have to do matrix multiplication
        w_i = (
                self.wi_mu(i_indices) +
                self.wi_eps[i_indices] * self.softplus(self.wi_rho(i_indices))
        )

        b_i = (
                self.bi_mu(i_indices) +

                self.bi_eps[i_indices] * self.softplus(self.bi_rho(i_indices))
        ).squeeze()
        # If the double updating is not done, it takes a long time to converge.
        w_j = (
                self.wi_mu(j_indices) +
                self.wi_eps[j_indices] * self.softplus(self.wi_rho(j_indices))
        )
        b_j = (
                self.bi_mu(j_indices) +
                self.bi_eps[j_indices] * self.softplus(self.bi_rho(j_indices))
        ).squeeze()

        x = torch.sum(w_i * w_j, dim=1) + b_i + b_j
        return x

    def _init_samples(self):
        # On every 0th batch in an epoch, sample everything.
        sample_shape = torch.Size([])
        self.wi_eps = self.wi_dist.sample(sample_shape) #* 1e-9
        self.bi_eps = self.bi_dist.sample(sample_shape) #* 1e-9
        # This has to be done because there is currently no nice way to move
        # a Pyro distribution to GPU.
        # So we move whatever we sampled from it.
        template = self.wi_mu.weight.data
        self.wi_eps = self.wi_eps.type_as(template)
        self.bi_eps = self.bi_eps.type_as(template)

    def _loss_weights(self, x):
        # x: co_occurrence values
        wx = (x/self.x_max)**self.alpha
        wx = torch.min(wx, torch.ones_like(wx))
        return wx

    def loss(self, indices):
        # inputs are indexes, targets are actual embeddings
        # In the actual algorithm, "inputs" is the weight_func run on the
        # co-occurrence statistics.
        # loss = wmse_loss(weights_x, outputs, torch.log(x_ij))
        # not sure what it should be replaced by here
        # "targets" are the log of the co-occurrence statistics.
        # need to make every pair of indices that exist in co-occurrence file
        # Not every index will be represented in the co_occurrence matrix
        # To calculate glove loss, we will take all the pairs in the co-occurrence
        # that contain anything in the current set of indices.
        # There is a disconnect between the indices that are passed in here,
        # and the indices of all pairs in the co-occurrence matrix
        # containing those indices.
        indices = indices.sort()[0]
        subset = self.allpairs[indices]
        if not torch.any(subset):
            self.losses = [0]
            return self.losses
        # now look up the indices of the existing pairs
        # it is faster to do the indexing into an array of bools
        # instead of the dense array
        subset_indices = torch.nonzero(subset).type_as(indices)
        i_indices = indices[subset_indices[:, 0]]
        j_indices = subset_indices[:, 1]

        targets = self.coo_dense[(i_indices, j_indices)]
        weights_x = self._loss_weights(targets)
        current_weights = self._update(i_indices, j_indices)
        # put everything on the right device
        weights_x = weights_x.type_as(current_weights)
        targets = targets.type_as(current_weights)

        loss = weights_x * F.mse_loss(
            current_weights, torch.log(targets), reduction='none')
        # This is a feasible strategy for mapping indices -> pairs
        # Second strategy: Loop over all possible pairs
        # More computationally intensive
        # Allow this to be configurable?
        # Degrees of separation but only allow 1 or -1
        # -1 is use all indices
        # 1 is use indices
        # We may want to save this loss as an attribute on the layer object
        # Does Lightning have a way of representing layer specific losses
        # Define an interface by which we can return loss objects
        # probably stick with self.losses = [loss]
        # - a list - because then we can do +=
        loss = torch.mean(loss)
        self.losses = [loss]
        return self.losses

    def entropy(self):
        # Calculate entropy based on the learnt rho (which must be transformed
        # to a variance using softplus)
        # entropy of a MV Gaussian =
        # 0.5 * N * ln (2 * pi * e) + 0.5 * ln (det C)
        N = self.embedding_dim
        C = self.softplus(self.wi_rho.weight)
        # diagonal covariance so just multiply all the items to get determinant
        # convert to log space so we can add across the dimensions
        # We don't need to convert back to exp because we need log det C anyway
        logdetC = torch.log(C).sum(axis=1)
        entropy = 0.5*(N*np.log(2*np.pi * np.e) + logdetC)
        return entropy

    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        # eventually, load this from a .pt file that contains all the
        # wi/wj/etc.
        raise NotImplementedError('Not yet implemented')
Beispiel #21
0
def guide(data):
    # pyro params
    new_topic_prob = pyro.param("new_topic_prob",
                                lambda: Uniform(0, 1).sample([T]),
                                constraint=constraints.unit_interval)

    linked_prob = pyro.param("linked_prob",
                             lambda: Uniform(0, 1).sample([T]),
                             constraint=constraints.unit_interval)

    which_topic_probs = pyro.param("which_topic_probs",
                                   lambda: Uniform(0, 1).sample([T_prev]),
                                   constraint=constraints.simplex)

    kappa = pyro.param('kappa',
                       lambda: Uniform(0, 2).sample([T - 1]),
                       constraint=constraints.positive)

    tau = pyro.param('tau',
                     lambda: MultivariateNormal(0.5 * torch.ones(C), 0.25 *
                                                torch.eye(C)).sample([T]),
                     constraint=constraints.unit_interval)

    # N params for categorical dist; topic weights; symmetric prior
    phi = pyro.param('phi',
                     lambda: Dirichlet(1 / T * torch.ones(T)).sample([N]),
                     constraint=constraints.simplex)

    # model params
    with pyro.plate("new_topic_plate", T):
        # print(new_topic_prob)
        new_topic = pyro.sample("new_topic", Binomial(probs=new_topic_prob))

    # if new topic, if linked to old topic, prior=0.5
    with pyro.plate("linked_plate", T):
        linked = pyro.sample("linked", Binomial(probs=linked_prob))

    # if old topic, which old topic
    with pyro.plate("old_topic_plate", T):
        which_old_topic = pyro.sample("which_old_topic",
                                      Multinomial(probs=which_topic_probs))

    with pyro.plate("beta_plate", T - 1):
        q_beta = 0
        q_beta += pyro.sample("beta", Beta(torch.ones(T - 1), kappa))

    # new topic with symmetric prior
    with pyro.plate("theta_plate", T):
        theta = pyro.sample("theta", Dirichlet(tau))

    # new topic linked to old topic
    with pyro.plate("gamma_plate", T_prev):
        gamma = pyro.sample("gamma", Dirichlet(prev_taus))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(phi))
        old = get_old_topics(which_old_topic)
        a = ((new_topic) * (linked))
        b = (1 - new_topic)
        c = ((new_topic) * (1 - linked))
        a = a[z].reshape(N, 1)
        b = b[z].reshape(N, 1)
        c = c[z].reshape(N, 1)
        mult_probs = 0
        mult_probs += a * gamma[old[z]] + b * prev_theta[old[z]] + c * theta[z]
Beispiel #22
0
    def __init__(self, num_embeddings, embedding_dim, co_occurrence,
                 # glove learning options
                 x_max=100, alpha=0.75,
                 seed=None, # if None means don't set
                 # whether or not to use the wi and wj
                 # if set to False, will use only wi
                 double=False,
                 # nn.Embedding options go here
                 padding_idx=None,
                 scale_grad_by_freq=None,
                 max_norm=None, norm_type=2,
                 sparse=False   # not supported- just here to keep interface
                 ):
        self.seed = seed
        if seed is not None:
            # internal import to allow setting seed before any other imports
            import pyro
            pyro.set_rng_seed(seed)
        # Internal import because we need to set seed first
        from pyro.distributions import MultivariateNormal
        # This is spurious; we won't actually be using any of the superclass
        # attributes, but we have to do this to get other things like the
        # registration of parameters to work.
        super(ProbabilisticGloveLayer, self).__init__(num_embeddings, embedding_dim,
                                                      padding_idx=None, max_norm=None,
                                                      norm_type=2.0, scale_grad_by_freq=False,
                                                      sparse=False, _weight=None)
        if sparse:
            raise NotImplementedError("`sparse` is not implemented for this class")
        # for the total weight to have a max norm of K, the embeddings
        # that are summed to make them up need a max norm of K/2
        used_norm = max_norm / 2 if max_norm else None
        kws = {}
        if used_norm:
            kws['max_norm'] = used_norm
            kws['norm_type'] = norm_type
        if padding_idx:
            kws['padding_idx'] = padding_idx
        if scale_grad_by_freq is not None:
            kws['scale_grad_by_freq'] = scale_grad_by_freq
        # double is not supported, but we keep the same API.
        assert not double, "Probabilistic embedding can only be used in single mode"
        # This assumes each dimension is independent of the others,
        # and that all the embeddings are independent of each other.
        # We express it as MV normal because this allows us to use a
        # non diagonal covariance matrix
        # try setting the variance low here instead
        # The output of these needs to be moved to GPU before use,
        # because there is currently no nice way to move a distribution to GPU.
        self.wi_dist = MultivariateNormal(
            torch.zeros((num_embeddings, embedding_dim)),
            # changing this to 1e-9 makes the embeddings converge to something
            # else than the pretrained
            torch.eye(embedding_dim)# * 1e-9
        )
        self.bi_dist = MultivariateNormal(
            torch.zeros((num_embeddings, 1)),
            torch.eye(1)
        )
        # Deterministic means for the weights and bias, that will be learnt
        # means will be used to transform the samples from the above wi/bi
        # samples.
        # Express them as embeddings because that sets up all the gradients
        # for backprop, and allows for easy indexing.
        self.wi_mu = nn.Embedding(num_embeddings, embedding_dim)
        self.wi_mu.weight.data.uniform_(-1, 1)
        self.bi_mu = nn.Embedding(num_embeddings, 1)
        self.bi_mu.weight.data.zero_()
        # wi_sigma = log(1 + exp(wi_rho)) to enforce positivity.
        self.wi_rho = nn.Embedding(num_embeddings, embedding_dim)
        # initialise the rho so softplus results in small values
        # 1e-9 - this appears to be about -20 so we have to re-center around -19?
        # except it doesn't work- re-centering just makes the means nowhere near
        self.wi_rho.weight.data.uniform_(-1, 1)
        self.bi_rho = nn.Embedding(num_embeddings, 1)
        self.bi_rho.weight.data.zero_()
        # using torch functions should ensure backprop is set up right
        self.softplus = nn.Softplus()
        #self.wi_sigma = softplus(self.wi_rho.weight) #torch.log(1 + torch.exp(self.wi_rho))
        #self.bi_sigma = softplus(self.bi_rho.weight) #torch.log(1 + torch.exp(self.bi_rho))

        self.co_occurrence = co_occurrence.coalesce()
        # it is not very big
        self.coo_dense = self.co_occurrence.to_dense()
        self.x_max = x_max
        self.alpha = alpha
        # Placeholder. In future, we will make an abstract base class
        # which will have the below attribute so that all instances
        # carry their own loss.
        self.losses = []
        self._setup_indices()
Beispiel #23
0

def eval_grid(xx, yy, fcn):
    xy = torch.stack([xx.flatten(), yy.flatten()], dim=1)
    return fcn(xy).reshape_as(xx)


if __name__ == '__main__':
    from pyro.distributions import Dirichlet

    N, K, D = 200, 4, 2
    props = Dirichlet(5 * torch.ones(K)).sample()
    mean = torch.arange(K).float().view(K, 1).expand(K, D)
    var = .1 * torch.eye(D).expand(K, -1, -1)
    mixing = Categorical(props)
    components = MultivariateNormal(mean, var)
    print("mixing", mixing.batch_shape, mixing.event_shape)
    print("components", components.batch_shape, components.event_shape)
    mixture = Mixture(mixing,
                      NaturalMultivariateNormal.from_standard(components))
    mixture.rename(['x', 'y'])
    print("mixture names", mixture.variable_names)
    print("mixture", mixture.batch_shape, mixture.event_shape)
    probe = MultivariateNormal(mean[:3] + 1 * torch.tensor([1., -1.]),
                               .2 * var[:3])
    post_mixture = mixture.posterior(probe)
    print("post_mixture names", post_mixture.variable_names)
    samples = mixture.sample([N])
    n = 1
    post_samples = post_mixture.sample([N])[:, n]
    print("post_mixture", post_mixture.batch_shape, post_mixture.event_shape)
Beispiel #24
0
    def logjoint(self):

        distA = LogNormal(-0.5, 0.5)
        distd = HalfNormal(1.)
        distMC = NFdist(model=model, M=0., c=0.)

        A = pyro.sample(
            "A", distA)  #normal.Normal(-0.5, 0.5).log_prob(torch.log(A))
        d = pyro.sample("d", distd)
        Mc = pyro.sample('Mc', distMC)

        #c = Mc[:,1]*ss.scale_[1] + ss.mean_[1]
        #M = Mc[:,0]*ss.scale_[0] + ss.mean_[0]

        Mc_scaled = Mc * torch.tensor(self.ss.scale_) + torch.tensor(
            self.ss.mean_)
        #print(Mc_scaled, Mc)
        try:
            c = Mc_scaled[:, 0:len(self.chat)]
            M = Mc_scaled[:, len(self.chat):]
        except IndexError:
            c = Mc_scaled[0:len(self.chat)]
            M = Mc_scaled[len(self.chat):]
        #print(M, c, A, d)
        covc = torch.eye(len(self.sigmac))
        for i, s in enumerate(self.sigmac):
            covc[i, i] *= s**2
        covm = torch.eye(len(self.sigmam))
        for i, s in enumerate(self.sigmam):
            covm[i, i] *= s**2
        #print(A*torch.tensor(dustco_c))
        #print(self.predicted_color(A, c, dustco_c))
        lnp_c = pyro.sample("chat",
                            MultivariateNormal(
                                self.predicted_color(A, c, self.dustco_c),
                                covc),
                            obs=torch.Tensor(self.chat))
        #lnp_c = pyro.sample("chat", Normal(self.predicted_color(A, c, dustco_c), self.sigmac), obs = torch.Tensor([self.chat]))
        lnp_m = pyro.sample("mhat",
                            MultivariateNormal(
                                self.predicted_magnitude(
                                    A, M, d, self.dustco_m), covm),
                            obs=torch.Tensor(self.mhat))
        #lnp_m = pyro.sample("mhat", Normal(self.predicted_magnitude(A, M, d, dustco_m), self.sigmam), obs=torch.Tensor([self.mhat]))
        lnp_varpi = pyro.sample("varpihat",
                                Normal(1. / d, self.sigmavarpi),
                                obs=torch.Tensor([self.varpihat]))
        lnp_Mc = distMC.log_prob(Mc)
        lnp_A = distA.log_prob(A)
        lnp_d = distd.log_prob(d)
        #print(lnp_c, lnp_m, lnp_varpi, lnp_Mc, lnp_A, lnp_d)
        #print(lnp_c[0], lnp_m[0], lnp_varpi[0], lnp_Mc[0][0], lnp_A, lnp_d, lnp_M[0], lnp_c[0])
        logp = torch.stack([
            lnp_c.sum(),
            lnp_m.sum(),
            lnp_varpi.sum(),
            lnp_Mc.sum(),
            lnp_A.sum(),
            lnp_d.sum()
        ],
                           dim=0).sum()
        return logp
Beispiel #25
0
 def predict(self, x) -> MultivariateNormal:
     mean, scale_tril = self(x)
     return MultivariateNormal(mean, scale_tril=scale_tril)
Beispiel #26
0
@register_kl(Factorised, Factorised)
def _kl_factorised_factorised(p: Factorised, q: Factorised):
    return sum(
        kl_divergence(p_factor, q_factor)
        for p_factor, q_factor in zip(p.factors, q.factors))


if __name__ == '__main__':
    from pyro.distributions import Dirichlet, MultivariateNormal
    from torch.distributions import kl_divergence
    from distributions.mixture import Mixture

    B, D1, D2 = 5, 3, 4
    N = 1000

    dist1 = MultivariateNormal(torch.zeros(D1), torch.eye(D1)).expand((B, ))
    dist2 = Dirichlet(torch.ones(D2)).expand((B, ))
    print(dist1.batch_shape, dist1.event_shape)
    print(dist2.batch_shape, dist2.event_shape)
    fact = Factorised([dist1, dist2])
    print(fact.batch_shape, fact.event_shape)
    samples = fact.rsample((N, ))
    print(samples[0])
    print(samples.shape)
    logp = fact.log_prob(samples)
    print(logp.shape)
    entropy = fact.entropy()
    print(entropy.shape)
    print(entropy, -logp.mean())
    print()