Ejemplo n.º 1
0
    def guide(self, x, temp=1, anneal_id=1.0, anneal_t=1.0):
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        pyro.module('VDSM_EncDec', self)
        num_individuals, num_timepoints, pixels = x.view(
            x.shape[0], x.shape[1], self.imsize**2 * self.nc).shape

        id_plate = pyro.plate("individuals", num_individuals, dim=-2)
        time_plate = pyro.plate("time", num_timepoints, dim=-1)

        # pass all sequences in and generate mean, sigma and ID
        x = x.view(num_individuals * num_timepoints, self.nc, self.imsize,
                   self.imsize)
        z_loc, z_scale, ID_loc, ID_scale = self.enc(x)
        ID_loc, ID_scale = self.id_layers(ID_loc,
                                          ID_scale)  # extra trainable layer
        ID_loc = torch.mean(ID_loc.view(num_individuals, num_timepoints, -1),
                            1).unsqueeze(1)
        ID_scale = torch.mean(
            ID_scale.view(num_individuals, num_timepoints, -1), 1).unsqueeze(1)
        z_loc = z_loc.view(num_individuals, num_timepoints, -1)
        z_scale = z_scale.view(num_individuals, num_timepoints, -1)

        # within the individuals plate:
        with id_plate:
            IDdist = dist.Normal(ID_loc, ID_scale).to_event(1)
            with poutine.scale(scale=anneal_id):
                ID = pyro.sample('ID', IDdist) * temp

            # within the individuals and timepoint plates:
            with time_plate:
                zdist = dist.Normal(z_loc, z_scale).to_event(1)
                with poutine.scale(scale=anneal_t):
                    z = pyro.sample('z', zdist)
        return z_loc, z_scale
Ejemplo n.º 2
0
    def model(self,
              src,
              trg,
              src_mask,
              trg_mask,
              src_lengths,
              trg_lengths,
              y_trg,
              kl=1.0):
        pyro.module('VNMT', self)
        encoder_hidden, encoder_final = self.encoder_hidden_x, self.encoder_final
        X = self.X_avg

        if self.posterior is not None:
            #regular VNMT
            z_mean, z_sig = self.prior(X)
        else:
            #match our...own parameters, should just mean KL(...) = 0 ery time
            mu_post, sig_post = self.get_batch_params(ret_posterior=True)
            z_mean, z_sig = mu_post, sig_post

        self.prior_params = {'mu': z_mean, 'sig': z_sig}
        with pyro.plate('data'):
            #TODO FYI: technically, the correct scaling is 1./ size_of_data
            with poutine.scale(scale=self.get_model_kl_const(scale=kl)):
                #TODO probably...a good idea to test this with flows also on prior...you know, so it's correct?
                use_flows = True
                dist = self.getDistribution(z_mean,
                                            z_sig,
                                            use_cached_flows=True,
                                            extra_cond=use_flows,
                                            cond_input=None)
                z = pyro.sample('z', dist)
            #TODO, need to add the latent z as input to decoder

            z = z if self.projection is None else self.projection(z)

            inputs = self.getWordEmbeddingsWithWordDropout(
                self.trg_embed, trg, trg_mask)
            #key thing is HERE, i am directly calling our decoder
            _, _, pre_output = self.decoder(inputs,
                                            encoder_hidden,
                                            encoder_final,
                                            src_mask,
                                            trg_mask,
                                            additional_input=z)
            logits = self.generator(pre_output)
            obs = y_trg.contiguous().view(-1)
            mask = trg_mask.contiguous().view(-1)
            try:
                mask = mask.bool()
            except AttributeError as e:
                #do nothing, is just a versioning issue
                _ = 0
            #My assumption is this will usually just sum the loss so we need to average it ourselves
            with poutine.scale(scale=self.get_reconstruction_const(scale=kl)):
                pyro.sample('preds',
                            Categorical(logits=logits.contiguous().view(
                                -1, logits.size(-1))).mask(mask),
                            obs=obs)
    def guide(self, response, mask, annealing_factor=1):
        pyro.module("item_encoder", self.item_encoder)
        pyro.module("ability_encoder", self.ability_encoder)
        device = response.device

        item_domain = torch.arange(self.num_item).unsqueeze(1).to(device)
        item_feat_mu, item_feat_logvar = self.item_encoder(item_domain)
        item_feat_scale = torch.exp(0.5 * item_feat_logvar)

        with poutine.scale(scale=annealing_factor):
            item_feat = pyro.sample(
                "item_feat",
                dist.Normal(item_feat_mu, item_feat_scale),
            )

        if self.conditional_posterior:
            ability_mu, ability_logvar = self.ability_encoder(
                response, mask, item_feat)
        else:
            ability_mu, ability_logvar = self.ability_encoder(response, mask)

        ability_scale = torch.exp(0.5 * ability_logvar)
        ability_dist = dist.Normal(ability_mu, ability_scale)

        if self.num_iafs > 0:
            ability_dist = TransformedDistribution(ability_dist, self.iafs)

        with poutine.scale(scale=annealing_factor):
            ability = pyro.sample("ability", ability_dist)

        return ability_mu, ability_logvar, item_feat_mu, item_feat_logvar
Ejemplo n.º 4
0
    def forward(
        self,
        x: torch.Tensor,
        _library: torch.Tensor,
        n_obs: Optional[int] = None,
        kl_weight: float = 1.0,
    ):
        # Topic feature distributions.
        with pyro.plate("topics",
                        self.n_topics), poutine.scale(None, kl_weight):
            pyro.sample(
                "log_topic_feature_dist",
                dist.Normal(
                    self.topic_feature_posterior_mu,
                    self.topic_feature_posterior_sigma,
                ).to_event(1),
            )

        # Cell topic distributions guide.
        with pyro.plate("cells",
                        size=n_obs or self.n_obs,
                        subsample_size=x.shape[0]), poutine.scale(
                            None, kl_weight):
            cell_topic_posterior_mu, cell_topic_posterior_sigma, _ = self.encoder(
                x)
            pyro.sample(
                "log_cell_topic_dist",
                dist.Normal(
                    cell_topic_posterior_mu,
                    F.softplus(cell_topic_posterior_sigma)).to_event(1),
            )
