示例#1
0
def test_batch_log_pdf_mask(dist):
    if dist.get_test_distribution_name() not in ('Normal', 'Bernoulli',
                                                 'Categorical'):
        pytest.skip('Batch pdf masking not supported for the distribution.')
    d = dist.pyro_dist
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            batch_pdf_shape = d.batch_shape(**dist_params) + (1, )
            batch_pdf_shape_broadcasted = d.batch_shape(x, **
                                                        dist_params) + (1, )
            zeros_mask = ng_zeros(1)  # should be broadcasted to data dims
            ones_mask = ng_ones(
                batch_pdf_shape)  # should be broadcasted to data dims
            half_mask = ng_ones(1) * 0.5
            batch_log_pdf = d.batch_log_pdf(x, **dist_params)
            batch_log_pdf_zeros_mask = d.batch_log_pdf(x,
                                                       log_pdf_mask=zeros_mask,
                                                       **dist_params)
            batch_log_pdf_ones_mask = d.batch_log_pdf(x,
                                                      log_pdf_mask=ones_mask,
                                                      **dist_params)
            batch_log_pdf_half_mask = d.batch_log_pdf(x,
                                                      log_pdf_mask=half_mask,
                                                      **dist_params)
            assert_equal(batch_log_pdf_ones_mask, batch_log_pdf)
            assert_equal(batch_log_pdf_zeros_mask,
                         ng_zeros(batch_pdf_shape_broadcasted))
            assert_equal(batch_log_pdf_half_mask, 0.5 * batch_log_pdf)
示例#2
0
 def model():
     latent = named.Object("latent")
     latent.list = named.List()
     mu = latent.list.add().mu.param_(Variable(torch.zeros(1)))
     latent.dict = named.Dict()
     foo = latent.dict["foo"].foo.sample_(dist.normal, mu, ng_ones(1))
     latent.object.bar.observe_(dist.normal, foo, mu, ng_ones(1))
示例#3
0
 def model():
     mu_latent = pyro.sample("mu_latent", dist.normal, self.mu0,
                             torch.pow(self.tau0, -0.5))
     bijector = AffineExp(torch.pow(self.tau, -0.5), mu_latent)
     x_dist = TransformedDistribution(dist.normal, bijector)
     pyro.observe("obs0", x_dist, self.data[0], ng_zeros(1), ng_ones(1))
     pyro.observe("obs1", x_dist, self.data[1], ng_zeros(1), ng_ones(1))
     return mu_latent
示例#4
0
 def model():
     mu_latent = pyro.sample("mu_latent", dist.normal,
                             self.mu0, torch.pow(self.tau0, -0.5))
     bijector = AffineExp(torch.pow(self.tau, -0.5), mu_latent)
     x_dist = TransformedDistribution(dist.normal, bijector)
     pyro.observe("obs0", x_dist, self.data[0], ng_zeros(1), ng_ones(1))
     pyro.observe("obs1", x_dist, self.data[1], ng_zeros(1), ng_ones(1))
     return mu_latent
示例#5
0
 def obs_inner(i, _i, _x):
     for k in range(n_superfluous_top):
         pyro.sample("z_%d_%d" % (i, k),
                     dist.Normal(ng_zeros(4 - i, 1), ng_ones(4 - i, 1), reparameterized=False))
     pyro.observe("obs_%d" % i, dist.normal, _x, mu_latent, torch.pow(self.lam, -0.5))
     for k in range(n_superfluous_top, n_superfluous_top + n_superfluous_bottom):
         pyro.sample("z_%d_%d" % (i, k),
                     dist.Normal(ng_zeros(4 - i, 1), ng_ones(4 - i, 1), reparameterized=False))
示例#6
0
        def model(batch_size_outer=2, batch_size_inner=2):
            mu_latent = pyro.sample("mu_latent", dist.normal, ng_zeros(1), ng_ones(1))

            def outer(i, x):
                pyro.map_data("map_inner_%d" % i, x, lambda _i, _x:
                              inner(i, _i, _x), batch_size=batch_size_inner)

            def inner(i, _i, _x):
                pyro.sample("z_%d_%d" % (i, _i), dist.normal, mu_latent + _x, ng_ones(1))

            pyro.map_data("map_outer", [[ng_ones(1)] * 2] * 2, lambda i, x:
                          outer(i, x), batch_size=batch_size_outer)

            return mu_latent
示例#7
0
        def model(batch_size_outer=2, batch_size_inner=2):
            mu_latent = pyro.sample("mu_latent", dist.normal, ng_zeros(1), ng_ones(1))

            def outer(i, x):
                pyro.map_data("map_inner_%d" % i, x, lambda _i, _x:
                              inner(i, _i, _x), batch_size=batch_size_inner)

            def inner(i, _i, _x):
                pyro.sample("z_%d_%d" % (i, _i), dist.normal, mu_latent + _x, ng_ones(1))

            pyro.map_data("map_outer", [[ng_ones(1)] * 2] * 2, lambda i, x:
                          outer(i, x), batch_size=batch_size_outer)

            return mu_latent
