Ejemplo n.º 1
0
def ZeroInflatedPoisson_loss_function(recon_x,
                                      x,
                                      latent_loss,
                                      data_shape=None,
                                      act_choice=5):

    x_shape = x.size()
    # if x == 0
    recon_x_0_bin = recon_x[0]
    recon_x_0_count = recon_x[1]

    poisson_0 = (x == 0).float() * Poisson(recon_x_0_count).log_prob(x)
    # else if x > 0
    poisson_greater0 = (x > 0).float() * Poisson(recon_x_0_count).log_prob(x)

    zero_inf = torch.cat((torch.log((1 - recon_x_0_bin) + 1e-9).view(
        x_shape[0], x_shape[1], -1), poisson_0.view(x_shape[0], x_shape[1],
                                                    -1)),
                         dim=2)

    log_l = (x == 0).float() * torch.logsumexp(zero_inf, dim=2)
    log_l += (x > 0).float() * (torch.log(recon_x_0_bin + 1e-9) +
                                poisson_greater0)

    # KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return -(log_l[:, :7 * 5 - 1].sum() +
             log_l[:, 7 * 5 + 1:].sum()) + latent_loss
Ejemplo n.º 2
0
 def __init__(self, num_bootstraps=1, bootstrap=True, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.weights = torch.empty(self.max_len,
                                num_bootstraps,
                                dtype=torch.int)
     self.mask_distribution = Poisson(torch.ones(num_bootstraps))
     self.bootstrap = bootstrap
Ejemplo n.º 3
0
def test_zip_0_gate(rate):
    # if gate is 0 ZIP is Poisson
    zip_ = ZeroInflatedPoisson(torch.zeros(1), torch.tensor(rate))
    pois = Poisson(torch.tensor(rate))
    s = pois.sample((20, ))
    zip_prob = zip_.log_prob(s)
    pois_prob = pois.log_prob(s)
    assert_close(zip_prob, pois_prob, atol=1e-06)
Ejemplo n.º 4
0
 def sample(self, m, a):
     r = 1 / a
     p = (m * a) / (1 + (m * a))
     b = (1 - p) / p
     g = Gamma(r, b)
     g = g.sample()
     p = Poisson(g)
     z = p.sample()
     return z
Ejemplo n.º 5
0
def poisson_spike(x, time_bins):
    shape_org = list(x.shape)
    y = x.reshape(-1)
    samples = []
    for yy in y:
        m1 = Poisson(yy)
        samples.append(m1.sample(sample_shape=(time_bins,)) > 0)
    output = torch.stack(samples, dim=0).float()
    return output.reshape(shape_org + [time_bins])
Ejemplo n.º 6
0
    def _sample_n_sources(self, batch_size, n_tiles_h, n_tiles_w):
        # returns number of sources for each batch x tile
        # output dimension is batch_size x n_tiles_h x n_tiles_w

        # always poisson distributed.
        p = torch.full((1, ),
                       self.mean_sources,
                       device=self.device,
                       dtype=torch.float)
        m = Poisson(p)
        n_sources = m.sample([batch_size, n_tiles_h, n_tiles_w])

        # long() here is necessary because used for indexing and one_hot encoding.
        n_sources = n_sources.clamp(max=self.max_sources, min=self.min_sources)
        return rearrange(n_sources.long(), "b nth ntw 1 -> b nth ntw")
Ejemplo n.º 7
0
    def _reconstruction_loss(self,
                             x,
                             px_rate,
                             px_r,
                             px_dropout,
                             batch_index,
                             y,
                             mode="scRNA",
                             weighting=1):
        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r

        # Reconstruction Loss
        if mode == "scRNA":
            if self.reconstruction_loss == 'zinb':
                reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r),
                                                  px_dropout)
            elif self.reconstruction_loss == 'nb':
                reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))

        else:
            if self.reconstruction_loss_fish == 'poisson':
                reconst_loss = -torch.sum(Poisson(px_rate).log_prob(x), dim=1)
            elif self.reconstruction_loss_fish == 'gaussian':
                reconst_loss = -torch.sum(Normal(px_rate, 10).log_prob(x),
                                          dim=1)
        return reconst_loss
 def _length_log_probs_with_rates(self, log_rates):
     n_classes = log_rates.size(-1)
     max_length = self.max_k
     # max_length x n_classes
     time_steps = torch.arange(max_length, device=log_rates.device).unsqueeze(-1).expand(max_length,
                                                                                         n_classes).float()
     if max_length == 1:
         return torch.FloatTensor([0, -1000]).unsqueeze(-1).expand(2, n_classes).to(log_rates.device)
         # return torch.zeros(max_length, n_classes).to(log_rates.device)
     poissons = Poisson(torch.exp(log_rates))
     if log_rates.dim() == 2:
         time_steps = time_steps.unsqueeze(1).expand(max_length, log_rates.size(0), n_classes)
         return poissons.log_prob(time_steps).transpose(0, 1)
     else:
         assert log_rates.dim() == 1
         return poissons.log_prob(time_steps)