Ejemplo n.º 5
0
    def model_krein(data, node_ind, edge_ind, edge_list):
        r"""Defines a probabilistic model for the observed network data."""
        # Define priors on the regression coefficients
        mu = pyro.sample(
            'mu',
            dist.Normal(torch.tensor([0.0]), torch.tensor([2.0])).to_event(1))

        beta = pyro.sample(
            'beta',
            dist.Normal(loc=torch.zeros(embed_dim),
                        scale=torch.tensor(2.0)).to_event(1))

        # Define prior on the embedding vectors, with subsampling
        with poutine.scale(scale=data.num_nodes / len(node_ind)):
            omega = pyro.sample(
                'omega',
                dist.Normal(loc=torch.zeros(embed_dim, len(node_ind)),
                            scale=omega_model_scale).to_event(2))

        # Before proceeding further, define a list t which acts as the
        # inverse function of node_ind - i.e it takes a number in node_ind
        # to its index location
        t = torch.zeros(node_ind.max() + 1, dtype=torch.long)
        t[node_ind] = torch.arange(len(node_ind))

        # Create mask corresponding to entries of ind which lie within the
        # training set (i.e data.train_nodes)
        gt_data = data.gt[node_ind]
        obs_mask = np.isin(node_ind, data.nodes_train).tolist()
        gt_data[gt_data != gt_data] = 0.0
        obs_mask = torch.tensor(obs_mask, dtype=torch.bool)

        # Compute logits, compute relevant parts of sample
        if sum(obs_mask) != 0:
            logit_prob = mu + torch.mv(omega.t(), beta)
            with poutine.scale(scale=len(data.nodes_train) / sum(obs_mask)):
                pyro.sample(
                    'trust',
                    dist.Bernoulli(logits=logit_prob[obs_mask]).independent(1),
                    obs=gt_data[obs_mask])

        # Begin extracting the relevant components of the gram matrix
        # formed by omega. Note that to extract the relevant indices,
        # we need to account for the change in indexing induced by
        # subsampling omega
        gram_pos = torch.mm(omega[:int(embed_dim / 2), :].t(),
                            omega[:int(embed_dim / 2), :])
        gram_neg = torch.mm(omega[int(embed_dim / 2):, :].t(),
                            omega[int(embed_dim / 2):, :])
        gram = gram_pos - gram_neg
        gram_sample = gram[t[edge_list[0, :]], t[edge_list[0, :]]]

        # Finally draw terms corresponding to the edges
        with poutine.scale(scale=data.num_edges / len(edge_ind)):
            pyro.sample('a',
                        dist.Normal(loc=gram_sample,
                                    scale=obs_scale).to_event(1),
                        obs=data.edge_weight_logit[edge_ind])
Ejemplo n.º 6
0
    def model(self,
              x,
              temp_id=None,
              anneal_id=None,
              anneal_t=None,
              anneal_dynamics=None):
        pyro.module('vdsm_seq', self)
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        bs, seq_len, pixels = x.view(x.shape[0], x.shape[1],
                                     self.imsize**2 * self.nc).shape

        id_prior = x.new_zeros([bs, self.n_e_w])
        dynamics_prior = x.new_zeros([bs, self.dynamics_dim])

        # sample dynamics and identity from prior
        with pyro.plate('ID_plate', bs):
            IDdist = dist.Normal(id_prior, 1 / self.n_e_w).to_event(1)
            dz_dist = dist.Normal(dynamics_prior,
                                  1.0 / self.dynamics_dim).to_event(1)
            with poutine.scale(scale=anneal_id):
                ID = pyro.sample("ID", IDdist).to(
                    x.device) * temp_id  # static factors
            with poutine.scale(scale=anneal_dynamics):
                dz = pyro.sample("dz", dz_dist)  # dynamics factors
            ID_exp = torch.exp(ID)
            ID = ID_exp / ID_exp.sum(-1).unsqueeze(-1)

        zs = torch.zeros(bs, seq_len, self.input_dim)

        for i in pyro.plate('batch_loop', bs):
            z_prev = pyro.sample(
                'z_{}_0'.format(i),
                dist.Normal(torch.zeros(self.input_dim), 1).to_event(1))
            zs[i, 0] = z_prev
            for t in pyro.markov(range(1, seq_len)):
                z_loc, z_scale = self.transitions(
                    z_prev, dz[None, i].expand(1, self.dynamics_dim))
                z_dist = dist.Normal(z_loc, z_scale).to_event(1)
                with poutine.scale(scale=anneal_t):
                    z = pyro.sample('z_{}_{}'.format(i, t), z_dist)
                zs[i, t] = z
                z_prev = z

        x = torch.flatten(x, 2)
        recon = torch.zeros(bs,
                            seq_len,
                            self.imsize**2 * self.nc,
                            device=x.device)

        for ind in range(bs):
            recon[ind] = self.image_dec(zs[ind], ID[ind].unsqueeze(1))

        with pyro.plate('timepoints_ims', seq_len):
            with pyro.plate('inds_ims', bs):
                image_d = dist.Bernoulli(recon).to_event(1)
                with poutine.scale(scale=1.0):
                    pyro.sample('images', image_d, obs=x)
