Esempio n. 1
0
def test_beta_binomial_dependent_sample():
    total = 10
    counts = dist.Binomial(total, 0.3).sample()
    concentration1 = torch.tensor(0.5)
    concentration0 = torch.tensor(1.5)

    prior = dist.Beta(concentration1, concentration0)
    posterior = dist.Beta(concentration1 + counts,
                          concentration0 + total - counts)

    def model(counts):
        prob = pyro.sample("prob", prior)
        pyro.sample("counts", dist.Binomial(total, prob), obs=counts)

    reparam_model = poutine.reparam(
        model,
        {
            "prob":
            ConjugateReparam(
                lambda counts: dist.Beta(1 + counts, 1 + total - counts)),
        },
    )

    with poutine.trace() as tr, pyro.plate("particles", 10000):
        reparam_model(counts)
    samples = tr.trace.nodes["prob"]["value"]

    assert_close(samples.mean(), posterior.mean, atol=0.01)
    assert_close(samples.std(), posterior.variance.sqrt(), atol=0.01)
Esempio n. 2
0
def test_beta_binomial_hmc():
    num_samples = 1000
    total = 10
    counts = dist.Binomial(total, 0.3).sample()
    concentration1 = torch.tensor(0.5)
    concentration0 = torch.tensor(1.5)

    prior = dist.Beta(concentration1, concentration0)
    likelihood = dist.Beta(1 + counts, 1 + total - counts)
    posterior = dist.Beta(concentration1 + counts,
                          concentration0 + total - counts)

    def model():
        prob = pyro.sample("prob", prior)
        pyro.sample("counts", dist.Binomial(total, prob), obs=counts)

    reparam_model = poutine.reparam(model,
                                    {"prob": ConjugateReparam(likelihood)})

    kernel = HMC(reparam_model)
    samples = MCMC(kernel, num_samples, warmup_steps=0).run()
    pred = Predictive(reparam_model, samples, num_samples=num_samples)
    trace = pred.get_vectorized_trace()
    samples = trace.nodes["prob"]["value"]

    assert_close(samples.mean(), posterior.mean, atol=0.01)
    assert_close(samples.std(), posterior.variance.sqrt(), atol=0.01)
def guide(prior, obs, num_obs):
    a = pyro.param("a", prior["A"], constraint=constraints.positive)
    pyro.sample("p_A", dist.Beta(a[0], a[1]))
    b = pyro.param("b", prior["B"], constraint=constraints.positive)
    pyro.sample("p_B", dist.Beta(b[:, 0], b[:, 1]).to_event(1))
    c = pyro.param("c", prior["C"], constraint=constraints.positive)
    pyro.sample("p_C", dist.Beta(c[:, 0], c[:, 1]).to_event(1))
    def _update_the_policy_network(self, state_batch_tensor, advantages,
                                   action_batch_tensor):

        action_alpha_old, action_beta_old = self.policy_network(
            state_batch_tensor)
        old_beta_dist = dist.Beta(action_alpha_old, action_beta_old)
        old_action_prob = old_beta_dist.batch_log_pdf(action_batch_tensor)
        old_action_prob = old_action_prob.detach()

        if self.use_cuda:
            advantages = Variable(advantages).cuda()
        else:
            advantages = Variable(advantages)

        for _ in range(self.args.policy_update_step):
            action_alpha, action_beta = self.policy_network(state_batch_tensor)
            new_beta_dist = dist.Beta(action_alpha, action_beta)
            new_action_prob = new_beta_dist.batch_log_pdf(action_batch_tensor)
            # calculate the ratio
            ratio = torch.exp(new_action_prob - old_action_prob)

            # calculate the surr
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.args.epsilon,
                                1 + self.args.epsilon) * advantages

            loss_policy = -torch.min(surr1, surr2).mean()

            self.optimizer_policy.zero_grad()
            loss_policy.backward()
            # update
            self.optimizer_policy.step()

        return loss_policy
Esempio n. 5
0
def guide(prior, obs, num_obs):
    a = pyro.param('a', prior['A'], constraint=constraints.positive)
    pyro.sample('p_A', dist.Beta(a[0], a[1]))
    b = pyro.param('b', prior['B'], constraint=constraints.positive)
    pyro.sample('p_B', dist.Beta(b[:, 0], b[:, 1]).to_event(1))
    c = pyro.param('c', prior['C'], constraint=constraints.positive)
    pyro.sample('p_C', dist.Beta(c[:, 0], c[:, 1]).to_event(1))