Ejemplo n.º 9
0
def cascvi_baseline_z(tree,
                    latent,
                    model,
                    library_size=10000):
    """
    :param tree: ete3 phylogenetic tree
    :param latent: dict: latent representations of internal nodes
    :param model: VAE or variant
    :param weighted: True if the average is weighted
    :param library_size:
    :return: imputed Gene Expression
    """

    imputed = {}
    for n in tree.traverse('levelorder'):
        if n.is_leaf():
            continue
        else:
            px_scale, px_r, px_rate, px_dropout = model.decoder.forward(model.dispersion,
                                                                        latent[n.name].float().view(1, -1),
                                                                        torch.from_numpy(np.array([np.log(library_size)])),
                                                                        0)

            l_train = torch.clamp(torch.mean(px_rate, axis=0), max=1e8)

            data = Poisson(l_train).sample().cpu().numpy()

            imputed[n.name] = data

    return imputed
Ejemplo n.º 10
0
    def test_Poisson(self):
        shape = 10, 100

        a = 1e-2 * torch.ones((shape[0], 1))
        dt = 1e-2
        dist = Poisson(dt * 0.1)

        init = Normal(a, 1.)
        sde = AffineEulerMaruyama((f_sde, g_sde), (a, 0.15),
                                  init,
                                  dist,
                                  dt=dt,
                                  num_steps=10)

        # ===== Initialize ===== #
        x = sde.i_sample(shape)

        # ===== Propagate ===== #
        num = 1000
        samps = [x]
        for t in range(num):
            samps.append(sde.propagate(samps[-1]))

        samps = torch.stack(samps)
        self.assertEqual(samps.size(), torch.Size([num + 1, *shape]))

        # ===== Sample path ===== #
        path = sde.sample_path(num + 1, shape)
        self.assertEqual(samps.shape, path.shape)
Ejemplo n.º 11
0
 def _sampler(self, samples=1000):
     d_ = torch.ones(samples)
     if d == 1:
         # If SZ is adopted, then some Districts and Schools buy in
         dist = Poisson(self.n_districts)\
                 .sample([samples])\
                 .reshape([samples])
         schools = NegativeBinomial(tensor([3.]),
                                    tensor([0.8]))\
                     .sample([samples, self.n_districts.int()])\
                     .sum(dim=1)\
                     .reshape([samples])
         sz = 15000. * dist + 2430 * schools
     else:
         dist, schools, sz = torch.zeros(samples),\
                             torch.zeros(samples),\
                             torch.zeros(samples)
     if d < 2:
         sf = LogNormal(
                 *self._lognormal_params(300000., 10000.))\
                     .sample([samples])
     else:
         sf = torch.zeros(samples)
     # System & Infrastructure
     az = LogNormal(self.az_means[d], self.az_sds[d]).sample([samples])
     salary_estimate = Normal(70000., 5000.).sample([samples])
     fa = Beta(self.fa_ms[d], self.fa_ks[d]).sample([samples])
     dt = Beta(self.dt_ms[d], self.dt_ks[d]).sample([samples])
     return d_, dist, schools, sz, az, sf, fa, dt