Ejemplo n.º 7
0
def test_subsample_gradient(Elbo, reparameterized, has_rsample, subsample, local_samples, scale):
    pyro.clear_param_store()
    data = torch.tensor([-0.5, 2.0])
    subsample_size = 1 if subsample else len(data)
    precision = 0.06 * scale
    Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

    def model(subsample):
        with pyro.plate("data", len(data), subsample_size, subsample) as ind:
            x = data[ind]
            z = pyro.sample("z", Normal(0, 1))
            pyro.sample("x", Normal(z, 1), obs=x)

    def guide(subsample):
        scale = pyro.param("scale", lambda: torch.tensor([1.0]))
        with pyro.plate("data", len(data), subsample_size, subsample):
            loc = pyro.param("loc", lambda: torch.zeros(len(data)), event_dim=0)
            z_dist = Normal(loc, scale)
            if has_rsample is not None:
                z_dist.has_rsample_(has_rsample)
            pyro.sample("z", z_dist)

    if scale != 1.0:
        model = poutine.scale(model, scale=scale)
        guide = poutine.scale(guide, scale=scale)

    num_particles = 50000
    if local_samples:
        guide = config_enumerate(guide, num_samples=num_particles)
        num_particles = 1

    optim = Adam({"lr": 0.1})
    elbo = Elbo(max_plate_nesting=1,  # set this to ensure rng agrees across runs
                num_particles=num_particles,
                vectorize_particles=True,
                strict_enumeration_warning=False)
    inference = SVI(model, guide, optim, loss=elbo)
    with xfail_if_not_implemented():
        if subsample_size == 1:
            inference.loss_and_grads(model, guide, subsample=torch.tensor([0], dtype=torch.long))
            inference.loss_and_grads(model, guide, subsample=torch.tensor([1], dtype=torch.long))
        else:
            inference.loss_and_grads(model, guide, subsample=torch.tensor([0, 1], dtype=torch.long))
    params = dict(pyro.get_param_store().named_parameters())
    normalizer = 2 if subsample else 1
    actual_grads = {name: param.grad.detach().cpu().numpy() / normalizer for name, param in params.items()}

    expected_grads = {'loc': scale * np.array([0.5, -2.0]), 'scale': scale * np.array([2.0])}
    for name in sorted(params):
        logger.info('expected {} = {}'.format(name, expected_grads[name]))
        logger.info('actual   {} = {}'.format(name, actual_grads[name]))
    assert_equal(actual_grads, expected_grads, prec=precision)
Ejemplo n.º 8
0
    def model(self,
              diurnality,
              viirs_observed,
              land_cover,
              latitude,
              longitude,
              meteorology,
              annealing_factor=1.0):
        # land_cover.shape: [128, 17, 30, 30]
        pyro.module("vae", self)
        batch_size = diurnality.shape[0]
        T_max = viirs_observed.size(1)
        z_prev = self.z_0.expand(batch_size, self.z_0.size(0))
        alpha0 = torch.tensor(10.0, device=diurnality.device)
        beta0 = torch.tensor(10.0, device=diurnality.device)
        diurnal_ratio = pyro.sample("diurnal_ratio", dist.Beta(alpha0, beta0))

        with pyro.plate("data", batch_size):
            diurnal_ = pyro.sample("diurnal_",
                                   dist.Bernoulli(diurnal_ratio),
                                   obs=diurnality).long()
            for t in pyro.markov(range(1, T_max + 1)):
                z_loc, z_scale = self.trans(z_prev, land_cover)
                with poutine.scale(scale=annealing_factor):
                    z_t = pyro.sample(
                        f"z_{t}",
                        dist.Normal(z_loc[diurnal_, :],
                                    z_scale[diurnal_, :]).to_event(1))
                image_p = self.emitter(z_t)
                image = pyro.sample(f"image_{t}",
                                    dist.Bernoulli(image_p).to_event(1),
                                    obs=viirs_observed[:, t - 1, :, :].reshape(
                                        -1, self.image_flatten_dim))
                z_prev = z_t
            return image_p, image
Ejemplo n.º 9
0
    def guide(self, imgs=None):
        """ 1. run the inference 
            2. sample latent variables 
        """
        #-----------------------#
        #--------  Trick -------#
        #-----------------------#
        if (imgs is None):
            observed = False
            imgs = torch.zeros(8, self.ch, self.height, self.width)
            if (self.use_cuda):
                imgs = imgs.cuda()
        else:
            observed = True
        #-----------------------#
        #----- Enf of Trick ----#
        #-----------------------#

        pyro.module("encoder", self.encoder)

        batch_size, ch, width, height = imgs.shape

        with pyro.plate('batch_size', batch_size, dim=-1):
            z = self.encoder(imgs)
            with poutine.scale(scale=self.scale):
                pyro.sample('z_latent',
                            dist.Normal(z.z_mu, z.z_std).to_event(1))
Ejemplo n.º 10
0
    def model(
        self,
        x0: torch.Tensor,
        x1: torch.Tensor,
        log_data_split: torch.Tensor,
        log_data_split_complement: torch.Tensor,
    ):
        # register modules with Pyro
        pyro.module("mcv_nbvae", self)

        with pyro.plate("data", len(x0)), poutine.scale(scale=self.scale_factor):
            z = pyro.sample(
                "latent", dist.Normal(0, x0.new_ones(self.n_latent)).to_event(1)
            )

            lib = pyro.sample(
                "library", dist.Normal(self.lib_loc, self.lib_scale).to_event(1)
            )

            log_r, logit = self.decoder(z, lib)

            # adjust for data split
            log_r += log_data_split_complement - log_data_split

            pyro.sample(
                "obs",
                dist.NegativeBinomial(
                    total_count=torch.exp(log_r) + self.epsilon, logits=logit
                ).to_event(1),
                obs=x1,
            )
Ejemplo n.º 11
0
    def model(self, x, y=None):
        # Register various nn.Modules with Pyro
        pyro.module("scanvi", self)

        # This gene-level parameter modulates the variance of the observation distribution
        theta = pyro.param("inverse_dispersion", 10.0 * x.new_ones(self.num_genes),
                           constraint=constraints.positive)

        # We scale all sample statements by scale_factor so that the ELBO is normalized
        # wrt the number of datapoints and genes
        with pyro.plate("batch", len(x)), poutine.scale(scale=self.scale_factor):
            z1 = pyro.sample("z1", dist.Normal(0, x.new_ones(self.latent_dim)).to_event(1))
            # Note that if y is None (i.e. y is unobserved) then y will be sampled;
            # otherwise y will be treated as observed.
            y = pyro.sample("y", dist.OneHotCategorical(logits=x.new_zeros(self.num_labels)),
                            obs=y)

            z2_loc, z2_scale = self.z2_decoder(z1, y)
            z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1))

            l_scale = self.l_scale * x.new_ones(1)
            l = pyro.sample("l", dist.LogNormal(self.l_loc, l_scale).to_event(1))

            # Note that by construction mu is normalized (i.e. mu.sum(-1) == 1) and the
            # total scale of counts for each cell is determined by `l`
            gate_logits, mu = self.x_decoder(z2)
            # TODO revisit this parameterization if torch.distributions.NegativeBinomial changes
            # from failure to success parametrization;
            # see https://github.com/pytorch/pytorch/issues/42449
            nb_logits = (l * mu + self.epsilon).log() - (theta + self.epsilon).log()
            x_dist = dist.ZeroInflatedNegativeBinomial(gate_logits=gate_logits, total_count=theta,
                                                       logits=nb_logits)
            # Observe the datapoint x using the observation distribution x_dist
            pyro.sample("x", x_dist.to_event(1), obs=x)