Esempio n. 6
0
    def model(self, *args, **kwargs):
        I = self._data['segments']




        pi = pyro.sample("pi", dist.Dirichlet(self._params['init_probs']))


        probs_z = pyro.sample("cnv_probs",
                              dist.Dirichlet((1- self._params['t']) * torch.eye(self._params['hidden_dim']) + (
                                      self._params['t'])).to_event(1))
        probs_y =  torch.tensor([[2., 64., 32., 21.5, 16., 43.],[64., 64., 64., 64., 64., 64.]])


        z = pyro.sample("z_0", dist.Categorical(pi),
                infer={"enumerate": "parallel"})

        pyro.sample("y_{}".format(0), dist.Beta(probs_y[0, z], probs_y[1, z]),
                    obs=self._data['data'][0, 0])

        for i in pyro.markov(range(1,I)):
            z = pyro.sample("z_{}".format(i), dist.Categorical(Vindex(probs_z)[z]),
                            infer={"enumerate": "parallel"})


            pyro.sample("y_{}".format(i), dist.Beta(probs_y[0,z], probs_y[1,z]),
                        obs= self._data['data'][i,0])
Esempio n. 7
0
def model(y,x,N):
    w = pyro.sample('w'.format(''), dist.Beta(Variable(26.072914040168385*torch.ones([amb(1)])),Variable((42.3120851154)*torch.ones([amb(1)]))))
    with pyro.iarange('b_range_'.format(''), N):
        b = pyro.sample('b'.format(''), dist.Gamma(Variable((5.63887222899)*torch.ones([amb(N)])),Variable((40.1978121928)*torch.ones([amb(N)]))))
    with pyro.iarange('p_range_'.format(''), N):
        p = pyro.sample('p'.format(''), dist.Beta(Variable((52.1419233118)*torch.ones([amb(N)])),Variable((83.6618285099)*torch.ones([amb(N)]))))
    pyro.sample('obs__100'.format(), dist.Beta(w*x+b,p), obs=y)
Esempio n. 8
0
 def __init__(self, ode_op, ode_model):
     super(LNAGenModel, self).__init__()
     self._ode_op = ode_op
     self._ode_model = ode_model
     self.ode_params1 = PyroSample(dist.Beta(2, 1))
     self.ode_params2 = PyroSample(dist.HalfNormal(1))
     self.ode_params3 = PyroSample(dist.Beta(1, 2))
Esempio n. 9
0
 def _model_(self):
     control_prior = pyro.sample('control_p', dist.Beta(1, 1))
     treatment_prior = pyro.sample('treatment_p', dist.Beta(1, 1))
     return pyro.sample(
         'obs',
         dist.Binomial(self.traffic_size,
                       torch.stack([control_prior, treatment_prior])),
         obs=self.outcome)
Esempio n. 10
0
def model(data, params):
    # initialize data
    data = {k: torch.tensor(v).float() for k, v in data.items()}
    n = data["n"]
    k = data["k"]
    # model block
    theta = pyro.sample("theta", dist.Beta(1., 1.))
    thetaprior = pyro.sample("thetaprior", dist.Beta(1., 1.))
    k = pyro.sample("k", dist.Binomial(n, theta), obs=k)
Esempio n. 11
0
 def __init__(self, ode_op, ode_model):
     super(PlantModel, self).__init__()
     self._ode_op = ode_op
     self._ode_model = ode_model
     # TODO: Incorporate appropriate priors (cf. MATALB codes from Daewook)
     self.ode_params1 = PyroSample(dist.Gamma(1, 1000))  # dG
     self.ode_params2 = PyroSample(dist.Gamma(1, 1000))  # dP
     self.ode_params3 = PyroSample(dist.Beta(0.5, 0.5))  # G0
     self.ode_params4 = PyroSample(dist.Beta(0.5, 0.5))  # P0
Esempio n. 12
0
def model(prior, obs, num_obs):
    p_A = pyro.sample('p_A', dist.Beta(1, 1))
    p_B = pyro.sample('p_B', dist.Beta(torch.ones(2), torch.ones(2)).to_event(1))
    p_C = pyro.sample('p_C', dist.Beta(torch.ones(2), torch.ones(2)).to_event(1))
    with pyro.plate('data_plate', num_obs):
        A = pyro.sample('A', dist.Bernoulli(p_A.expand(num_obs)), obs=obs['A'])
        # Vindex used to ensure proper indexing into the enumerated sample sites
        B = pyro.sample('B', dist.Bernoulli(Vindex(p_B)[A.type(torch.long)]), infer={"enumerate": "parallel"})
        pyro.sample('C', dist.Bernoulli(Vindex(p_C)[B.type(torch.long)]), obs=obs['C'])