Ejemplo n.º 12
0
    def get_reconstruction_loss(self,
                                x,
                                px_rate,
                                px_r,
                                px_dropout,
                                mode="scRNA",
                                weighting=1):

        # Reconstruction Loss
        if mode == "scRNA":
            if self.reconstruction_loss == 'zinb':
                reconst_loss = -log_zinb_positive(x, px_rate, torch.exp(px_r),
                                                  px_dropout)
            elif self.reconstruction_loss == 'nb':
                reconst_loss = -log_nb_positive(x, px_rate, torch.exp(px_r))

        else:
            if self.reconstruction_loss_fish == 'poisson':
                reconst_loss = -torch.sum(Poisson(px_rate).log_prob(
                    x[:, self.indexes_to_keep]),
                                          dim=1)
            elif self.reconstruction_loss_fish == 'gaussian':
                reconst_loss = -torch.sum(Normal(px_rate, 10).log_prob(
                    x[:, self.indexes_to_keep]),
                                          dim=1)
        return reconst_loss
    def __call__(self, tensor):
        # Calculate photons per pixels\
        photons_per_pixel = np.random.negative_binomial(
            self.mean / self.dispersion,
            1 / self.mean) / self.mean * self.dispersion

        # Sample poisson
        poisson_params = torch.mul(torch.div(tensor, 255.0), photons_per_pixel)
        poisson_dist = Poisson(poisson_params)
        poisson_samples = poisson_dist.sample()
        result = torch.mul(torch.div(poisson_samples, photons_per_pixel),
                           255.0)

        # Clamp
        result = torch.clamp(result, 0, 255)

        return result.type(torch.uint8)
Ejemplo n.º 14
0
    def generate(
        self,
        n_samples: int = 100,
        batch_size: int = 64
    ):  # with n_samples>1 return original list/ otherwise sequential
        """
        Return samples from posterior predictive. Proteins are concatenated to genes.

        :param n_samples: Number of posterior predictive samples
        :return: Tuple of posterior samples, original data
        """
        original_list = []
        posterior_list = []
        for tensors in self.update({"batch_size": batch_size}):
            x, _, _, batch_index, labels, y = tensors
            with torch.no_grad():
                outputs = self.model.inference(x,
                                               y,
                                               batch_index=batch_index,
                                               label=labels,
                                               n_samples=n_samples)
            px_ = outputs["px_"]
            py_ = outputs["py_"]

            pi = 1 / (1 + torch.exp(-py_["mixing"]))
            mixing_sample = Bernoulli(pi).sample()
            protein_rate = (py_["rate_fore"] * (1 - mixing_sample) +
                            py_["rate_back"] * mixing_sample)
            rate = torch.cat((px_["rate"], protein_rate), dim=-1)
            if len(px_["r"].size()) == 2:
                px_dispersion = px_["r"]
            else:
                px_dispersion = torch.ones_like(x) * px_["r"]
            if len(py_["r"].size()) == 2:
                py_dispersion = py_["r"]
            else:
                py_dispersion = torch.ones_like(y) * py_["r"]

            dispersion = torch.cat((px_dispersion, py_dispersion), dim=-1)

            # This gamma is really l*w using scVI manuscript notation
            p = rate / (rate + dispersion)
            r = dispersion
            l_train = Gamma(r, (1 - p) / p).sample()
            data = Poisson(l_train).sample().cpu().numpy()
            # """
            # In numpy (shape, scale) => (concentration, rate), with scale = p /(1 - p)
            # rate = (1 - p) / p  # = 1/scale # used in pytorch
            # """
            original_list += [np.array(torch.cat((x, y), dim=-1).cpu())]
            posterior_list += [data]

            posterior_list[-1] = np.transpose(posterior_list[-1], (1, 2, 0))

        return (
            np.concatenate(posterior_list, axis=0),
            np.concatenate(original_list, axis=0),
        )
Ejemplo n.º 15
0
    def distribution(self,
                     distr_args,
                     scale: Optional[torch.Tensor] = None) -> Distribution:
        (rate, ) = distr_args

        if scale is not None:
            rate *= scale

        return self.independent(Poisson(rate))
Ejemplo n.º 16
0
 def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout, **kwargs):
     # Reconstruction Loss
     if self.reconstruction_loss == "zinb":
         reconst_loss = -log_zinb_positive(x, px_rate, px_r, px_dropout).sum(dim=-1)
     elif self.reconstruction_loss == "nb":
         reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1)
     elif self.reconstruction_loss == "poisson":
         reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1)
     return reconst_loss