Ejemplo n.º 12
0
def speaker1(state, threshold, prior):
    s1Optimality = 5.
    utterance = utterance_prior()
    L0 = listener0(utterance, threshold, prior)
    with poutine.scale(scale=torch.tensor(s1Optimality)):
        pyro.sample("L0_score", L0, obs=state)
    return utterance
    def languageModelOptimization(self, z, z_hid, src, src_lengths,
                                  src_input_mask, kl):
        src = src.clone()  #pretty sure that's a bug anyways...
        #need to redo src side as batch doesn't handle it
        #TODO...probably should be handled in rebatch
        src_indx = src[:, :-1]
        src_trgs = src[:, 1:]
        self.src_tok_count = (src_trgs != self.pad_index).data.sum().item()
        src_output_mask = (src_trgs != self.pad_index
                           )  #similar to what is done in Batch class for trg
        z_x = self.resize_z(z_hid, self.num_layers)

        inputs = self.getWordEmbeddingsWithWordDropout(self.src_embed,
                                                       src_indx,
                                                       src_output_mask)
        _, _, pre_output = self.lang_model(inputs,
                                           src_input_mask,
                                           src_output_mask,
                                           hidden=z_x,
                                           z=z)
        logits = self.lm_generator(pre_output)
        logits = logits.contiguous().view(-1, logits.size(-1))
        obs = src_trgs.contiguous().view(-1)
        mask = src_output_mask.contiguous().view(-1)
        try:
            mask = mask.bool()
        except AttributeError as e:
            #do nothing, is a versionining thing to supress a warning
            _ = 0

        with poutine.scale(scale=self.get_reconstruction_const(scale=kl)):
            pyro.sample('lm_preds',
                        Categorical(logits=logits).mask(mask),
                        obs=obs)
Ejemplo n.º 14
0
    def _model(
        self,
        z0: Tensor,
        batch_dim: int,
        time_steps: int,
        x: Optional[Tensor] = None,
        seq_mask: Optional[Tensor] = None,
        annealing: float = 1.0,
    ) -> None:
        pyro.module("dmm", self)
        seq_mask = (seq_mask if seq_mask is not None else torch.ones(
            z0.size(0), time_steps))

        z = z0.expand(batch_dim, z0.size(-1))
        with pyro.plate("data", batch_dim):
            for t in pyro.markov(range(time_steps)):
                m = seq_mask[:, t:t + 1]
                z_loc, z_scale = self.transition(z)
                with poutine.scale(None, annealing):
                    z = pyro.sample(f"z_{t+1}",
                                    Normal(z_loc, z_scale).mask(m).to_event(1))

                x_loc, x_scale = self.emit(z)
                pyro.sample(
                    f"x_{t+1}",
                    Normal(x_loc, x_scale).mask(m).to_event(1),
                    obs=x[:, t, :] if x is not None else None,
                )
Ejemplo n.º 15
0
    def guide(self, x, y=None):
        pyro.module("scanvi", self)
        with pyro.plate("batch",
                        len(x)), poutine.scale(scale=self.scale_factor):
            z2_loc, z2_scale, l_loc, l_scale = self.z2l_encoder(x)
            pyro.sample("l", dist.LogNormal(l_loc, l_scale).to_event(1))
            z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1))

            y_logits = self.classifier(z2)
            y_dist = dist.OneHotCategorical(logits=y_logits)
            if y is None:
                # x is unlabeled so sample y using q(y|z2)
                y = pyro.sample("y", y_dist)
            else:
                # x is labeled so add a classification loss term
                # (this way q(y|z2) learns from both labeled and unlabeled data)
                classification_loss = y_dist.log_prob(y)
                # Note that the negative sign appears because we're adding this term in the guide
                # and the guide log_prob appears in the ELBO as -log q
                pyro.factor(
                    "classification_loss",
                    -self.alpha * classification_loss,
                    has_rsample=False,
                )

            z1_loc, z1_scale = self.z1_encoder(z2, y)
            pyro.sample("z1", dist.Normal(z1_loc, z1_scale).to_event(1))
Ejemplo n.º 16
0
def speaker(qudValue, qud):
    alpha = 1.
    utterance = utterance_prior()
    literal_marginal = literal_listener(utterance, qud)
    with poutine.scale(scale=torch.tensor(alpha)):
        pyro.sample("listener", literal_marginal, obs=qudValue)
    return utterance
    def translationModelOptimization(self, z, z_hid, src, src_mask,
                                     src_lengths, trg, trg_mask, trg_lengths,
                                     y_trg, kl):
        #self.num_layers*2 because encoder is bidirectional
        z_hid = self.resize_z(z_hid, self.num_layers * 2)

        encoder_hidden, encoder_final = self.encoder(self.src_embed(src),
                                                     src_mask,
                                                     src_lengths,
                                                     hidden=z_hid)
        inputs = self.getWordEmbeddingsWithWordDropout(self.trg_embed, trg,
                                                       trg_mask)
        #key thing is HERE, i am directly calling our decoder
        _, _, pre_output = self.decoder(inputs,
                                        encoder_hidden,
                                        encoder_final,
                                        src_mask,
                                        trg_mask,
                                        additional_input=z)
        logits = self.generator(pre_output)
        logits = logits.contiguous().view(-1, logits.size(-1))
        obs = y_trg.contiguous().view(-1)
        mask = trg_mask.contiguous().view(-1)
        try:
            mask = mask.bool()
        except AttributeError as e:
            #do nothing, means it's an older pytorch version
            _ = 0
        #My assumption is this will usually just sum the loss so we need to average it ourselves
        with poutine.scale(scale=self.get_reconstruction_const(scale=kl)):
            pyro.sample('preds',
                        Categorical(logits=logits).mask(mask),
                        obs=obs)
    def model(self,
              src,
              trg,
              src_mask,
              trg_mask,
              src_lengths,
              trg_lengths,
              y_trg,
              kl=1.0):
        #TODO again, maaaaaaybe a good idea to specify which parts I need to update...
        pyro.module("GNMT", self)

        with pyro.plate('data'):

            with poutine.scale(scale=self.get_model_kl_const(scale=kl)):
                use_flows = False
                dist = self.getDistribution(torch.zeros_like(self.mu_x),
                                            torch.ones_like(self.sig_x),
                                            cond_input=None,
                                            extra_cond=use_flows)
                z = pyro.sample('z', dist)

            z_hid = self.project(z)

            #Calculations for translation
            if self.train_mt:
                self.translationModelOptimization(z, z_hid, src, src_mask,
                                                  src_lengths, trg, trg_mask,
                                                  trg_lengths, y_trg, kl)

            #Calculations for language modeling
            #mask not passed in because it's handled differently for the lang modeling
            if self.train_lm:
                self.languageModelOptimization(z, z_hid, src, src_lengths,
                                               src_mask, kl)