Esempio n. 13
0
    def kl(self, dist_a, prior=None):
        if prior == None:  # use standard reparamterizer
            return self._kld_beta_kerman_prior(dist_a['beta']['conc1'],
                                               dist_a['beta']['conc2'])

        # we have two distributions provided (eg: VRNN)
        return torch.sum(
            D.kl_divergence(
                PD.Beta(dist_a['beta']['conc1'], dist_a['beta']['conc2']),
                PD.Beta(prior['beta']['conc1'], prior['beta']['conc2'])), -1)
Esempio n. 14
0
def model(data):
    # define the hyperparameters that control the beta prior
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    pyro.sample("some", dist.Beta(torch.tensor(10.0), torch.tensor(5.0)))
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])
Esempio n. 15
0
def model(data, params):
    # initialize data
    data = {k: torch.tensor(v).float() for k, v in data.items()}
    n1 = data["n1"]
    n2 = data["n2"]
    k1 = data["k1"]
    k2 = data["k2"]
    # init parameters
    theta1 = pyro.sample("theta1", dist.Beta(1., 1.))
    theta2 = pyro.sample("theta2", dist.Beta(1., 1.))
    pyro.sample("k1", dist.Binomial(n1, theta1), obs=k1)
    pyro.sample("k2", dist.Binomial(n2, theta2), obs=k2)
Esempio n. 16
0
    def _kld_beta_kerman_prior(self, conc1, conc2):
        """ Internal function to do a KL-div against the prior.

        :param conc1: concentration 1.
        :param conc2: concentration 2.
        :returns: batch_size tensor of kld against prior.
        :rtype: torch.Tensor

        """
        prior = PD.Beta(zeros_like(conc1) + 1 / 3, zeros_like(conc2) + 1 / 3)
        beta = PD.Beta(conc1, conc2)
        return torch.sum(D.kl_divergence(beta, prior), -1)
Esempio n. 17
0
def test_beta_binomial(sample_shape, batch_shape):
    concentration1 = torch.randn(batch_shape).exp()
    concentration0 = torch.randn(batch_shape).exp()
    total = 10
    obs = dist.Binomial(total, 0.2).sample(sample_shape + batch_shape)

    f = dist.Beta(concentration1, concentration0)
    g = dist.Beta(1 + obs, 1 + total - obs)
    fg, log_normalizer = f.conjugate_update(g)

    x = fg.sample(sample_shape)
    assert_close(f.log_prob(x) + g.log_prob(x), fg.log_prob(x) + log_normalizer)
Esempio n. 18
0
 def model():
     d = dist.Bernoulli(p)
     context1 = pyro.plate("outer", 3, dim=-1)
     context2 = pyro.plate("inner", 2, dim=-2)
     pyro.sample("w", d)
     pyro.sample("b", dist.Beta(1.1, 1.1))
     with context1:
         pyro.sample("x", d)
     with context2:
         pyro.sample("c", dist.Beta(1.1, 1.1))
         pyro.sample("y", d)
     with context1, context2:
         pyro.sample("z", d)
def model_multi_obs_grp(obsmat):
    # some parameters can be directly derived from the data passed
    # K = 2
    nparticipants = data.shape[0]
    nfeatures = data.shape[1]  # number of rows in each person's matrix
    ncol = data.shape[2]

    # Background probability of different groups
    if tm.stickbreak:
        # stick breaking process for assigning weights to groups
        with pyro.plate("beta_plate", K - 1):
            beta_mix = pyro.sample("weights", dist.Beta(1, 10))
        weights = tm.mix_weights(beta_mix)
    else:
        weights = pyro.sample('weights',
                              dist.Dirichlet(0.5 * torch.ones(tm.K)))
    # declare model parameters based on whether the data are row-normalized
    if tm.dtype == 'norm':
        pass
#         with pyro.plate('components', K):
#             # concentration parameters
#             concentration = pyro.sample('concentration',
#                                         dist.Gamma(2 * torch.ones(nfeatures,ncol), 1/3 * torch.ones(nfeatures,ncol)).to_event(2))

