def _length_log_probs_with_rates(self, log_rates):
     n_classes = log_rates.size(-1)
     max_length = self.max_k
     # max_length x n_classes
     time_steps = torch.arange(max_length, device=log_rates.device).unsqueeze(-1).expand(max_length,
                                                                                         n_classes).float()
     if max_length == 1:
         return torch.FloatTensor([0, -1000]).unsqueeze(-1).expand(2, n_classes).to(log_rates.device)
         # return torch.zeros(max_length, n_classes).to(log_rates.device)
     poissons = Poisson(torch.exp(log_rates))
     if log_rates.dim() == 2:
         time_steps = time_steps.unsqueeze(1).expand(max_length, log_rates.size(0), n_classes)
         return poissons.log_prob(time_steps).transpose(0, 1)
     else:
         assert log_rates.dim() == 1
         return poissons.log_prob(time_steps)
Exemplo n.º 2
0
def test_zip_0_gate(rate):
    # if gate is 0 ZIP is Poisson
    zip_ = ZeroInflatedPoisson(torch.zeros(1), torch.tensor(rate))
    pois = Poisson(torch.tensor(rate))
    s = pois.sample((20, ))
    zip_prob = zip_.log_prob(s)
    pois_prob = pois.log_prob(s)
    assert_close(zip_prob, pois_prob, atol=1e-06)
Exemplo n.º 3
0
def tile_map_prior(prior: ImagePrior, tile_map):
    # Source probabilities
    dist_sources = Poisson(torch.tensor(prior.mean_sources))
    log_prob_no_source = dist_sources.log_prob(torch.tensor(0))
    log_prob_one_source = dist_sources.log_prob(torch.tensor(1))
    log_prob_source = (tile_map["n_sources"] == 0) * log_prob_no_source + (
        tile_map["n_sources"] == 1) * log_prob_one_source

    # Binary probabilities
    galaxy_log_prob = torch.tensor(0.7).log()
    star_log_prob = torch.tensor(0.3).log()
    log_prob_binary = (galaxy_log_prob * tile_map["galaxy_bools"] +
                       star_log_prob * tile_map["star_bools"])

    # Galaxy probabiltiies
    gal_dist = Normal(0.0, 1.0)
    galaxy_probs = gal_dist.log_prob(
        tile_map["galaxy_params"]) * tile_map["galaxy_bools"]

    # prob_normalized =
    return log_prob_source.sum() + log_prob_binary.sum() + galaxy_probs.sum()
Exemplo n.º 4
0
    def k_new(self, X, Z, A, i, truncation):
        '''
        i: The loop calling this function is asking this function
        "how many new features (k_new) should data point i draw?"

        truncation: When computing the un-normalized posterior for k_new|X,Z,A, we cannot
        compute the posterior for the infinite amount of values k_new could take on. So instead
        we compute from 0 up to some high number, truncation, and then normalize. In practice,
        the posterior probability for k_new is so low that it underflows past truncation=20.
        '''

        log_likelihood = torch.zeros(truncation)
        log_poisson_probs = torch.zeros(truncation)
        N, K = Z.size()
        D = X.size()[1]
        p_k_new = Pois(torch.tensor([self.alpha / N]))
        cur_X_minus_ZA = X - Z @ A

        for j in range(truncation):

            # Compute the log likelihood of k_new equaling j
            log_likelihood[j] = self.log_likelihood_given_k_new(
                cur_X_minus_ZA, Z, D, i, j)

            # Compute the prior probability of k_new equaling j
            log_poisson_probs[j] = p_k_new.log_prob(j)

            # Add new column to Z for next feature
            zeros = torch.zeros(N)
            Z = torch.cat((Z, torch.zeros(N, 1)), 1)
            Z[i][-1] = 1

        # Compute log posterior of k_new and exp/normalize
        log_sample_probs = log_likelihood + log_poisson_probs
        sample_probs = self.renormalize_log_probs(log_sample_probs)

        # Important: we changed Z for calculating p(k_new|...)
        # so we must take off the extra rows
        Z = Z[:, :-truncation]
        assert Z.size()[1] == K
        posterior_k_new = Categorical(sample_probs)
        return posterior_k_new.sample()
Exemplo n.º 5
0
 def log_prob(self, likelihood: Poisson, x):
     # [B, V]
     counts = self.make_counts(x)
     # [B, V] -> [B]
     return likelihood.log_prob(counts).sum(-1)
Exemplo n.º 6
0
 def step1_logprob(self, μs):
     N = len(μs)
     poisson = Poisson(self.λ)
     return poisson.log_prob(N)