Ejemplo n.º 19
0
    def model(self,
              src,
              trg,
              src_mask,
              trg_mask,
              src_lengths,
              trg_lengths,
              y_trg,
              kl=1.0):
        pyro.module('VanillaNMT', self)
        self.encoder_hidden_x, self.encoder_final = self.encoder(
            self.src_embed(src), src_mask, src_lengths)
        encoder_hidden, encoder_final = self.encoder_hidden_x, self.encoder_final

        with pyro.plate('data'):
            #for consistency, although word dropout ...supposedly makes less sense with out latent variables
            inputs = self.getWordEmbeddingsWithWordDropout(
                self.trg_embed, trg, trg_mask)
            #key thing is HERE, i am directly calling our decoder
            _, _, pre_output = self.decoder(inputs, encoder_hidden,
                                            encoder_final, src_mask, trg_mask)
            logits = self.generator(pre_output)
            obs = y_trg.contiguous().view(-1)
            mask = trg_mask.contiguous().view(-1)
            try:
                mask = mask.bool()
            except AttributeError as e:
                #do nothing, is just a versioning issue
                _ = 0
            #My assumption is this will usually just sum the loss so we need to average it ourselves
            with poutine.scale(scale=self.get_reconstruction_const(scale=kl)):
                pyro.sample('preds',
                            Categorical(logits=logits.contiguous().view(
                                -1, logits.size(-1))).mask(mask),
                            obs=obs)
Ejemplo n.º 20
0
    def model(self):
        self.set_mode("model")

        M = self.Xu.size(0)
        Kuu = self.kernel(self.Xu).contiguous()
        Kuu.view(-1)[::M + 1] += self.jitter  # add jitter to the diagonal
        Luu = Kuu.cholesky()

        zero_loc = self.Xu.new_zeros(self.u_loc.shape)
        if self.whiten:
            identity = eye_like(self.Xu, M)
            pyro.sample(self._pyro_get_fullname("u"),
                        dist.MultivariateNormal(zero_loc, scale_tril=identity)
                            .to_event(zero_loc.dim() - 1))
        else:
            pyro.sample(self._pyro_get_fullname("u"),
                        dist.MultivariateNormal(zero_loc, scale_tril=Luu)
                            .to_event(zero_loc.dim() - 1))

        f_loc, f_var = conditional(self.X, self.Xu, self.kernel, self.u_loc, self.u_scale_tril,
                                   Luu, full_cov=False, whiten=self.whiten, jitter=self.jitter)

        f_loc = f_loc + self.mean_function(self.X)
        if self.y is None:
            return f_loc, f_var
        else:
            # we would like to load likelihood's parameters outside poutine.scale context
            self.likelihood._load_pyro_samples()
            with poutine.scale(scale=self.num_data / self.X.size(0)):
                return self.likelihood(f_loc, f_var, self.y)
Ejemplo n.º 21
0
    def model(self):
        self.set_mode("model")

        Xu = self.get_param("Xu")
        u_loc = self.get_param("u_loc")
        u_scale_tril = self.get_param("u_scale_tril")

        M = Xu.shape[0]
        Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M, M)) * self.jitter
        Luu = Kuu.potrf(upper=False)

        zero_loc = Xu.new_zeros(u_loc.shape)
        u_name = param_with_module_name(self.name, "u")
        if self.whiten:
            Id = torch.eye(M, out=Xu.new_empty(M, M))
            pyro.sample(u_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Id)
                            .independent(zero_loc.dim() - 1))
        else:
            pyro.sample(u_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Luu)
                            .independent(zero_loc.dim() - 1))

        f_loc, f_var = conditional(self.X, Xu, self.kernel, u_loc, u_scale_tril,
                                   Luu, full_cov=False, whiten=self.whiten,
                                   jitter=self.jitter)

        f_loc = f_loc + self.mean_function(self.X)
        if self.y is None:
            return f_loc, f_var
        else:
            with poutine.scale(None, self.num_data / self.X.shape[0]):
                return self.likelihood(f_loc, f_var, self.y)
Ejemplo n.º 22
0
    def guide(self, obs):
        batch_size = obs['x'].shape[0]
        with pyro.plate('observations', batch_size):
            hidden = self.encoder(obs['x'])

            ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
                obs['ventricle_volume'])
            brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
                obs['brain_volume'])
            lesion_volume_ = self.lesion_volume_flow_constraint_transforms.inv(
                obs['lesion_volume'])
            slice_number = obs['slice_number']
            ctx = torch.cat([
                ventricle_volume_, brain_volume_, lesion_volume_, slice_number
            ], 1)
            hidden = torch.cat([hidden, ctx], 1)

            z_base_dist = self.latent_encoder.predict(hidden)
            z_dist = TransformedDistribution(
                z_base_dist, self.posterior_flow_transforms
            ) if self.use_posterior_flow else z_base_dist
            _ = self.posterior_affine
            _ = self.posterior_flow_components
            with poutine.scale(scale=self.annealing_factor[-1]):
                z = pyro.sample('z', z_dist)

        return z