def main():
    file = h5py.File('/notebooks/data/static20892-3-14-preproc0.h5', "r")
    dat = StaticImageSet('/notebooks/data/static20892-3-14-preproc0.h5',
                         'images', 'responses')
    img_shape = dat.img_shape[2:]
    gabor_rf = gen_gabor_RF(img_shape, rf_shape, n=n_neurons, seed=random_seed)
    images = dat[()].images
    images = images.reshape(6000, 36, 64)
    firing_rates = compute_activity_simple(images / 255, gabor_rf)
    dist = Poisson(torch.tensor(firing_rates))
    responses = dist.sample().numpy()
    data_file = h5py.File("toy_dataset.hdf5", "w")

    im_set = data_file.create_dataset('images', data=dat[()].images)
    response_set = data_file.create_dataset('responses', data=responses)
    tier_set = data_file.create_dataset('tiers', data=file['tiers'])

    data_file.close()
Ejemplo n.º 18
0
def add_poisson(
    tensor: Tensor,
    lam: Union[Number, Tuple[Number, Number]],
    inplace: bool = False,
    clip: bool = True,
) -> Tuple[Tensor, Union[Number, Tensor]]:
    """Adds Poisson noise to a batch of input images.

    Args:
        tensor (Tensor): Tensor to add noise to; this should be in a B*** format, e.g. BCHW.
        lam (Union[Number, Tuple[Number, Number]]): Distribution rate parameter (lambda) for
            noise being added. If a Tuple is provided then the lambda is pulled from the
            uniform distribution between the two value is used for each batched input (B***).
        inplace (bool, optional): Whether to add the noise in-place. Defaults to False.
        clip (bool, optional): Whether to clip between image bounds (0.0-1.0 or 0-255).
            Defaults to True.

    Returns:
        Tuple[Tensor, Union[Number, Tensor]]: Tuple containing:
            * Copy of or reference to input tensor with noise added.
            * Lambda used for noise generation. This will be an array of the different
            lambda used if a range of lambda are being used.
    """
    if not inplace:
        tensor = tensor.clone()

    if isinstance(lam, (list, tuple)):
        if len(lam) == 1:
            lam = lam[0]
        else:
            assert len(lam) == 2
            (min_lam, max_lam) = lam
            uniform_generator = Uniform(min_lam, max_lam)
            shape = [tensor.shape[0]] + [1] * (len(tensor.shape) - 1)
            lam = uniform_generator.sample(shape)
    tensor.mul_(lam)
    poisson_generator = Poisson(torch.tensor(1, dtype=float))
    noise = poisson_generator.sample(tensor.shape)
    tensor.add_(noise)
    tensor.div_(lam)
    if clip:
        tensor = ssdn.utils.clip_img(tensor, inplace=True)

    return tensor, lam
Ejemplo n.º 19
0
    def sample(self, sample_shape=torch.Size()):
        gamma_d = self._gamma()
        p_means = gamma_d.sample(sample_shape)

        # Clamping as distributions objects can have buggy behaviors when
        # their parameters are too high
        l_train = torch.clamp(p_means, max=1e8)
        counts = Poisson(
            l_train).sample()  # Shape : (n_samples, n_cells_batch, n_genes)
        return counts
Ejemplo n.º 20
0
    def k_new(self, X, Z, A, i, truncation):
        '''
        i: The loop calling this function is asking this function
        "how many new features (k_new) should data point i draw?"

        truncation: When computing the un-normalized posterior for k_new|X,Z,A, we cannot
        compute the posterior for the infinite amount of values k_new could take on. So instead
        we compute from 0 up to some high number, truncation, and then normalize. In practice,
        the posterior probability for k_new is so low that it underflows past truncation=20.
        '''

        log_likelihood = torch.zeros(truncation)
        log_poisson_probs = torch.zeros(truncation)
        N, K = Z.size()
        D = X.size()[1]
        p_k_new = Pois(torch.tensor([self.alpha / N]))
        cur_X_minus_ZA = X - Z @ A

        for j in range(truncation):

            # Compute the log likelihood of k_new equaling j
            log_likelihood[j] = self.log_likelihood_given_k_new(
                cur_X_minus_ZA, Z, D, i, j)

            # Compute the prior probability of k_new equaling j
            log_poisson_probs[j] = p_k_new.log_prob(j)

            # Add new column to Z for next feature
            zeros = torch.zeros(N)
            Z = torch.cat((Z, torch.zeros(N, 1)), 1)
            Z[i][-1] = 1

        # Compute log posterior of k_new and exp/normalize
        log_sample_probs = log_likelihood + log_poisson_probs
        sample_probs = self.renormalize_log_probs(log_sample_probs)

        # Important: we changed Z for calculating p(k_new|...)
        # so we must take off the extra rows
        Z = Z[:, :-truncation]
        assert Z.size()[1] == K
        posterior_k_new = Categorical(sample_probs)
        return posterior_k_new.sample()