示例#8
0
文件: mf.py 项目: makora9143/fmvae
    def model(self, x):
        u_mu = ng_zeros(self.U_size)
        u_sigma = ng_ones(self.U_size)

        U = pyro.sample('u', dist.normal, u_mu, u_sigma)

        v_mu = ng_zeros(self.V_size)
        v_sigma = ng_ones(self.V_size)

        V = pyro.sample('v', dist.normal, v_mu, v_sigma)
        pyro.observe('x', dist.bernoulli,
                     x,
                     torch.matmul(U, torch.t(V)),
                     # ng_ones(x.size(), type_as=x.data)
                     )
示例#9
0
 def model(self, data):
     decoder = pyro.module('decoder', self.vae_decoder)
     z_mean, z_std = ng_zeros([data.size(0),
                               20]), ng_ones([data.size(0), 20])
     z = pyro.sample('latent', Normal(z_mean, z_std))
     img = decoder.forward(z)
     pyro.sample('obs', Bernoulli(img), obs=data.view(-1, 784))
示例#10
0
    def model(self):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        # Setup hyperparameters for prior p(z)
        z_mu = ng_zeros([self.n_samples, self.n_latent])
        z_sigma = ng_ones([self.n_samples, self.n_latent])
        # sample from prior
        z = pyro.sample("latent", dist.normal, z_mu, z_sigma)
        # decode the latent code z
        z_adj = self.decoder(z)

        # Subsampling
        if self.subsampling:
            with pyro.iarange("data",
                              self.n_subsample,
                              subsample=self.sample()) as ind:
                pyro.observe('obs', dist.bernoulli,
                             self.adj_labels.view(1, -1)[0][ind],
                             z_adj.view(1, -1)[0][ind])
        # Reweighting
        else:
            with pyro.iarange("data"):
                pyro.observe('obs',
                             weighted_bernoulli,
                             self.adj_labels.view(1, -1),
                             z_adj.view(1, -1),
                             weight=self.pos_weight)
示例#11
0
    def test_do_propagation(self):
        pyro.clear_param_store()

        def model():
            z = pyro.sample("z", Normal(10.0 * ng_ones(1), 0.0001 * ng_ones(1)))
            latent_prob = torch.exp(z) / (torch.exp(z) + ng_ones(1))
            flip = pyro.sample("flip", Bernoulli(latent_prob))
            return flip

        sample_from_model = model()
        z_data = {"z": -10.0 * ng_ones(1)}
        # under model flip = 1 with high probability; so do indirect DO surgery to make flip = 0
        sample_from_do_model = poutine.trace(poutine.do(model, data=z_data))()

        assert eq(sample_from_model, ng_ones(1))
        assert eq(sample_from_do_model, ng_zeros(1))
示例#12
0
def test_dirichlet_shape():
    alpha = ng_ones(3, 2) / 2
    d = dist.Dirichlet(alpha)
    assert d.batch_shape() == (3, )
    assert d.event_shape() == (2, )
    assert d.shape() == (3, 2)
    assert d.sample().size() == d.shape()
示例#13
0
        def guide():
            mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.094 * torch.ones(2),
                                               requires_grad=True))
            log_sig_q = pyro.param("log_sig_q", Variable(
                                   self.analytic_log_sig_n.data - 0.11 * torch.ones(2), requires_grad=True))
            sig_q = torch.exp(log_sig_q)
            trivial_baseline = pyro.module("mu_baseline", pt_mu_baseline, tags="baseline")
            baseline_value = trivial_baseline(ng_ones(1))
            mu_latent = pyro.sample("mu_latent",
                                    dist.Normal(mu_q, sig_q, reparameterized=False),
                                    baseline=dict(baseline_value=baseline_value))

            def obs_inner(i, _i, _x):
                for k in range(n_superfluous_top + n_superfluous_bottom):
                    z_baseline = pyro.module("z_baseline_%d_%d" % (i, k),
                                             pt_superfluous_baselines[3 * k + i], tags="baseline")
                    baseline_value = z_baseline(mu_latent.detach()).unsqueeze(-1)
                    mean_i = pyro.param("mean_%d_%d" % (i, k),
                                        Variable(0.5 * torch.ones(4 - i, 1), requires_grad=True))
                    pyro.sample("z_%d_%d" % (i, k),
                                dist.Normal(mean_i, ng_ones(4 - i, 1), reparameterized=False),
                                baseline=dict(baseline_value=baseline_value))

            def obs_outer(i, x):
                pyro.map_data("map_obs_inner_%d" % i, x, lambda _i, _x:
                              obs_inner(i, _i, _x), batch_size=4 - i)

            pyro.map_data("map_obs_outer", [self.data_tensor[0:4, :], self.data_tensor[4:7, :],
                                            self.data_tensor[7:9, :]],
                          lambda i, x: obs_outer(i, x), batch_size=3)

            return mu_latent