Ejemplo n.º 23
0
    def model(self):
        obs = self.pgm_model()

        ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
            obs['ventricle_volume'])
        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
            obs['brain_volume'])
        lesion_volume_ = self.lesion_volume_flow_constraint_transforms.inv(
            obs['lesion_volume'])
        slice_number = obs['slice_number']
        ctx = torch.cat(
            [ventricle_volume_, brain_volume_, lesion_volume_, slice_number],
            1)

        if self.prior_components > 1:
            z_scale = (
                0.5 * self.z_scale).exp() + 1e-5  # z_scale parameter is logvar
            z_base_dist = MixtureOfDiagNormalsSharedCovariance(
                self.z_loc, z_scale, self.z_components).to_event(0)
        else:
            z_base_dist = Normal(self.z_loc, self.z_scale).to_event(1)
        z_dist = TransformedDistribution(
            z_base_dist,
            self.prior_flow_transforms) if self.use_prior_flow else z_base_dist
        _ = self.prior_affine
        _ = self.prior_flow_components
        with poutine.scale(scale=self.annealing_factor[-1]):
            z = pyro.sample('z', z_dist)
        latent = torch.cat([z, ctx], 1)

        x_dist = self._get_transformed_x_dist(latent)  # run decoder
        x = pyro.sample('x', x_dist)

        obs.update(dict(x=x, z=z))
        return obs
Ejemplo n.º 24
0
    def guide(
        self,
        x: torch.Tensor,
        x_packed_reversed: nn.utils.rnn.PackedSequence,
        seq_mask: torch.Tensor,
        seq_lengths: torch.Tensor,
        annealing=1.0,
    ) -> Tensor:

        pyro.module("dmm", self)
        batch_dim, time_steps, _ = x.shape
        h0 = self.h0.expand(self.h0.size(0), batch_dim,
                            self.h0.size(-1)).contiguous()
        h_packed_reversed = self.encode(x_packed_reversed, h0)[0]
        h_reversed, _ = pad_packed_sequence(h_packed_reversed,
                                            batch_first=True)
        h = self.reverse_sequences(h_reversed, seq_lengths)
        z = self.qz0.expand(batch_dim, self.qz0.size(-1))
        with pyro.plate("data", batch_dim):
            for t in range(time_steps):
                z_params = self.combine(h[:, t, :], z)
                with poutine.scale(None, annealing):
                    z = pyro.sample(
                        f"z_{t+1}",
                        Normal(*z_params).mask(seq_mask[:,
                                                        t:t + 1]).to_event(1),
                    )
        return z
def irt_model_3pl(
    ability_dim,
    num_person,
    num_item,
    device,
    response=None,
    mask=None,
    annealing_factor=1,
    nonlinear=False,
):
    ability_prior = dist.Normal(
        torch.zeros((num_person, ability_dim), device=device),
        torch.ones((num_person, ability_dim), device=device),
    )
    with poutine.scale(scale=annealing_factor):
        ability = pyro.sample("ability", ability_prior)

    item_feat_prior = dist.Normal(
        torch.zeros((num_item, ability_dim + 2), device=device),
        torch.ones((num_item, ability_dim + 2), device=device),
    )
    with poutine.scale(scale=annealing_factor):
        item_feat = pyro.sample("item_feat", item_feat_prior)

    discrimination = item_feat[:, :ability_dim]
    difficulty = item_feat[:, ability_dim:ability_dim + 1]
    guess_logit = item_feat[:, ability_dim + 1:ability_dim + 2]
    guess = torch.sigmoid(guess_logit)

    logit = (torch.mm(ability, -discrimination.T) + difficulty.T).unsqueeze(2)

    if nonlinear:
        logit = torch.pow(logit, 2)

    guess = guess.unsqueeze(0)
    response_mu = guess + (1. - guess) * torch.sigmoid(logit)

    if mask is not None:
        response_dist = dist.Bernoulli(response_mu).mask(mask)
    else:
        response_dist = dist.Bernoulli(response_mu)

    if response is not None:
        pyro.sample("response", response_dist, obs=response)
    else:
        response = pyro.sample("response", response_dist)
        return response, ability, item_feat
Ejemplo n.º 26
0
    def model(
        self,
        mini_batch,
        mini_batch_reversed,
        mini_batch_mask,
        mini_batch_seq_lengths,
        annealing_factor=1.0,
    ):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)

        # register all PyTorch (sub)modules with pyro
        # this needs to happen in both the model and guide
        pyro.module("dmm", self)

        # set z_prev = z_0 to setup the recursive conditioning in p(z_t | z_{t-1})
        z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0))

        # we enclose all the sample statements in the model in a plate.
        # this marks that each datapoint is conditionally independent of the others
        with pyro.plate("z_minibatch", len(mini_batch)):
            # sample the latents z and observed x's one time step at a time
            # we wrap this loop in pyro.markov so that TraceEnum_ELBO can use multiple samples from the guide at each z
            for t in pyro.markov(range(1, T_max + 1)):
                # the next chunk of code samples z_t ~ p(z_t | z_{t-1})
                # note that (both here and elsewhere) we use poutine.scale to take care
                # of KL annealing. we use the mask() method to deal with raggedness
                # in the observed data (i.e. different sequences in the mini-batch
                # have different lengths)

                # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
                z_loc, z_scale = self.trans(z_prev)

                # then sample z_t according to dist.Normal(z_loc, z_scale)
                # note that we use the reshape method so that the univariate Normal distribution
                # is treated as a multivariate Normal distribution with a diagonal covariance.
                with poutine.scale(scale=annealing_factor):
                    z_t = pyro.sample(
                        "z_%d" % t,
                        dist.Normal(z_loc, z_scale).mask(
                            mini_batch_mask[:, t - 1:t]).to_event(1),
                    )

                # compute the probabilities that parameterize the bernoulli likelihood
                emission_probs_t = self.emitter(z_t)
                # the next statement instructs pyro to observe x_t according to the
                # bernoulli distribution p(x_t|z_t)
                pyro.sample(
                    "obs_x_%d" % t,
                    dist.Bernoulli(emission_probs_t).mask(
                        mini_batch_mask[:, t - 1:t]).to_event(1),
                    obs=mini_batch[:, t - 1, :],
                )
                # the latent sampled at this time step will be conditioned upon
                # in the next time step so keep track of it
                z_prev = z_t