Ejemplo n.º 21
0
 def get_reconstruction_loss(self, x, px_rate, px_r,
                             px_dropout) -> torch.Tensor:
     if self.gene_likelihood == "zinb":
         reconst_loss = (-ZeroInflatedNegativeBinomial(
             mu=px_rate, theta=px_r,
             zi_logits=px_dropout).log_prob(x).sum(dim=-1))
     elif self.gene_likelihood == "nb":
         reconst_loss = (-NegativeBinomial(
             mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1))
     elif self.gene_likelihood == "poisson":
         reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1)
     return reconst_loss
Ejemplo n.º 22
0
def get_poisson_pair_by_maxval(img, max_val, seed=1):
    '''
    Input image must be in (0,1] range.
    '''

    if max_val == 255:
        return img, img

    img = th.FloatTensor(img)
    noisy = Poisson(img * max_val).sample()

    return img * max_val, noisy
Ejemplo n.º 23
0
def tile_map_prior(prior: ImagePrior, tile_map):
    # Source probabilities
    dist_sources = Poisson(torch.tensor(prior.mean_sources))
    log_prob_no_source = dist_sources.log_prob(torch.tensor(0))
    log_prob_one_source = dist_sources.log_prob(torch.tensor(1))
    log_prob_source = (tile_map["n_sources"] == 0) * log_prob_no_source + (
        tile_map["n_sources"] == 1) * log_prob_one_source

    # Binary probabilities
    galaxy_log_prob = torch.tensor(0.7).log()
    star_log_prob = torch.tensor(0.3).log()
    log_prob_binary = (galaxy_log_prob * tile_map["galaxy_bools"] +
                       star_log_prob * tile_map["star_bools"])

    # Galaxy probabiltiies
    gal_dist = Normal(0.0, 1.0)
    galaxy_probs = gal_dist.log_prob(
        tile_map["galaxy_params"]) * tile_map["galaxy_bools"]

    # prob_normalized =
    return log_prob_source.sum() + log_prob_binary.sum() + galaxy_probs.sum()
Ejemplo n.º 24
0
def get_poisson_pair(img, k, seed=1):
    '''
    Input image must be in (0,1] range.
    '''

    if k == 0:
        return img, img

    img = th.FloatTensor(img)
    noisy = Poisson(img / k).sample()

    return img / k, noisy
Ejemplo n.º 25
0
 def get_reconstruction_loss_expression(self, x, px_rate, px_r, px_dropout):
     rl = 0.0
     if self.gene_likelihood == "zinb":
         rl = (-ZeroInflatedNegativeBinomial(
             mu=px_rate, theta=px_r,
             zi_logits=px_dropout).log_prob(x).sum(dim=-1))
     elif self.gene_likelihood == "nb":
         rl = -NegativeBinomial(mu=px_rate,
                                theta=px_r).log_prob(x).sum(dim=-1)
     elif self.gene_likelihood == "poisson":
         rl = -Poisson(px_rate).log_prob(x).sum(dim=-1)
     return rl
def test_poisson(rate: float) -> None:
    """
    Test to check that maximizing the likelihood recovers the parameters
    """
    # generate samples
    rates = torch.zeros((NUM_SAMPLES, )) + rate

    poisson_distr = Poisson(rate=rates)
    samples = poisson_distr.sample()

    init_biases = [inv_softplus(rate - START_TOL_MULTIPLE * TOL * rate)]

    (rate_hat, ) = maximum_likelihood_estimate_sgd(
        PoissonOutput(),
        samples,
        init_biases=init_biases,
        num_epochs=20,
        learning_rate=0.05,
    )

    assert (np.abs(rate_hat - rate) < TOL *
            rate), f"rate did not match: rate = {rate}, rate_hat = {rate_hat}"