示例#14
0
def test_categorical_shape():
    ps = ng_ones(3, 2) / 2
    d = dist.Categorical(ps)
    assert d.batch_shape() == (3, )
    assert d.event_shape() == (1, )
    assert d.shape() == (3, 1)
    assert d.sample().size() == d.shape()
示例#15
0
def test_dirichlet_shape():
    alpha = ng_ones(3, 2) / 2
    d = dist.Dirichlet(alpha)
    assert d.batch_shape() == (3,)
    assert d.event_shape() == (2,)
    assert d.shape() == (3, 2)
    assert d.sample().size() == d.shape()
示例#16
0
    def test_do_propagation(self):
        pyro.clear_param_store()

        def model():
            z = pyro.sample("z", Normal(10.0 * ng_ones(1), 0.0001 * ng_ones(1)))
            latent_prob = torch.exp(z) / (torch.exp(z) + ng_ones(1))
            flip = pyro.sample("flip", Bernoulli(latent_prob))
            return flip

        sample_from_model = model()
        z_data = {"z": -10.0 * ng_ones(1)}
        # under model flip = 1 with high probability; so do indirect DO surgery to make flip = 0
        sample_from_do_model = poutine.trace(poutine.do(model, data=z_data))()

        assert eq(sample_from_model, ng_ones(1))
        assert eq(sample_from_do_model, ng_zeros(1))
示例#17
0
    def model(self, data):
        decoder = pyro.module('decoder', self.decoder)

        # Normal prior
        if self.prior == 0:
            z_mu, z_sigma = ng_zeros([data.size(0), self.z_dim
                                      ]), ng_ones([data.size(0), self.z_dim])

            if self.cuda_mode:
                z_mu, z_sigma = z_mu.cuda(), z_sigma.cuda()

            z = pyro.sample("latent", dist.normal, z_mu, z_sigma)

        elif self.prior == 1:
            z_mu, z_sigma = self.vampprior()

            z_mu_avg = torch.mean(z_mu, 0)

            z_sigma_square = z_sigma * z_sigma
            z_sigma_square_avg = torch.mean(z_sigma_square, 0)
            z_sigma_avg = torch.sqrt(z_sigma_square_avg)

            z_mu_avg = z_mu_avg.expand(data.size(0), z_mu_avg.size(0))
            z_sigma_avg = z_sigma_avg.expand(data.size(0), z_sigma_avg.size(0))

            z = pyro.sample("latent", dist.normal, z_mu_avg, z_sigma_avg)

        x1, x2, t, y = decoder.forward(
            z, data[:, :len(self.dataset.binfeats)],
            data[:,
                 len(self.dataset.binfeats):(len(self.dataset.binfeats) +
                                             len(self.dataset.contfeats))],
            data[:, -2], data[:, -1])
示例#18
0
def test_one_hot_categorical_shape():
    ps = ng_ones(3, 2) / 2
    d = dist.OneHotCategorical(ps)
    assert d.batch_shape() == (3, )
    assert d.event_shape() == (2, )
    assert d.shape() == (3, 2)
    assert d.sample().size() == d.shape()
示例#19
0
def test_score_errors_event_dim_mismatch(dist):
    d = dist.pyro_dist
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        test_data_wrong_dims = ng_ones(d.shape(**dist_params) + (1,))
        with pytest.raises(ValueError):
            d.batch_log_pdf(test_data_wrong_dims, **dist_params)
示例#20
0
def test_score_errors_event_dim_mismatch(dist):
    d = dist.pyro_dist
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        test_data_wrong_dims = ng_ones(d.shape(**dist_params) + (1, ))
        with pytest.raises(ValueError):
            d.batch_log_pdf(test_data_wrong_dims, **dist_params)
示例#21
0
def main(args):
    optim = Adam({"lr": 0.1})
    inference = SVI(model, guide, optim, loss="ELBO")

    # Data is an arbitrary json-like structure with tensors at leaves.
    one = ng_ones(1)
    data = {
        "foo": one,
        "bar": [0 * one, 1 * one, 2 * one],
        "baz": {
            "noun": {
                "concrete": 4 * one,
                "abstract": 6 * one,
            },
            "verb": 2 * one,
        },
    }

    print('Step\tLoss')
    for step in range(args.num_epochs):
        if step % 100 == 0:
            loss = inference.step(data)
            print('{}\t{:0.5g}'.format(step, loss))

    print('Parameters:')
    for name in sorted(pyro.get_param_store().get_all_param_names()):
        print('{} = {}'.format(name, pyro.param(name).data.cpu().numpy()))