#         # implementation for the dirichlet based model is not complete!!!!
#         with pyro.plat('data',obsmat.shape[0]):
#             assignment = pyro.sample('assignment', dist.Categorical(weights))
#             #d = dist.Dirichlet(concentration[assignment,:,:].clone().detach()) # .detach() might interfere with backprop
#             d = dist.Dirichlet(concentration[assignment,i,:])
#             pyro.sample('obs', d.to_event(1), obs=obsmat)

    elif tm.dtype == 'raw':
        with pyro.plate('components', tm.K):
            alphas = pyro.sample(
                'alpha',
                dist.Gamma(2 * torch.ones(nfeatures, ncol),
                           1 / 3 * torch.ones(nfeatures, ncol)).to_event(2))
            betas = pyro.sample(
                'beta',
                dist.Gamma(2 * torch.ones(nfeatures, ncol),
                           1 / 3 * torch.ones(nfeatures, ncol)).to_event(2))

        assignment = pyro.sample('assignment', dist.Categorical(weights))
        # expand assignment to make dimensions match
        for r in np.arange(obsmat.shape[0]):
            rowind = obsmat[r, 1].type(torch.long)
            colind = obsmat[r, 2].type(torch.long)
            d = dist.Beta(alphas[assignment, rowind, colind],
                          betas[assignment, rowind, colind])
            pyro.sample('obs_{}'.format(r), d, obs=obsmat[r, 0])
Esempio n. 20
0
def guide(y,x,N):
    arg_1 = torch.nn.Softplus()(pyro.param('arg_1', Variable(torch.ones((amb(1))), requires_grad=True)))
    arg_2 = torch.nn.Softplus()(pyro.param('arg_2', Variable(torch.ones((amb(1))), requires_grad=True)))
    w = pyro.sample('w'.format(''), dist.Beta(arg_1,arg_2))
    arg_3 = torch.nn.Softplus()(pyro.param('arg_3', Variable(torch.ones((amb(N))), requires_grad=True)))
    arg_4 = torch.nn.Softplus()(pyro.param('arg_4', Variable(torch.ones((amb(N))), requires_grad=True)))
    with pyro.iarange('b_prange'):
        b = pyro.sample('b'.format(''), dist.Gamma(arg_3,arg_4))
    arg_5 = torch.nn.Softplus()(pyro.param('arg_5', Variable(torch.ones((amb(N))), requires_grad=True)))
    arg_6 = torch.nn.Softplus()(pyro.param('arg_6', Variable(torch.ones((amb(N))), requires_grad=True)))
    with pyro.iarange('p_prange'):
        p = pyro.sample('p'.format(''), dist.Beta(arg_5,arg_6))
    
    pass
Esempio n. 21
0
 def setUp(self):
     self.alpha = Variable(torch.Tensor([2.4]))
     self.beta = Variable(torch.Tensor([3.7]))
     self.test_data = Variable(torch.Tensor([0.4]))
     self.dist = dist.Beta(self.alpha, self.beta)
     self.batch_alpha = Variable(torch.Tensor([[2.4], [3.6]]))
     self.batch_beta = Variable(torch.Tensor([[3.7], [2.5]]))
     self.batch_test_data = Variable(torch.Tensor([[0.4], [0.6]]))
     self.batch_dist = dist.Beta(self.batch_alpha, self.batch_beta)
     self.analytic_mean = (self.alpha / (self.alpha + self.beta))
     one = Variable(torch.ones([1]))
     self.analytic_var = torch.pow(self.analytic_mean, 2.0) * self.beta / (
         self.alpha * (self.alpha + self.beta + one))
     self.n_samples = 50000
Esempio n. 22
0
    def mutual_info(self, params, eps=1e-9):
        """ I(z_d; x) ~ H(z_prior, z_d) + H(z_prior)

        :param params: parameters of distribution
        :param eps: tolerance
        :returns: batch_size mutual information (prop-to) tensor.
        :rtype: torch.Tensor

        """
        z_true = PD.Beta(params['beta']['conc1'], params['beta']['conc2'])
        z_match = PD.Beta(params['q_z_given_xhat']['beta']['conc1'],
                          params['q_z_given_xhat']['beta']['conc2'])
        kl_proxy_to_xent = torch.sum(D.kl_divergence(z_match, z_true), dim=-1)
        return self.config['continuous_mut_info'] * kl_proxy_to_xent
Esempio n. 23
0
def test_init():
    total = 10
    counts = dist.Binomial(total, 0.3).sample()
    concentration1 = torch.tensor(0.5)
    concentration0 = torch.tensor(1.5)

    prior = dist.Beta(concentration1, concentration0)
    likelihood = dist.Beta(1 + counts, 1 + total - counts)

    def model():
        x = pyro.sample("x", prior)
        pyro.sample("counts", dist.Binomial(total, x), obs=counts)
        return x

    check_init_reparam(model, ConjugateReparam(likelihood))