Ejemplo n.º 27
0
    def guide(self, x: torch.Tensor):
        pyro.module(self.NAME, self)

        with pyro.plate("data", len(x)), poutine.scale(scale=self.scale_factor):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale, l_loc, l_scale = self.encoder(x)

            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            pyro.sample("library", dist.Normal(l_loc, l_scale).to_event(1))
Ejemplo n.º 28
0
def speaker(face, faces, utterance_candidates, depth=0):
    """
	return: index of utterance
	"""
    alpha = 1.
    utterance = utterance_prior(utterance_candidates)
    literal_marginal = listener(utterance, utterance_candidates, faces, depth)
    with poutine.scale(scale=torch.tensor(alpha)):
        pyro.sample('listener', literal_marginal, obs=face)
    return utterance
Ejemplo n.º 29
0
    def __init__(
        self,
        scale_elbo: Union[float, None] = 1.0,
        **kwargs,
    ):
        super().__init__(**kwargs)

        if scale_elbo != 1.0:
            self.svi = pyro.infer.SVI(
                model=poutine.scale(self.module.model, scale_elbo),
                guide=poutine.scale(self.module.guide, scale_elbo),
                optim=self.optim,
                loss=self.loss_fn,
            )
        else:
            self.svi = pyro.infer.SVI(
                model=self.module.model,
                guide=self.module.guide,
                optim=self.optim,
                loss=self.loss_fn,
            )
Ejemplo n.º 30
0
 def __iter__(self):
     if not am_i_wrapped():
         for i in self.subsample:
             yield i if isinstance(i, numbers.Number) else i.item()
     else:
         indep_context = poutine.indep(name=self.name, size=self.subsample_size)
         with poutine.scale(scale=self.size / self.subsample_size):
             for i in self.subsample:
                 indep_context.next_context()
                 with indep_context:
                     # convert to python numeric type as functions like torch.ones(*args)
                     # do not work with dim 0 torch.Tensor instances.
                     yield i if isinstance(i, numbers.Number) else i.item()
Ejemplo n.º 31
0
 def __enter__(self):
     self._wrapped = am_i_wrapped()
     self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim)
     if self._wrapped:
         try:
             self._scale_messenger = poutine.scale(scale=self.size / self.subsample_size)
             self._indep_messenger = poutine.indep(name=self.name, size=self.subsample_size, dim=self.dim)
             self._scale_messenger.__enter__()
             self._indep_messenger.__enter__()
         except BaseException:
             _DIM_ALLOCATOR.free(self.name, self.dim)
             raise
     return self.subsample
Ejemplo n.º 32
0
    def model(self, x_sent, y_sent, decode=False, kl_annealing=1.0):
        pyro.module("vnmt", self)
        #Produce our prior parameters
        x_embeds, x_len, x_mask, y_sent = self.x_embed(x_sent, y_sent)
        x_out, s_0 = self.encoder(x_embeds)

        X, x_len = pad_packed_sequence(x_out, batch_first=self.b_f)
        if self.use_cuda:
            x_len = x_len.cuda()

        z_input = torch.sum(X, dim=1) / x_len.unsqueeze(1).float()
        z_mean, z_sig = self.prior(z_input)

        #TODO technically this is done in forward call....
        y_labels = self.y_embed.sentences2IndexesAndLens(y_sent)

        T_max = max([y[1] for y in y_labels])
        y_labels, y_mask = self.y_embed.padAndMask(y_labels,
                                                   batch_first=self.b_f)

        with pyro.plate('z_minibatch'):
            #sample from model prior P(Z | X)
            with poutine.scale(scale=kl_annealing):
                semantics = pyro.sample('z_semantics',
                                        dist.Normal(z_mean, z_sig))
            #Generate sequences
            #TODO probably need to verify this, supposed to be H_e' = g(Wh_z + b_z) in paper eq 11
            semantics = F.relu(self.h_e_p(semantics))

            #TODO verify that this makes sense to do
            #TODO so...based on paper graphic the init hidden state is the last part of the RNN run in reverse on seq
            #TODO i added a "bridge" to take in the final context and convert to a proper hidden state size...ned to put it in
            #s_0 = s_0.view(self.num_layers,2 if self.enc_bi else 1, len(y_len), self.hidden_dim)
            s_t = s_0[1].unsqueeze(0)  #s_0[:, 1 if self.enc_bi else 0, :, :]
            for t in range(0, T_max):
                #TODO atm we are teacher forcing the model (i.e. using labeledinputs)
                #TODO in future may want to use generative samples to improve robustness
                #probably need to figure out a more eloquent solution...
                l = y_labels[:, t]
                inputs = torch.cat([
                    F.relu(self.y_embed.getBatchEmbeddings(l).unsqueeze(1)),
                    semantics.unsqueeze(1)
                ],
                                   dim=2)
                output, s_t = self.decoder(inputs, s_t)
                output = self.emitter(output).squeeze(1)
                entry = pyro.sample(
                    'y_{}'.format(t),
                    dist.Categorical(probs=F.softmax(output, dim=1)).mask(
                        y_mask[:, t:t + 1].squeeze()),
                    obs=l)
Ejemplo n.º 33
0
 def __iter__(self):
     if not am_i_wrapped():
         for i in self.subsample:
             yield i if isinstance(i, numbers.Number) else i.item()
     else:
         indep_context = poutine.indep(name=self.name,
                                       size=self.subsample_size)
         with poutine.scale(scale=self.size / self.subsample_size):
             for i in self.subsample:
                 indep_context.next_context()
                 with indep_context:
                     # convert to python numeric type as functions like torch.ones(*args)
                     # do not work with dim 0 torch.Tensor instances.
                     yield i if isinstance(i, numbers.Number) else i.item()