示例#22
0
def test_normal_shape():
    mu = ng_zeros(3, 2)
    sigma = ng_ones(3, 2)
    d = dist.Normal(mu, sigma)
    assert d.batch_shape() == (3, )
    assert d.event_shape() == (2, )
    assert d.shape() == (3, 2)
    assert d.sample().size() == d.shape()
示例#23
0
def test_normal_shape():
    mu = ng_zeros(3, 2)
    sigma = ng_ones(3, 2)
    d = dist.Normal(mu, sigma)
    assert d.batch_shape() == (3,)
    assert d.event_shape() == (2,)
    assert d.shape() == (3, 2)
    assert d.sample().size() == d.shape()
示例#24
0
    def model(self, input_variable, target_variable, step):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder_dense", self.decoder_dense)
        pyro.module("decoder_rnn", self.decoder_rnn)

        # setup hyperparameters for prior p(z)
        # the type_as ensures we get CUDA Tensors if x is on gpu
        z_mu = ng_zeros([self.num_layers, self.z_dim], type_as=target_variable.data)
        z_sigma = ng_ones([self.num_layers, self.z_dim], type_as=target_variable.data)

        # sample from prior
        # (value will be sampled by guide when computing the ELBO)
        z = pyro.sample("latent", dist.normal, z_mu, z_sigma)

        # init vars
        target_length = target_variable.shape[0]

        decoder_input = dataset.to_onehot([[self.dataset.SOS_index]])
        decoder_input = decoder_input.cuda() if USE_CUDA else decoder_input

        decoder_outputs = np.ones((target_length))
        decoder_hidden = self.decoder_dense(z)

        # # Teacher forcing
        for di in range(target_length):
            decoder_output, decoder_hidden = self.decoder_rnn(
                decoder_input, decoder_hidden)
            decoder_input = target_variable[di]

            if self.use_cuda:
                decoder_outputs[di] = np.argmax(decoder_output.cpu().data.numpy())
            else:
                decoder_outputs[di] = np.argmax(decoder_output.data.numpy())

            pyro.observe("obs_{}".format(di), dist.bernoulli, target_variable[di], decoder_output[0])

        # ----------------------------------------------------------------
        # prepare offer
        if self.use_cuda:
            offer = np.argmax(input_variable.cpu().data.numpy(), axis=1).astype(int)
        else:
            offer = np.argmax(input_variable.data.numpy(), axis=1).astype(int)

        # prepare answer
        if self.use_cuda:
            answer = np.argmax(target_variable.cpu().data.numpy(), axis=1).astype(int)
        else:
            answer = np.argmax(target_variable.data.numpy(), axis=1).astype(int)

        # prepare rnn
        rnn_response = list(map(int, decoder_outputs))
        
        # print output
        if step % 10 == 0:
            print("---------------------------")
            print("Offer: ", dataset.to_phrase(offer))
            print("Answer:", self.dataset.to_phrase(answer))
            print("RNN:", self.dataset.to_phrase(rnn_response))
示例#25
0
def test_categorical_batch_log_pdf_shape(one_hot):
    ps = ng_ones(3, 2, 4) / 4
    if one_hot:
        x = ng_zeros(3, 2, 4)
        x[:, :, 0] = 1
    else:
        x = ng_zeros(3, 2, 1)
    d = dist.Categorical(ps, one_hot=one_hot)
    assert d.batch_log_pdf(x).size() == (3, 2, 1)
示例#26
0
def test_score_errors_non_broadcastable_data_shape(dist):
    d = dist.pyro_dist
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        shape = d.shape(**dist_params)
        non_broadcastable_shape = (shape[0] + 1, ) + shape[1:]
        test_data_non_broadcastable = ng_ones(non_broadcastable_shape)
        with pytest.raises(ValueError):
            d.batch_log_pdf(test_data_non_broadcastable, **dist_params)
示例#27
0
def test_categorical_batch_log_pdf_shape(one_hot):
    ps = ng_ones(3, 2, 4) / 4
    if one_hot:
        x = ng_zeros(3, 2, 4)
        x[:, :, 0] = 1
    else:
        x = ng_zeros(3, 2, 1)
    d = dist.Categorical(ps, one_hot=one_hot)
    assert d.batch_log_pdf(x).size() == (3, 2, 1)
示例#28
0
def test_score_errors_non_broadcastable_data_shape(dist):
    d = dist.pyro_dist
    for idx in dist.get_batch_data_indices():
        dist_params = dist.get_dist_params(idx)
        shape = d.shape(**dist_params)
        non_broadcastable_shape = (shape[0] + 1,) + shape[1:]
        test_data_non_broadcastable = ng_ones(non_broadcastable_shape)
        with pytest.raises(ValueError):
            d.batch_log_pdf(test_data_non_broadcastable, **dist_params)