Esempio n. 24
0
def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
    assert not torch._C._get_tracing_state()
    num_sequences, max_length, data_dim = sequences.shape
    with poutine.mask(mask=include_prior):
        # Our prior on transition probabilities will be:
        # stay in the same state with 90% probability; uniformly jump to another
        # state with 10% probability.
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        # We put a weak prior on the conditional probability of a tone sounding.
        # We know that on average about 4 of 88 tones are active, so we'll set a
        # rough weak prior of 10% of the notes being active at any one time.
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2))
    # In this first model we'll sequentially iterate over sequences in a
    # minibatch; this will make it easy to reason about tensor shapes.
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    for i in pyro.plate("sequences", len(sequences), batch_size):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            # On the next line, we'll overwrite the value of x with an updated
            # value. If we wanted to record all x values, we could instead
            # write x[t] = pyro.sample(...x[t-1]...).
            x = pyro.sample("x_{}_{}".format(i, t),
                            dist.Categorical(probs_x[x]),
                            infer={"enumerate": "parallel"})
            with tones_plate:
                pyro.sample("y_{}_{}".format(i, t),
                            dist.Bernoulli(probs_y[x.squeeze(-1)]),
                            obs=sequence[t])
Esempio n. 25
0
def model_3(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with poutine.mask(mask=include_prior):
        probs_w = pyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        w, x = 0, 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample("w_{}".format(t),
                                dist.Categorical(probs_w[w]),
                                infer={"enumerate": "parallel"})
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                with tones_plate as tones:
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y[w, x, tones]),
                                obs=sequences[batch, t])
Esempio n. 26
0
def model_2(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x, y = 0, 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                # Note the broadcasting tricks here: to index probs_y on tensors x and y,
                # we also need a final tensor for the tones dimension. This is conveniently
                # provided by the plate associated with that dimension.
                with tones_plate as tones:
                    y = pyro.sample("y_{}".format(t),
                                    dist.Bernoulli(probs_y[x, y, tones]),
                                    obs=sequences[batch, t]).long()
Esempio n. 27
0
 def __init__(self, ode_op, ode_model):
     super(ProteinGenModel, self).__init__()
     self._ode_op = ode_op
     self._ode_model = ode_model
     self.ode_params = PyroSample(
         dist.Beta(torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
                   2.0).to_event(1))
Esempio n. 28
0
def _beta_bernoulli(data):
    alpha = torch.tensor([1.1, 1.1])
    beta = torch.tensor([1.1, 1.1])
    p_latent = pyro.sample('p_latent', dist.Beta(alpha, beta))
    with pyro.plate('data', data.shape[0], dim=-2):
        pyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
    return p_latent
Esempio n. 29
0
def negative_binomial_guide(
    num_sites, num_days, num_predictors, predictors, data=None
):
    # Parameters for p
    alpha_0 = pyro.param(
        "alpha_0", 5 * torch.ones(num_sites), constraint=positive
    )
    alpha_1 = pyro.param(
        "alpha_1", 5 * torch.ones(num_sites), constraint=positive
    )

    means_epsilon = pyro.param("means_epsilon", torch.zeros(num_sites))
    means_betas = pyro.param("means_betas", torch.zeros(num_predictors))
    variance_epsilon = pyro.param(
        "variance_epsilon", torch.ones(num_sites), constraint=positive
    )
    variance_betas = pyro.param(
        "variance_beta", torch.ones(num_predictors), constraint=positive
    )

    with plate("beta_plates", num_predictors):
        pyro.sample("betas", dist.Normal(means_betas, variance_betas))

    with plate("sites", size=num_sites):
        pyro.sample("epsilon", dist.Normal(means_epsilon, variance_epsilon))
        pyro.sample("p", dist.Beta(alpha_0, alpha_1))
Esempio n. 30
0
 def model(data):
     y_prob = pyro.sample("y_prob", dist.Beta(1.0, 1.0))
     y = pyro.sample("y", dist.Bernoulli(y_prob))
     with pyro.plate("data", data.shape[0]):
         z = pyro.sample("z", dist.Bernoulli(0.65 * y + 0.1))
         pyro.sample("obs", dist.Normal(2.0 * z, 1.0), obs=data)
     pyro.sample("nuisance", dist.Bernoulli(0.3))