Ejemplo n.º 34
0
Archivo: dmm.py Proyecto: lewisKit/pyro
    def model(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)

        # register all PyTorch (sub)modules with pyro
        # this needs to happen in both the model and guide
        pyro.module("dmm", self)

        # set z_prev = z_0 to setup the recursive conditioning in p(z_t | z_{t-1})
        z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0))

        # we enclose all the sample statements in the model in a iarange.
        # this marks that each datapoint is conditionally independent of the others
        with pyro.iarange("z_minibatch", len(mini_batch)):
            # sample the latents z and observed x's one time step at a time
            for t in range(1, T_max + 1):
                # the next chunk of code samples z_t ~ p(z_t | z_{t-1})
                # note that (both here and elsewhere) we use poutine.scale to take care
                # of KL annealing. we use the mask() method to deal with raggedness
                # in the observed data (i.e. different sequences in the mini-batch
                # have different lengths)

                # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
                z_loc, z_scale = self.trans(z_prev)

                # then sample z_t according to dist.Normal(z_loc, z_scale)
                # note that we use the reshape method so that the univariate Normal distribution
                # is treated as a multivariate Normal distribution with a diagonal covariance.
                with poutine.scale(scale=annealing_factor):
                    z_t = pyro.sample("z_%d" % t,
                                      dist.Normal(z_loc, z_scale)
                                          .mask(mini_batch_mask[:, t - 1:t])
                                          .independent(1))

                # compute the probabilities that parameterize the bernoulli likelihood
                emission_probs_t = self.emitter(z_t)
                # the next statement instructs pyro to observe x_t according to the
                # bernoulli distribution p(x_t|z_t)
                pyro.sample("obs_x_%d" % t,
                            dist.Bernoulli(emission_probs_t)
                                .mask(mini_batch_mask[:, t - 1:t])
                                .independent(1),
                            obs=mini_batch[:, t - 1, :])
                # the latent sampled at this time step will be conditioned upon
                # in the next time step so keep track of it
                z_prev = z_t
Ejemplo n.º 35
0
def irange(name, size, subsample_size=None, subsample=None, use_cuda=None):
    """
    Non-vectorized version of ``iarange``. See ``iarange`` for details.

    :param str name: A name that will be used for this site in a Trace.
    :param int size: The size of the collection being subsampled (like ``stop``
        in builtin ``range``).
    :param int subsample_size: Size of minibatches used in subsampling.
        Defaults to ``size``.
    :param subsample: Optional custom subsample for user-defined subsampling
        schemes. If specified, then ``subsample_size`` will be set to
        ``len(subsample)``.
    :type subsample: Anything supporting ``len()``.
    :param bool use_cuda: Optional bool specifying whether to use cuda tensors
        for internal ``log_pdf`` computations. Defaults to
        ``torch.Tensor.is_cuda``.
    :return: A generator yielding a sequence of integers.

    Examples::

        >>> for i in irange('data', 100, subsample_size=10):
                if z[i]:  # Prevents vectorization.
                    observe('obs_{}'.format(i), normal, data[i], mu, sigma)

    See `SVI Part II <http://pyro.ai/examples/svi_part_ii.html>`_ for an extended discussion.
    """
    subsample, scale = _subsample(name, size, subsample_size, subsample, use_cuda)
    if isinstance(subsample, Variable):
        subsample = subsample.data
    if len(_PYRO_STACK) == 0:
        for i in subsample:
            yield i
    else:
        indep_context = poutine.indep(None, name, vectorized=False)
        with poutine.scale(None, scale):
            for i in subsample:
                with indep_context:
                    yield i
Ejemplo n.º 36
0
def iarange(name, size=None, subsample_size=None, subsample=None, use_cuda=None):
    """
    Context manager for conditionally independent ranges of variables.

    ``iarange`` is similar to ``torch.arange`` in that it yields an array
    of indices by which other tensors can be indexed. ``iarange`` differs from
    ``torch.arange`` in that it also informs inference algorithms that the
    variables being indexed are conditionally independent. To do this,
    ``iarange`` is a provided as context manager rather than a function, and
    users must guarantee that all computation within an ``iarange`` context
    is conditionally independent::

        with iarange("name", size) as ind:
            # ...do conditionally independent stuff with ind...

    Additionally, ``iarange`` can take advantage of the conditional
    independence assumptions by subsampling the indices and informing inference
    algorithms to scale various computed values. This is typically used to
    subsample minibatches of data::

        with iarange("data", len(data), subsample_size=100) as ind:
            batch = data[ind]
            assert len(batch) == 100

    By default ``subsample_size=False`` and this simply yields a
    ``torch.arange(0, size)``. If ``0 < subsample_size <= size`` this yields a
    single random batch of indices of size ``subsample_size`` and scales all
    log likelihood terms by ``size/batch_size``, within this context.

    .. warning::  This is only correct if all computation is conditionally
        independent within the context.

    :param str name: A unique name to help inference algorithms match
        ``iarange`` sites between models and guides.
    :param int size: Optional size of the collection being subsampled
        (like `stop` in builtin `range`).
    :param int subsample_size: Size of minibatches used in subsampling.
        Defaults to `size`.
    :param subsample: Optional custom subsample for user-defined subsampling
        schemes. If specified, then `subsample_size` will be set to
        `len(subsample)`.
    :type subsample: Anything supporting `len()`.
    :param bool use_cuda: Optional bool specifying whether to use cuda tensors
        for `subsample` and `log_pdf`. Defaults to `torch.Tensor.is_cuda`.
    :return: A context manager yielding a single 1-dimensional `torch.Tensor`
        of indices.

    Examples::

        # This version simply declares independence:
        >>> with iarange('data'):
                observe('obs', normal, data, mu, sigma)

        # This version subsamples data in vectorized way:
        >>> with iarange('data', 100, subsample_size=10) as ind:
                observe('obs', normal, data.index_select(0, ind), mu, sigma)

        # This wraps a user-defined subsampling method for use in pyro:
        >>> ind = my_custom_subsample
        >>> with iarange('data', 100, subsample=ind):
                observe('obs', normal, data.index_select(0, ind), mu, sigma)

    See `SVI Part II <http://pyro.ai/examples/svi_part_ii.html>`_ for an
    extended discussion.
    """
    subsample, scale = _subsample(name, size, subsample_size, subsample, use_cuda)
    if len(_PYRO_STACK) == 0:
        yield subsample
    else:
        with poutine.scale(None, scale):
            with poutine.indep(None, name, vectorized=True):
                yield subsample