示例#29
0
 def obs_inner(i, _i, _x):
     for k in range(n_superfluous_top + n_superfluous_bottom):
         z_baseline = pyro.module("z_baseline_%d_%d" % (i, k),
                                  pt_superfluous_baselines[3 * k + i], tags="baseline")
         baseline_value = z_baseline(mu_latent.detach()).unsqueeze(-1)
         mean_i = pyro.param("mean_%d_%d" % (i, k),
                             Variable(0.5 * torch.ones(4 - i, 1), requires_grad=True))
         pyro.sample("z_%d_%d" % (i, k),
                     dist.Normal(mean_i, ng_ones(4 - i, 1), reparameterized=False),
                     baseline=dict(baseline_value=baseline_value))
示例#30
0
文件: vae.py 项目: makora9143/fmvae
    def model(self, x):
        pyro.module('decoder', self.decoder)

        z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
        z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)

        z = pyro.sample('latent', dist.normal, z_mu, z_sigma)

        img_mu = self.decoder(z)
        pyro.sample('obs', dist.bernoulli, img_mu, obs=x.view(-1, 784))
示例#31
0
    def model(self, x, y):
        pyro.module('decoder', self.decoder)

        i, y = y

        U_mu = ng_zeros([self.u_dim, self.z_dim])
        U_sigma = ng_ones([self.u_dim, self.z_dim])
        U = pyro.sample('U', dist.normal, U_mu, U_sigma)

        V_mu = ng_zeros([self.v_dim, self.z_dim])
        V_sigma = ng_ones([self.v_dim, self.z_dim])
        V = pyro.sample('V', dist.normal, V_mu, V_sigma)

        z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)

        z = pyro.sample('latent', dist.normal, U[i, :] * V[y, :], z_sigma)

        img_mu = self.decoder(z)
        pyro.sample('obs', dist.bernoulli, img_mu, obs=x.view(-1, 784))
示例#32
0
def test_batch_log_pdf_mask(dist):
    if dist.get_test_distribution_name() not in ('Normal', 'Bernoulli', 'Categorical'):
        pytest.skip('Batch pdf masking not supported for the distribution.')
    d = dist.pyro_dist
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            batch_pdf_shape = d.batch_shape(**dist_params) + (1,)
            batch_pdf_shape_broadcasted = d.batch_shape(x, **dist_params) + (1,)
            zeros_mask = ng_zeros(1)  # should be broadcasted to data dims
            ones_mask = ng_ones(batch_pdf_shape)  # should be broadcasted to data dims
            half_mask = ng_ones(1) * 0.5
            batch_log_pdf = d.batch_log_pdf(x, **dist_params)
            batch_log_pdf_zeros_mask = d.batch_log_pdf(x, log_pdf_mask=zeros_mask, **dist_params)
            batch_log_pdf_ones_mask = d.batch_log_pdf(x, log_pdf_mask=ones_mask, **dist_params)
            batch_log_pdf_half_mask = d.batch_log_pdf(x, log_pdf_mask=half_mask, **dist_params)
            assert_equal(batch_log_pdf_ones_mask, batch_log_pdf)
            assert_equal(batch_log_pdf_zeros_mask, ng_zeros(batch_pdf_shape_broadcasted))
            assert_equal(batch_log_pdf_half_mask, 0.5 * batch_log_pdf)
示例#33
0
def test_categorical_shape(one_hot):
    ps = ng_ones(3, 2) / 2
    d = dist.Categorical(ps, one_hot=one_hot)
    assert d.batch_shape() == (3, )
    if one_hot:
        assert d.event_shape() == (2, )
        assert d.shape() == (3, 2)
    else:
        assert d.event_shape() == (2, )
        assert d.shape() == (3, 1)
    assert d.sample().size() == d.shape()
示例#34
0
def test_categorical_shape(one_hot):
    ps = ng_ones(3, 2) / 2
    d = dist.Categorical(ps, one_hot=one_hot)
    assert d.batch_shape() == (3,)
    if one_hot:
        assert d.event_shape() == (2,)
        assert d.shape() == (3, 2)
    else:
        assert d.event_shape() == (2,)
        assert d.shape() == (3, 1)
    assert d.sample().size() == d.shape()
示例#35
0
文件: air.py 项目: Magica-Chen/pyro
def z_where_inv(z_where):
    # Take a batch of z_where vectors, and compute their "inverse".
    # That is, for each row compute:
    # [s,x,y] -> [1/s,-x/s,-y/s]
    # These are the parameters required to perform the inverse of the
    # spatial transform performed in the generative model.
    n = z_where.size(0)
    out = torch.cat((ng_ones([1, 1]).type_as(z_where).expand(n, 1), -z_where[:, 1:]), 1)
    # Divide all entries by the scale.
    out = out / z_where[:, 0:1]
    return out