Ejemplo n.º 27
0
Archivo: vae.py Proyecto: zhuy16/scVI
 def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout,
                             **kwargs) -> torch.Tensor:
     """Return the reconstruction loss (for a minibatch)
     """
     # Reconstruction Loss
     if self.reconstruction_loss == "zinb":
         reconst_loss = -log_zinb_positive(x, px_rate, px_r,
                                           px_dropout).sum(dim=-1)
     elif self.reconstruction_loss == "nb":
         reconst_loss = -log_nb_positive(x, px_rate, px_r).sum(dim=-1)
     elif self.reconstruction_loss == "poisson":
         reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1)
     return reconst_loss
    def generate_joint(self,
                       x,
                       local_l_mean,
                       local_l_var,
                       batch_index,
                       y=None,
                       zero_inflated=True):
        """
        :param x: used only for shape match
        """
        n_batches, _ = x.shape
        device = "cuda" if torch.cuda.is_available() else "cpu"
        z_mean = torch.zeros(n_batches, self.n_latent, device=device)
        z_std = torch.zeros(n_batches, self.n_latent, device=device)
        z_prior_dist = Normal(z_mean, z_std)
        z_sim = z_prior_dist.sample()

        l_prior_dist = Normal(local_l_mean, torch.sqrt(local_l_var))
        l_sim = l_prior_dist.sample()

        # Decoder pass
        px_scale, px_r, px_rate, px_dropout = self.decoder(
            self.dispersion, z_sim, l_sim, batch_index, y)

        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels),
                self.px_r)  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)

        # Data generation
        p = px_rate / (px_rate + px_r)
        r = px_r
        # Important remark: Gamma is parametrized by the rate = 1/scale!
        l_train = Gamma(concentration=r, rate=(1 - p) / p).sample()

        # Clamping as distributions objects can have buggy behaviors when
        # their parameters are too high
        l_train = torch.clamp(l_train, max=1e8)
        gene_expressions = Poisson(
            l_train).sample()  # Shape : (n_samples, n_cells_batch, n_genes)
        if zero_inflated:
            p_zero = (1.0 + torch.exp(-px_dropout)).pow(-1)
            random_prob = torch.rand_like(p_zero)
            gene_expressions[random_prob <= p_zero] = 0

        return gene_expressions, z_sim, l_sim
Ejemplo n.º 29
0
 def init_Z(self, N=20):
     '''
     Samples from the Indian Buffet Process:
     First Customer i=1 takes the first Poisson(alpha/(i=1)) dishes
     Each next customer i>1 takes each previously sampled dish k
     independently with m_k/i where m_k is the number of people who
     have already sampled dish k. Z_ik=1 if the ith customer sampled
     the kth dish and 0 otherwise.
     '''
     Z = torch.zeros(N, self.K)
     K = int(self.K.item())
     total_dishes_sampled = 0
     for i in range(N):
         selected = torch.rand(total_dishes_sampled) < \
             Z[:,:total_dishes_sampled].sum(dim=0) / (i+1.)
         Z[i][:total_dishes_sampled][selected] = 1.0
         p_new_dishes = Pois(torch.tensor([self.alpha / (i + 1)]))
         new_dishes = int(p_new_dishes.sample().item())
         if total_dishes_sampled + new_dishes >= K:
             new_dishes = K - total_dishes_sampled
         Z[i][total_dishes_sampled:total_dishes_sampled + new_dishes] = 1.0
         total_dishes_sampled += new_dishes
     return self.left_order_form(Z)
Ejemplo n.º 30
0
Archivo: data.py Proyecto: weiyumou/PLL
class TextTrainDataset(TextDataset):
    def __init__(self, text_data, ngram, TEXT, rate) -> None:
        super().__init__(text_data, ngram, TEXT)
        self.pdist = Poisson(rate=rate)
        self.ngram = ngram
        self.ds = calc_dn(ngram)

    def __getitem__(self, index: int):
        k = self.pdist.sample().int().item()
        perm = torch.tensor(k_permute(self.ngram, k, self.ds),
                            dtype=torch.long)
        return self.data[index, perm], perm

    def set_poisson_rate(self, rate):
        self.pdist = Poisson(rate)