示例#36
0
def z_where_inv(z_where):
    # Take a batch of z_where vectors, and compute their "inverse".
    # That is, for each row compute:
    # [s,x,y] -> [1/s,-x/s,-y/s]
    # These are the parameters required to perform the inverse of the
    # spatial transform performed in the generative model.
    n = z_where.size(0)
    out = torch.cat(
        (ng_ones([1, 1]).type_as(z_where).expand(n, 1), -z_where[:, 1:]), 1)
    # Divide all entries by the scale.
    out = out / z_where[:, 0:1]
    return out
示例#37
0
 def model(self, x):
     # register PyTorch module `decoder` with Pyro
     pyro.module("decoder", self.decoder)
     # setup hyperparameters for prior p(z)
     # the type_as ensures we get cuda Tensors if x is on gpu
     z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
     z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)
     # sample from prior (value will be sampled by guide when computing the ELBO)
     z = pyro.sample("latent", dist.normal, z_mu, z_sigma)
     # decode the latent code z
     mu_img = self.decoder.forward(z)
     # score against actual images
     pyro.observe("obs", dist.bernoulli, x.view(-1, 784), mu_img)
示例#38
0
文件: vae.py 项目: Magica-Chen/pyro
 def model(self, x):
     # register PyTorch module `decoder` with Pyro
     pyro.module("decoder", self.decoder)
     # setup hyperparameters for prior p(z)
     # the type_as ensures we get cuda Tensors if x is on gpu
     z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
     z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)
     # sample from prior (value will be sampled by guide when computing the ELBO)
     z = pyro.sample("latent", dist.normal, z_mu, z_sigma)
     # decode the latent code z
     mu_img = self.decoder.forward(z)
     # score against actual images
     pyro.observe("obs", dist.bernoulli, x.view(-1, 784), mu_img)
示例#39
0
    def model(self, x):
        # register decoder with pyro
        pyro.module("decoder", self.decoder)

        # setup hyper-parameters for prior p(z)
        z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
        z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)

        # sample from prior (value will be sampled by guide when computing the ELBO),
        # decode the latent code z,
        # and score against actual frame
        z = pyro.sample("latent", dist.normal, z_mu, z_sigma)
        mu_frame = self.decoder.forward(z)
        pyro.observe("obs", dist.bernoulli, x.view(-1, self.frame), mu_frame)
示例#40
0
 def model(self, x):
     raise NotImplementedError("don't use it plzz")
     pyro.module("decoder", self.decoders)
     # gerer params prior
     z_mu = ng_zeros([x.size(0), self.platent["dim"]], type_as=x.data)
     z_sigma = ng_ones([x.size(0), self.platent["dim"]], type_as=x.data)
     z = pyro.sample("latent",self.platent["dist"], z_mu, z_sigma)
     decoder_out = self.decoder.forward(z)
     if type(decoder_out)!=tuple:
         decoder_out = (decoder_out, )
     if pinput["dist"] == dist.normal:
         decoder_out = list(decoder_out)
         decoder_out[1] = torch.exp(decoder_out[1])
         decoder_out = tuple(decoder_out)
         
     pyro.sample("obs", self.pinput["dist"], obs=x.view(-1, self.pinput["dim"]), *decoder_out)
示例#41
0
def model_recurse(data, latent):
    if isinstance(data, Variable):
        latent.x.observe_(dist.normal, data, latent.z, ng_ones(1))
    elif isinstance(data, list):
        latent.prior_sigma.param_(Variable(torch.ones(1), requires_grad=True))
        latent.list = named.List()
        for data_i in data:
            latent_i = latent.list.add()
            latent_i.z.sample_(dist.normal, latent.z, latent.prior_sigma)
            model_recurse(data_i, latent_i)
    elif isinstance(data, dict):
        latent.prior_sigma.param_(Variable(torch.ones(1), requires_grad=True))
        latent.dict = named.Dict()
        for key, value in data.items():
            latent.dict[key].z.sample_(dist.normal, latent.z, latent.prior_sigma)
            model_recurse(value, latent.dict[key])
    else:
        raise TypeError("Unsupported type {}".format(type(data)))
示例#42
0
def test_subsample_gradient(trace_graph, reparameterized):
    pyro.clear_param_store()
    data_size = 2
    subsample_size = 1
    num_particles = 1000
    precision = 0.333
    data = dist.normal(ng_zeros(data_size), ng_ones(data_size))

    def model(subsample_size):
        with pyro.iarange("data", len(data), subsample_size) as ind:
            x = data[ind]
            z = pyro.sample("z", dist.Normal(ng_zeros(len(x)), ng_ones(len(x)),
                                             reparameterized=reparameterized))
            pyro.observe("x", dist.Normal(z, ng_ones(len(x)), reparameterized=reparameterized), x)

    def guide(subsample_size):
        mu = pyro.param("mu", lambda: Variable(torch.zeros(len(data)), requires_grad=True))
        sigma = pyro.param("sigma", lambda: Variable(torch.ones(1), requires_grad=True))
        with pyro.iarange("data", len(data), subsample_size) as ind:
            mu = mu[ind]
            sigma = sigma.expand(subsample_size)
            pyro.sample("z", dist.Normal(mu, sigma, reparameterized=reparameterized))

    optim = Adam({"lr": 0.1})
    inference = SVI(model, guide, optim, loss="ELBO",
                    trace_graph=trace_graph, num_particles=num_particles)

    # Compute gradients without subsampling.
    inference.loss_and_grads(model, guide, subsample_size=data_size)
    params = dict(pyro.get_param_store().named_parameters())
    expected_grads = {name: param.grad.data.clone() for name, param in params.items()}
    zero_grads(params.values())

    # Compute gradients with subsampling.
    inference.loss_and_grads(model, guide, subsample_size=subsample_size)
    actual_grads = {name: param.grad.data.clone() for name, param in params.items()}

    for name in sorted(params):
        print('\nexpected {} = {}'.format(name, expected_grads[name].cpu().numpy()))
        print('actual   {} = {}'.format(name, actual_grads[name].cpu().numpy()))
    assert_equal(actual_grads, expected_grads, prec=precision)
示例#43
0
文件: dmm.py 项目: Magica-Chen/pyro
 def forward(self, z_t_1):
     """
     Given the latent `z_{t-1}` corresponding to the time step t-1
     we return the mean and sigma vectors that parameterize the
     (diagonal) gaussian distribution `p(z_t | z_{t-1})`
     """
     # compute the gating function and one minus the gating function
     gate_intermediate = self.relu(self.lin_gate_z_to_hidden(z_t_1))
     gate = self.sigmoid(self.lin_gate_hidden_to_z(gate_intermediate))
     one_minus_gate = ng_ones(gate.size()).type_as(gate) - gate
     # compute the 'proposed mean'
     proposed_mean_intermediate = self.relu(self.lin_proposed_mean_z_to_hidden(z_t_1))
     proposed_mean = self.lin_proposed_mean_hidden_to_z(proposed_mean_intermediate)
     # assemble the actual mean used to sample z_t, which mixes a linear transformation
     # of z_{t-1} with the proposed mean modulated by the gating function
     mu = one_minus_gate * self.lin_z_to_mu(z_t_1) + gate * proposed_mean
     # compute the sigma used to sample z_t, using the proposed mean from above as input
     # the softplus ensures that sigma is positive
     sigma = self.softplus(self.lin_sig(self.relu(proposed_mean)))
     # return mu, sigma which can be fed into Normal
     return mu, sigma
示例#44
0
 def model_dup():
     pyro.param("mu_q", Variable(torch.ones(1), requires_grad=True))
     pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
示例#45
0
def test_bernoulli_batch_log_pdf_shape():
    ps = ng_ones(3, 2)
    x = ng_ones(3, 2)
    d = dist.Bernoulli(ps)
    assert d.batch_log_pdf(x).size() == (3, 1)
示例#46
0
def test_normal_batch_log_pdf_shape():
    mu = ng_zeros(3, 2)
    sigma = ng_ones(3, 2)
    x = ng_zeros(3, 2)
    d = dist.Normal(mu, sigma)
    assert d.batch_log_pdf(x).size() == (3, 1)
示例#47
0
 def model(subsample_size):
     with pyro.iarange("data", len(data), subsample_size) as ind:
         x = data[ind]
         z = pyro.sample("z", dist.Normal(ng_zeros(len(x)), ng_ones(len(x)),
                                          reparameterized=reparameterized))
         pyro.observe("x", dist.Normal(z, ng_ones(len(x)), reparameterized=reparameterized), x)
示例#48
0
 def inner(i, _i, _x):
     pyro.sample("z_%d_%d" % (i, _i), dist.normal, mu_latent + _x, ng_ones(1))
示例#49
0
文件: air.py 项目: Magica-Chen/pyro
 def ng_ones(self, *args, **kwargs):
     t = ng_ones(*args, **kwargs)
     if self.use_cuda:
         t = t.cuda()
     return t
示例#50
0
 def model_obs_dup():
     pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
     pyro.observe("mu_q", dist.normal, ng_zeros(1), ng_ones(1), ng_zeros(1))
示例#51
0
 def model():
     pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
示例#52
0
    def do_elbo_test(self, repa1, repa2, n_steps, prec, lr, use_nn_baseline, use_decaying_avg_baseline):
        if self.verbose:
            print(" - - - - - DO NORMALNORMALNORMAL ELBO TEST - - - - - -")
            print("[reparameterized = %s, %s; nn_baseline = %s, decaying_baseline = %s]" %
                  (repa1, repa2, use_nn_baseline, use_decaying_avg_baseline))
        pyro.clear_param_store()

        if use_nn_baseline:

            class VanillaBaselineNN(nn.Module):
                def __init__(self, dim_input, dim_h):
                    super(VanillaBaselineNN, self).__init__()
                    self.lin1 = nn.Linear(dim_input, dim_h)
                    self.lin2 = nn.Linear(dim_h, 1)
                    self.sigmoid = nn.Sigmoid()

                def forward(self, x):
                    h = self.sigmoid(self.lin1(x))
                    return self.lin2(h)

            mu_prime_baseline = pyro.module("mu_prime_baseline", VanillaBaselineNN(2, 5), tags="baseline")
        else:
            mu_prime_baseline = None

        def model():
            mu_latent_prime = pyro.sample(
                    "mu_latent_prime",
                    dist.Normal(self.mu0, torch.pow(self.lam0, -0.5), reparameterized=repa1))
            mu_latent = pyro.sample(
                    "mu_latent",
                    dist.Normal(mu_latent_prime, torch.pow(self.lam0, -0.5), reparameterized=repa2))
            for i, x in enumerate(self.data):
                pyro.observe("obs_%d" % i, dist.normal, x, mu_latent,
                             torch.pow(self.lam, -0.5))
            return mu_latent

        # note that the exact posterior is not mean field!
        def guide():
            mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.334 * torch.ones(2),
                                               requires_grad=True))
            log_sig_q = pyro.param("log_sig_q", Variable(
                                   self.analytic_log_sig_n.data - 0.29 * torch.ones(2),
                                   requires_grad=True))
            mu_q_prime = pyro.param("mu_q_prime", Variable(torch.Tensor([-0.34, 0.52]),
                                    requires_grad=True))
            kappa_q = pyro.param("kappa_q", Variable(torch.Tensor([0.74]),
                                 requires_grad=True))
            log_sig_q_prime = pyro.param("log_sig_q_prime",
                                         Variable(-0.5 * torch.log(1.2 * self.lam0.data),
                                                  requires_grad=True))
            sig_q, sig_q_prime = torch.exp(log_sig_q), torch.exp(log_sig_q_prime)
            mu_latent_dist = dist.Normal(mu_q, sig_q, reparameterized=repa2)
            mu_latent = pyro.sample("mu_latent", mu_latent_dist,
                                    baseline=dict(use_decaying_avg_baseline=use_decaying_avg_baseline))
            mu_latent_prime_dist = dist.Normal(kappa_q.expand_as(mu_latent) * mu_latent + mu_q_prime,
                                               sig_q_prime,
                                               reparameterized=repa1)
            pyro.sample("mu_latent_prime",
                        mu_latent_prime_dist,
                        baseline=dict(nn_baseline=mu_prime_baseline,
                                      nn_baseline_input=mu_latent,
                                      use_decaying_avg_baseline=use_decaying_avg_baseline))

            return mu_latent

        # optim = Optimize(model, guide,
        #                 torch.optim.Adam, {"lr": lr, "betas": (0.97, 0.999)},
        #                 loss="ELBO", trace_graph=True,
        #                 auxiliary_optim_constructor=torch.optim.Adam,
        #                 auxiliary_optim_args={"lr": 5.0 * lr, "betas": (0.90, 0.999)})

        adam = optim.Adam({"lr": .0015, "betas": (0.97, 0.999)})
        svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)

        for k in range(n_steps):
            svi.step()

            mu_error = param_mse("mu_q", self.analytic_mu_n)
            log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n)
            mu_prime_error = param_mse("mu_q_prime", 0.5 * self.mu0)
            kappa_error = param_mse("kappa_q", 0.5 * ng_ones(1))
            log_sig_prime_error = param_mse("log_sig_q_prime", -0.5 * torch.log(2.0 * self.lam0))

            if k % 500 == 0 and self.verbose:
                print("errors:  %.4f, %.4f" % (mu_error, log_sig_error), end='')
                print(", %.4f, %.4f" % (mu_prime_error, log_sig_prime_error), end='')
                print(", %.4f" % kappa_error)

        self.assertEqual(0.0, mu_error, prec=prec)
        self.assertEqual(0.0, log_sig_error, prec=prec)
        self.assertEqual(0.0, mu_prime_error, prec=prec)
        self.assertEqual(0.0, log_sig_prime_error, prec=prec)
        self.assertEqual(0.0, kappa_error, prec=prec)
示例#53
0
 def model():
     z = pyro.sample("z", Normal(10.0 * ng_ones(1), 0.0001 * ng_ones(1)))
     latent_prob = torch.exp(z) / (torch.exp(z) + ng_ones(1))
     flip = pyro.sample("flip", Bernoulli(latent_prob))
     return flip
示例#54
0
 def model(self, data):
     decoder = pyro.module('decoder', self.vae_decoder)
     z_mean, z_std = ng_zeros([data.size(0), 20]), ng_ones([data.size(0), 20])
     z = pyro.sample('latent', Normal(z_mean, z_std))
     img = decoder.forward(z)
     pyro.observe('obs', Bernoulli(img), data.view(-1, 784))