Пример #1
0
def test_batch_log_pdf_mask(dist):
    if dist.get_test_distribution_name() not in ('Normal', 'Bernoulli',
                                                 'Categorical'):
        pytest.skip('Batch pdf masking not supported for the distribution.')
    d = dist.pyro_dist
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            batch_pdf_shape = d.batch_shape(**dist_params) + (1, )
            batch_pdf_shape_broadcasted = d.batch_shape(x, **
                                                        dist_params) + (1, )
            zeros_mask = ng_zeros(1)  # should be broadcasted to data dims
            ones_mask = ng_ones(
                batch_pdf_shape)  # should be broadcasted to data dims
            half_mask = ng_ones(1) * 0.5
            batch_log_pdf = d.batch_log_pdf(x, **dist_params)
            batch_log_pdf_zeros_mask = d.batch_log_pdf(x,
                                                       log_pdf_mask=zeros_mask,
                                                       **dist_params)
            batch_log_pdf_ones_mask = d.batch_log_pdf(x,
                                                      log_pdf_mask=ones_mask,
                                                      **dist_params)
            batch_log_pdf_half_mask = d.batch_log_pdf(x,
                                                      log_pdf_mask=half_mask,
                                                      **dist_params)
            assert_equal(batch_log_pdf_ones_mask, batch_log_pdf)
            assert_equal(batch_log_pdf_zeros_mask,
                         ng_zeros(batch_pdf_shape_broadcasted))
            assert_equal(batch_log_pdf_half_mask, 0.5 * batch_log_pdf)
Пример #2
0
 def model():
     mu_latent = pyro.sample("mu_latent", dist.normal,
                             self.mu0, torch.pow(self.tau0, -0.5))
     bijector = AffineExp(torch.pow(self.tau, -0.5), mu_latent)
     x_dist = TransformedDistribution(dist.normal, bijector)
     pyro.observe("obs0", x_dist, self.data[0], ng_zeros(1), ng_ones(1))
     pyro.observe("obs1", x_dist, self.data[1], ng_zeros(1), ng_ones(1))
     return mu_latent
Пример #3
0
 def obs_inner(i, _i, _x):
     for k in range(n_superfluous_top):
         pyro.sample("z_%d_%d" % (i, k),
                     dist.Normal(ng_zeros(4 - i, 1), ng_ones(4 - i, 1), reparameterized=False))
     pyro.observe("obs_%d" % i, dist.normal, _x, mu_latent, torch.pow(self.lam, -0.5))
     for k in range(n_superfluous_top, n_superfluous_top + n_superfluous_bottom):
         pyro.sample("z_%d_%d" % (i, k),
                     dist.Normal(ng_zeros(4 - i, 1), ng_ones(4 - i, 1), reparameterized=False))
Пример #4
0
 def model():
     mu_latent = pyro.sample("mu_latent", dist.normal, self.mu0,
                             torch.pow(self.tau0, -0.5))
     bijector = AffineExp(torch.pow(self.tau, -0.5), mu_latent)
     x_dist = TransformedDistribution(dist.normal, bijector)
     pyro.observe("obs0", x_dist, self.data[0], ng_zeros(1), ng_ones(1))
     pyro.observe("obs1", x_dist, self.data[1], ng_zeros(1), ng_ones(1))
     return mu_latent
Пример #5
0
def test_categorical_batch_log_pdf_shape(one_hot):
    ps = ng_ones(3, 2, 4) / 4
    if one_hot:
        x = ng_zeros(3, 2, 4)
        x[:, :, 0] = 1
    else:
        x = ng_zeros(3, 2, 1)
    d = dist.Categorical(ps, one_hot=one_hot)
    assert d.batch_log_pdf(x).size() == (3, 2, 1)
Пример #6
0
def test_categorical_batch_log_pdf_shape(one_hot):
    ps = ng_ones(3, 2, 4) / 4
    if one_hot:
        x = ng_zeros(3, 2, 4)
        x[:, :, 0] = 1
    else:
        x = ng_zeros(3, 2, 1)
    d = dist.Categorical(ps, one_hot=one_hot)
    assert d.batch_log_pdf(x).size() == (3, 2, 1)
Пример #7
0
    def model(self, x):
        u_mu = ng_zeros(self.U_size)
        u_sigma = ng_ones(self.U_size)

        U = pyro.sample('u', dist.normal, u_mu, u_sigma)

        v_mu = ng_zeros(self.V_size)
        v_sigma = ng_ones(self.V_size)

        V = pyro.sample('v', dist.normal, v_mu, v_sigma)
        pyro.observe('x', dist.bernoulli,
                     x,
                     torch.matmul(U, torch.t(V)),
                     # ng_ones(x.size(), type_as=x.data)
                     )
Пример #8
0
 def model(self, data):
     decoder = pyro.module('decoder', self.vae_decoder)
     z_mean, z_std = ng_zeros([data.size(0),
                               20]), ng_ones([data.size(0), 20])
     z = pyro.sample('latent', Normal(z_mean, z_std))
     img = decoder.forward(z)
     pyro.sample('obs', Bernoulli(img), obs=data.view(-1, 784))
Пример #9
0
    def model(self):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        # Setup hyperparameters for prior p(z)
        z_mu = ng_zeros([self.n_samples, self.n_latent])
        z_sigma = ng_ones([self.n_samples, self.n_latent])
        # sample from prior
        z = pyro.sample("latent", dist.normal, z_mu, z_sigma)
        # decode the latent code z
        z_adj = self.decoder(z)

        # Subsampling
        if self.subsampling:
            with pyro.iarange("data",
                              self.n_subsample,
                              subsample=self.sample()) as ind:
                pyro.observe('obs', dist.bernoulli,
                             self.adj_labels.view(1, -1)[0][ind],
                             z_adj.view(1, -1)[0][ind])
        # Reweighting
        else:
            with pyro.iarange("data"):
                pyro.observe('obs',
                             weighted_bernoulli,
                             self.adj_labels.view(1, -1),
                             z_adj.view(1, -1),
                             weight=self.pos_weight)
Пример #10
0
 def setUp(self):
     # normal-normal; known covariance
     self.lam0 = Variable(torch.Tensor([0.1, 0.1]))  # precision of prior
     self.mu0 = Variable(torch.Tensor([0.0, 0.5]))  # prior mean
     # known precision of observation noise
     self.lam = Variable(torch.Tensor([6.0, 4.0]))
     self.n_outer = 3
     self.n_inner = 3
     self.n_data = Variable(torch.Tensor([self.n_outer * self.n_inner]))
     self.data = []
     self.sum_data = ng_zeros(2)
     for _out in range(self.n_outer):
         data_in = []
         for _in in range(self.n_inner):
             data_in.append(
                 Variable(
                     torch.Tensor([-0.1, 0.3]) +
                     torch.randn(2) / torch.sqrt(self.lam.data)))
             self.sum_data += data_in[-1]
         self.data.append(data_in)
     self.analytic_lam_n = self.lam0 + self.n_data.expand_as(
         self.lam) * self.lam
     self.analytic_log_sig_n = -0.5 * torch.log(self.analytic_lam_n)
     self.analytic_mu_n = self.sum_data * (self.lam / self.analytic_lam_n) +\
         self.mu0 * (self.lam0 / self.analytic_lam_n)
Пример #11
0
    def model(self, data):
        decoder = pyro.module('decoder', self.decoder)

        # Normal prior
        if self.prior == 0:
            z_mu, z_sigma = ng_zeros([data.size(0), self.z_dim
                                      ]), ng_ones([data.size(0), self.z_dim])

            if self.cuda_mode:
                z_mu, z_sigma = z_mu.cuda(), z_sigma.cuda()

            z = pyro.sample("latent", dist.normal, z_mu, z_sigma)

        elif self.prior == 1:
            z_mu, z_sigma = self.vampprior()

            z_mu_avg = torch.mean(z_mu, 0)

            z_sigma_square = z_sigma * z_sigma
            z_sigma_square_avg = torch.mean(z_sigma_square, 0)
            z_sigma_avg = torch.sqrt(z_sigma_square_avg)

            z_mu_avg = z_mu_avg.expand(data.size(0), z_mu_avg.size(0))
            z_sigma_avg = z_sigma_avg.expand(data.size(0), z_sigma_avg.size(0))

            z = pyro.sample("latent", dist.normal, z_mu_avg, z_sigma_avg)

        x1, x2, t, y = decoder.forward(
            z, data[:, :len(self.dataset.binfeats)],
            data[:,
                 len(self.dataset.binfeats):(len(self.dataset.binfeats) +
                                             len(self.dataset.contfeats))],
            data[:, -2], data[:, -1])
Пример #12
0
def test_normal_shape():
    mu = ng_zeros(3, 2)
    sigma = ng_ones(3, 2)
    d = dist.Normal(mu, sigma)
    assert d.batch_shape() == (3, )
    assert d.event_shape() == (2, )
    assert d.shape() == (3, 2)
    assert d.sample().size() == d.shape()
Пример #13
0
def test_normal_shape():
    mu = ng_zeros(3, 2)
    sigma = ng_ones(3, 2)
    d = dist.Normal(mu, sigma)
    assert d.batch_shape() == (3,)
    assert d.event_shape() == (2,)
    assert d.shape() == (3, 2)
    assert d.sample().size() == d.shape()
Пример #14
0
    def model(self, input_variable, target_variable, step):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder_dense", self.decoder_dense)
        pyro.module("decoder_rnn", self.decoder_rnn)

        # setup hyperparameters for prior p(z)
        # the type_as ensures we get CUDA Tensors if x is on gpu
        z_mu = ng_zeros([self.num_layers, self.z_dim], type_as=target_variable.data)
        z_sigma = ng_ones([self.num_layers, self.z_dim], type_as=target_variable.data)

        # sample from prior
        # (value will be sampled by guide when computing the ELBO)
        z = pyro.sample("latent", dist.normal, z_mu, z_sigma)

        # init vars
        target_length = target_variable.shape[0]

        decoder_input = dataset.to_onehot([[self.dataset.SOS_index]])
        decoder_input = decoder_input.cuda() if USE_CUDA else decoder_input

        decoder_outputs = np.ones((target_length))
        decoder_hidden = self.decoder_dense(z)

        # # Teacher forcing
        for di in range(target_length):
            decoder_output, decoder_hidden = self.decoder_rnn(
                decoder_input, decoder_hidden)
            decoder_input = target_variable[di]

            if self.use_cuda:
                decoder_outputs[di] = np.argmax(decoder_output.cpu().data.numpy())
            else:
                decoder_outputs[di] = np.argmax(decoder_output.data.numpy())

            pyro.observe("obs_{}".format(di), dist.bernoulli, target_variable[di], decoder_output[0])

        # ----------------------------------------------------------------
        # prepare offer
        if self.use_cuda:
            offer = np.argmax(input_variable.cpu().data.numpy(), axis=1).astype(int)
        else:
            offer = np.argmax(input_variable.data.numpy(), axis=1).astype(int)

        # prepare answer
        if self.use_cuda:
            answer = np.argmax(target_variable.cpu().data.numpy(), axis=1).astype(int)
        else:
            answer = np.argmax(target_variable.data.numpy(), axis=1).astype(int)

        # prepare rnn
        rnn_response = list(map(int, decoder_outputs))
        
        # print output
        if step % 10 == 0:
            print("---------------------------")
            print("Offer: ", dataset.to_phrase(offer))
            print("Answer:", self.dataset.to_phrase(answer))
            print("RNN:", self.dataset.to_phrase(rnn_response))
Пример #15
0
def reverse_sequences_torch(mini_batch, seq_lengths):
    reversed_mini_batch = ng_zeros(mini_batch.size(), type_as=mini_batch.data)
    for b in range(mini_batch.size(0)):
        T = seq_lengths[b]
        time_slice = np.arange(T - 1, -1, -1)
        time_slice = Variable(torch.cuda.LongTensor(time_slice)) if 'cuda' in mini_batch.data.type() \
            else Variable(torch.LongTensor(time_slice))
        reversed_sequence = torch.index_select(mini_batch[b, :, :], 0, time_slice)
        reversed_mini_batch[b, 0:T, :] = reversed_sequence
    return reversed_mini_batch
Пример #16
0
    def model(self, x, y):
        pyro.module('decoder', self.decoder)

        i, y = y

        U_mu = ng_zeros([self.u_dim, self.z_dim])
        U_sigma = ng_ones([self.u_dim, self.z_dim])
        U = pyro.sample('U', dist.normal, U_mu, U_sigma)

        V_mu = ng_zeros([self.v_dim, self.z_dim])
        V_sigma = ng_ones([self.v_dim, self.z_dim])
        V = pyro.sample('V', dist.normal, V_mu, V_sigma)

        z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)

        z = pyro.sample('latent', dist.normal, U[i, :] * V[y, :], z_sigma)

        img_mu = self.decoder(z)
        pyro.sample('obs', dist.bernoulli, img_mu, obs=x.view(-1, 784))
Пример #17
0
    def model(self, x):
        pyro.module('decoder', self.decoder)

        z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
        z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)

        z = pyro.sample('latent', dist.normal, z_mu, z_sigma)

        img_mu = self.decoder(z)
        pyro.sample('obs', dist.bernoulli, img_mu, obs=x.view(-1, 784))
Пример #18
0
def reverse_sequences_torch(mini_batch, seq_lengths):
    reversed_mini_batch = ng_zeros(mini_batch.size(), type_as=mini_batch.data)
    for b in range(mini_batch.size(0)):
        T = seq_lengths[b]
        time_slice = np.arange(T - 1, -1, -1)
        time_slice = Variable(torch.cuda.LongTensor(time_slice)) if 'cuda' in mini_batch.data.type() \
            else Variable(torch.LongTensor(time_slice))
        reversed_sequence = torch.index_select(mini_batch[b, :, :], 0, time_slice)
        reversed_mini_batch[b, 0:T, :] = reversed_sequence
    return reversed_mini_batch
Пример #19
0
def test_batch_log_pdf_mask(dist):
    if dist.get_test_distribution_name() not in ('Normal', 'Bernoulli', 'Categorical'):
        pytest.skip('Batch pdf masking not supported for the distribution.')
    d = dist.pyro_dist
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        x = dist.get_test_data(idx)
        with xfail_if_not_implemented():
            batch_pdf_shape = d.batch_shape(**dist_params) + (1,)
            batch_pdf_shape_broadcasted = d.batch_shape(x, **dist_params) + (1,)
            zeros_mask = ng_zeros(1)  # should be broadcasted to data dims
            ones_mask = ng_ones(batch_pdf_shape)  # should be broadcasted to data dims
            half_mask = ng_ones(1) * 0.5
            batch_log_pdf = d.batch_log_pdf(x, **dist_params)
            batch_log_pdf_zeros_mask = d.batch_log_pdf(x, log_pdf_mask=zeros_mask, **dist_params)
            batch_log_pdf_ones_mask = d.batch_log_pdf(x, log_pdf_mask=ones_mask, **dist_params)
            batch_log_pdf_half_mask = d.batch_log_pdf(x, log_pdf_mask=half_mask, **dist_params)
            assert_equal(batch_log_pdf_ones_mask, batch_log_pdf)
            assert_equal(batch_log_pdf_zeros_mask, ng_zeros(batch_pdf_shape_broadcasted))
            assert_equal(batch_log_pdf_half_mask, 0.5 * batch_log_pdf)
Пример #20
0
def expand_z_where(z_where):
    # Take a batch of three-vectors, and massages them into a batch of
    # 2x3 matrices with elements like so:
    # [s,x,y] -> [[s,0,x],
    #             [0,s,y]]
    n = z_where.size(0)
    out = torch.cat((ng_zeros([1, 1]).type_as(z_where).expand(n, 1), z_where), 1)
    ix = Variable(expansion_indices)
    if z_where.is_cuda:
        ix = ix.cuda()
    out = torch.index_select(out, 1, ix)
    out = out.view(n, 2, 3)
    return out
Пример #21
0
 def model(self, x):
     # register PyTorch module `decoder` with Pyro
     pyro.module("decoder", self.decoder)
     # setup hyperparameters for prior p(z)
     # the type_as ensures we get cuda Tensors if x is on gpu
     z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
     z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)
     # sample from prior (value will be sampled by guide when computing the ELBO)
     z = pyro.sample("latent", dist.normal, z_mu, z_sigma)
     # decode the latent code z
     mu_img = self.decoder.forward(z)
     # score against actual images
     pyro.observe("obs", dist.bernoulli, x.view(-1, 784), mu_img)
Пример #22
0
 def model(self, x):
     # register PyTorch module `decoder` with Pyro
     pyro.module("decoder", self.decoder)
     # setup hyperparameters for prior p(z)
     # the type_as ensures we get cuda Tensors if x is on gpu
     z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
     z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)
     # sample from prior (value will be sampled by guide when computing the ELBO)
     z = pyro.sample("latent", dist.normal, z_mu, z_sigma)
     # decode the latent code z
     mu_img = self.decoder.forward(z)
     # score against actual images
     pyro.observe("obs", dist.bernoulli, x.view(-1, 784), mu_img)
Пример #23
0
def expand_z_where(z_where):
    # Take a batch of three-vectors, and massages them into a batch of
    # 2x3 matrices with elements like so:
    # [s,x,y] -> [[s,0,x],
    #             [0,s,y]]
    n = z_where.size(0)
    out = torch.cat((ng_zeros([1, 1]).type_as(z_where).expand(n, 1), z_where),
                    1)
    ix = Variable(expansion_indices)
    if z_where.is_cuda:
        ix = ix.cuda()
    out = torch.index_select(out, 1, ix)
    out = out.view(n, 2, 3)
    return out
Пример #24
0
    def model(self, x):
        # register decoder with pyro
        pyro.module("decoder", self.decoder)

        # setup hyper-parameters for prior p(z)
        z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
        z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)

        # sample from prior (value will be sampled by guide when computing the ELBO),
        # decode the latent code z,
        # and score against actual frame
        z = pyro.sample("latent", dist.normal, z_mu, z_sigma)
        mu_frame = self.decoder.forward(z)
        pyro.observe("obs", dist.bernoulli, x.view(-1, self.frame), mu_frame)
Пример #25
0
        def model(batch_size_outer=2, batch_size_inner=2):
            mu_latent = pyro.sample("mu_latent", dist.normal, ng_zeros(1), ng_ones(1))

            def outer(i, x):
                pyro.map_data("map_inner_%d" % i, x, lambda _i, _x:
                              inner(i, _i, _x), batch_size=batch_size_inner)

            def inner(i, _i, _x):
                pyro.sample("z_%d_%d" % (i, _i), dist.normal, mu_latent + _x, ng_ones(1))

            pyro.map_data("map_outer", [[ng_ones(1)] * 2] * 2, lambda i, x:
                          outer(i, x), batch_size=batch_size_outer)

            return mu_latent
Пример #26
0
        def model(batch_size_outer=2, batch_size_inner=2):
            mu_latent = pyro.sample("mu_latent", dist.normal, ng_zeros(1), ng_ones(1))

            def outer(i, x):
                pyro.map_data("map_inner_%d" % i, x, lambda _i, _x:
                              inner(i, _i, _x), batch_size=batch_size_inner)

            def inner(i, _i, _x):
                pyro.sample("z_%d_%d" % (i, _i), dist.normal, mu_latent + _x, ng_ones(1))

            pyro.map_data("map_outer", [[ng_ones(1)] * 2] * 2, lambda i, x:
                          outer(i, x), batch_size=batch_size_outer)

            return mu_latent
Пример #27
0
    def test_do_propagation(self):
        pyro.clear_param_store()

        def model():
            z = pyro.sample("z", Normal(10.0 * ng_ones(1), 0.0001 * ng_ones(1)))
            latent_prob = torch.exp(z) / (torch.exp(z) + ng_ones(1))
            flip = pyro.sample("flip", Bernoulli(latent_prob))
            return flip

        sample_from_model = model()
        z_data = {"z": -10.0 * ng_ones(1)}
        # under model flip = 1 with high probability; so do indirect DO surgery to make flip = 0
        sample_from_do_model = poutine.trace(poutine.do(model, data=z_data))()

        assert eq(sample_from_model, ng_ones(1))
        assert eq(sample_from_do_model, ng_zeros(1))
Пример #28
0
 def model(self, x):
     raise NotImplementedError("don't use it plzz")
     pyro.module("decoder", self.decoders)
     # gerer params prior
     z_mu = ng_zeros([x.size(0), self.platent["dim"]], type_as=x.data)
     z_sigma = ng_ones([x.size(0), self.platent["dim"]], type_as=x.data)
     z = pyro.sample("latent",self.platent["dist"], z_mu, z_sigma)
     decoder_out = self.decoder.forward(z)
     if type(decoder_out)!=tuple:
         decoder_out = (decoder_out, )
     if pinput["dist"] == dist.normal:
         decoder_out = list(decoder_out)
         decoder_out[1] = torch.exp(decoder_out[1])
         decoder_out = tuple(decoder_out)
         
     pyro.sample("obs", self.pinput["dist"], obs=x.view(-1, self.pinput["dim"]), *decoder_out)
Пример #29
0
    def test_do_propagation(self):
        pyro.clear_param_store()

        def model():
            z = pyro.sample("z", Normal(10.0 * ng_ones(1), 0.0001 * ng_ones(1)))
            latent_prob = torch.exp(z) / (torch.exp(z) + ng_ones(1))
            flip = pyro.sample("flip", Bernoulli(latent_prob))
            return flip

        sample_from_model = model()
        z_data = {"z": -10.0 * ng_ones(1)}
        # under model flip = 1 with high probability; so do indirect DO surgery to make flip = 0
        sample_from_do_model = poutine.trace(poutine.do(model, data=z_data))()

        assert eq(sample_from_model, ng_ones(1))
        assert eq(sample_from_do_model, ng_zeros(1))
Пример #30
0
def _compute_elbo_non_reparam(guide_trace, guide_vec_md_nodes,  #
                              non_reparam_nodes, downstream_costs):
    # construct all the reinforce-like terms.
    # we include only downstream costs to reduce variance
    # optionally include baselines to further reduce variance
    # XXX should the average baseline be in the param store as below?
    surrogate_elbo = 0.0
    baseline_loss = 0.0
    for node in non_reparam_nodes:
        guide_site = guide_trace.nodes[node]
        log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf'
        downstream_cost = downstream_costs[node]
        baseline = 0.0
        (nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta,
            baseline_value) = _get_baseline_options(guide_site)
        use_nn_baseline = nn_baseline is not None
        use_baseline_value = baseline_value is not None
        assert(not (use_nn_baseline and use_baseline_value)), \
            "cannot use baseline_value and nn_baseline simultaneously"
        if use_decaying_avg_baseline:
            avg_downstream_cost_old = pyro.param("__baseline_avg_downstream_cost_" + node,
                                                 ng_zeros(1), tags="__tracegraph_elbo_internal_tag")
            avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \
                baseline_beta * avg_downstream_cost_old
            avg_downstream_cost_old.data = avg_downstream_cost_new.data  # XXX copy_() ?
            baseline += avg_downstream_cost_old
        if use_nn_baseline:
            # block nn_baseline_input gradients except in baseline loss
            baseline += nn_baseline(detach_iterable(nn_baseline_input))
        elif use_baseline_value:
            # it's on the user to make sure baseline_value tape only points to baseline params
            baseline += baseline_value
        if use_nn_baseline or use_baseline_value:
            # accumulate baseline loss
            baseline_loss += torch.pow(downstream_cost.detach() - baseline, 2.0).sum()

        guide_log_pdf = guide_site[log_pdf_key] / guide_site["scale"]  # not scaled by subsampling
        if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value:
            if downstream_cost.size() != baseline.size():
                raise ValueError("Expected baseline at site {} to be {} instead got {}".format(
                    node, downstream_cost.size(), baseline.size()))
            downstream_cost = downstream_cost - baseline
        surrogate_elbo += (guide_log_pdf * downstream_cost.detach()).sum()

    return surrogate_elbo, baseline_loss
Пример #31
0
def test_subsample_gradient(trace_graph, reparameterized):
    pyro.clear_param_store()
    data_size = 2
    subsample_size = 1
    num_particles = 1000
    precision = 0.333
    data = dist.normal(ng_zeros(data_size), ng_ones(data_size))

    def model(subsample_size):
        with pyro.iarange("data", len(data), subsample_size) as ind:
            x = data[ind]
            z = pyro.sample("z", dist.Normal(ng_zeros(len(x)), ng_ones(len(x)),
                                             reparameterized=reparameterized))
            pyro.observe("x", dist.Normal(z, ng_ones(len(x)), reparameterized=reparameterized), x)

    def guide(subsample_size):
        mu = pyro.param("mu", lambda: Variable(torch.zeros(len(data)), requires_grad=True))
        sigma = pyro.param("sigma", lambda: Variable(torch.ones(1), requires_grad=True))
        with pyro.iarange("data", len(data), subsample_size) as ind:
            mu = mu[ind]
            sigma = sigma.expand(subsample_size)
            pyro.sample("z", dist.Normal(mu, sigma, reparameterized=reparameterized))

    optim = Adam({"lr": 0.1})
    inference = SVI(model, guide, optim, loss="ELBO",
                    trace_graph=trace_graph, num_particles=num_particles)

    # Compute gradients without subsampling.
    inference.loss_and_grads(model, guide, subsample_size=data_size)
    params = dict(pyro.get_param_store().named_parameters())
    expected_grads = {name: param.grad.data.clone() for name, param in params.items()}
    zero_grads(params.values())

    # Compute gradients with subsampling.
    inference.loss_and_grads(model, guide, subsample_size=subsample_size)
    actual_grads = {name: param.grad.data.clone() for name, param in params.items()}

    for name in sorted(params):
        print('\nexpected {} = {}'.format(name, expected_grads[name].cpu().numpy()))
        print('actual   {} = {}'.format(name, actual_grads[name].cpu().numpy()))
    assert_equal(actual_grads, expected_grads, prec=precision)
Пример #32
0
 def setUp(self):
     # normal-normal; known covariance
     self.lam0 = Variable(torch.Tensor([0.1, 0.1]))   # precision of prior
     self.mu0 = Variable(torch.Tensor([0.0, 0.5]))   # prior mean
     # known precision of observation noise
     self.lam = Variable(torch.Tensor([6.0, 4.0]))
     self.n_outer = 3
     self.n_inner = 3
     self.n_data = Variable(torch.Tensor([self.n_outer * self.n_inner]))
     self.data = []
     self.sum_data = ng_zeros(2)
     for _out in range(self.n_outer):
         data_in = []
         for _in in range(self.n_inner):
             data_in.append(Variable(torch.Tensor([-0.1, 0.3]) + torch.randn(2) / torch.sqrt(self.lam.data)))
             self.sum_data += data_in[-1]
         self.data.append(data_in)
     self.analytic_lam_n = self.lam0 + self.n_data.expand_as(self.lam) * self.lam
     self.analytic_log_sig_n = -0.5 * torch.log(self.analytic_lam_n)
     self.analytic_mu_n = self.sum_data * (self.lam / self.analytic_lam_n) +\
         self.mu0 * (self.lam0 / self.analytic_lam_n)
     self.verbose = True
Пример #33
0
def test_one_hot_categorical_batch_log_pdf_shape():
    ps = ng_ones(3, 2, 4) / 4
    x = ng_zeros(3, 2, 4)
    x[:, :, 0] = 1
    d = dist.OneHotCategorical(ps)
    assert d.batch_log_pdf(x).size() == (3, 2, 1)
Пример #34
0
def test_normal_batch_log_pdf_shape():
    mu = ng_zeros(3, 2)
    sigma = ng_ones(3, 2)
    x = ng_zeros(3, 2)
    d = dist.Normal(mu, sigma)
    assert d.batch_log_pdf(x).size() == (3, 1)
Пример #35
0
def test_normal_batch_log_pdf_shape():
    mu = ng_zeros(3, 2)
    sigma = ng_ones(3, 2)
    x = ng_zeros(3, 2)
    d = dist.Normal(mu, sigma)
    assert d.batch_log_pdf(x).size() == (3, 1)
Пример #36
0
def test_categorical_batch_log_pdf_shape():
    ps = ng_ones(3, 2, 4) / 4
    x = ng_zeros(3, 2, 1)
    d = dist.Categorical(ps)
    assert d.batch_log_pdf(x).size() == (3, 2, 1)
Пример #37
0
 def model(subsample_size):
     with pyro.iarange("data", len(data), subsample_size) as ind:
         x = data[ind]
         z = pyro.sample("z", dist.Normal(ng_zeros(len(x)), ng_ones(len(x)),
                                          reparameterized=reparameterized))
         pyro.observe("x", dist.Normal(z, ng_ones(len(x)), reparameterized=reparameterized), x)
Пример #38
0
 def model_dup():
     pyro.param("mu_q", Variable(torch.ones(1), requires_grad=True))
     pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
Пример #39
0
 def guide():
     p = pyro.param("p", Variable(torch.ones(1), requires_grad=True))
     pyro.sample("mu_q", dist.normal, ng_zeros(1), p)
     pyro.sample("mu_q_2", dist.normal, ng_zeros(1), p)
Пример #40
0
 def model():
     mu_latent = pyro.sample("mu_latent", dist.Normal(ng_zeros(1), ng_ones(1), reparameterized=reparameterized))
     pyro.observe('obs', dist.normal, Variable(torch.Tensor([0.23])), mu_latent, ng_ones(1))
     return mu_latent
Пример #41
0
 def model_obs_dup():
     pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
     pyro.observe("mu_q", dist.normal, ng_zeros(1), ng_ones(1), ng_zeros(1))
Пример #42
0
    def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
        # get info regarding rao-blackwellization of vectorized map_data
        guide_vec_md_info = guide_trace.graph["vectorized_map_data_info"]
        model_vec_md_info = model_trace.graph["vectorized_map_data_info"]
        guide_vec_md_condition = guide_vec_md_info['rao-blackwellization-condition']
        model_vec_md_condition = model_vec_md_info['rao-blackwellization-condition']
        do_vec_rb = guide_vec_md_condition and model_vec_md_condition
        if not do_vec_rb:
            warnings.warn(
                "Unable to do fully-vectorized Rao-Blackwellization in TraceGraph_ELBO. "
                "Falling back to higher-variance gradient estimator. "
                "Try to avoid these issues in your model and guide:\n{}".format("\n".join(
                    guide_vec_md_info["warnings"] | model_vec_md_info["warnings"])))
        guide_vec_md_nodes = guide_vec_md_info['nodes'] if do_vec_rb else set()
        model_vec_md_nodes = model_vec_md_info['nodes'] if do_vec_rb else set()

        # have the trace compute all the individual (batch) log pdf terms
        # so that they are available below
        guide_trace.compute_batch_log_pdf(site_filter=lambda name, site: name in guide_vec_md_nodes)
        guide_trace.log_pdf()
        model_trace.compute_batch_log_pdf(site_filter=lambda name, site: name in model_vec_md_nodes)
        model_trace.log_pdf()

        # prepare a list of all the cost nodes, each of which is +- log_pdf
        cost_nodes = []
        non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
        for name, model_site in model_trace.nodes.items():
            if model_site["type"] == "sample":
                if model_site["is_observed"]:
                    cost_nodes.append(CostNode(model_site["log_pdf"], True))
                else:
                    # cost node from model sample
                    cost_nodes.append(CostNode(model_site["log_pdf"], True))
                    # cost node from guide sample
                    guide_site = guide_trace.nodes[name]
                    zero_expectation = name in non_reparam_nodes
                    cost_nodes.append(CostNode(-guide_site["log_pdf"], not zero_expectation))

        # compute the elbo; if all stochastic nodes are reparameterizable, we're done
        # this bit is never differentiated: it's here for getting an estimate of the elbo itself
        elbo = torch_data_sum(sum(c.cost for c in cost_nodes))

        # compute the surrogate elbo, removing terms whose gradient is zero
        # this is the bit that's actually differentiated
        # XXX should the user be able to control if these terms are included?
        surrogate_elbo = sum(c.cost for c in cost_nodes if c.nonzero_expectation)

        # the following computations are only necessary if we have non-reparameterizable nodes
        baseline_loss = 0.0
        if non_reparam_nodes:

            # recursively compute downstream cost nodes for all sample sites in model and guide
            # (even though ultimately just need for non-reparameterizable sample sites)
            # 1. downstream costs used for rao-blackwellization
            # 2. model observe sites (as well as terms that arise from the model and guide having different
            # dependency structures) are taken care of via 'children_in_model' below
            topo_sort_guide_nodes = list(reversed(list(networkx.topological_sort(guide_trace))))
            topo_sort_guide_nodes = [x for x in topo_sort_guide_nodes
                                     if guide_trace.nodes[x]["type"] == "sample"]
            downstream_guide_cost_nodes = {}
            downstream_costs = {}

            for node in topo_sort_guide_nodes:
                node_log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf'
                downstream_costs[node] = model_trace.nodes[node][node_log_pdf_key] - \
                    guide_trace.nodes[node][node_log_pdf_key]
                nodes_included_in_sum = set([node])
                downstream_guide_cost_nodes[node] = set([node])
                for child in guide_trace.successors(node):
                    child_cost_nodes = downstream_guide_cost_nodes[child]
                    downstream_guide_cost_nodes[node].update(child_cost_nodes)
                    if nodes_included_in_sum.isdisjoint(child_cost_nodes):  # avoid duplicates
                        if node_log_pdf_key == 'log_pdf':
                            downstream_costs[node] += downstream_costs[child].sum()
                        else:
                            downstream_costs[node] += downstream_costs[child]
                        nodes_included_in_sum.update(child_cost_nodes)
                missing_downstream_costs = downstream_guide_cost_nodes[node] - nodes_included_in_sum
                # include terms we missed because we had to avoid duplicates
                for missing_node in missing_downstream_costs:
                    mn_log_pdf_key = 'batch_log_pdf' if missing_node in guide_vec_md_nodes else 'log_pdf'
                    if node_log_pdf_key == 'log_pdf':
                        downstream_costs[node] += (model_trace.nodes[missing_node][mn_log_pdf_key] -
                                                   guide_trace.nodes[missing_node][mn_log_pdf_key]).sum()
                    else:
                        downstream_costs[node] += model_trace.nodes[missing_node][mn_log_pdf_key] - \
                                                  guide_trace.nodes[missing_node][mn_log_pdf_key]

            # finish assembling complete downstream costs
            # (the above computation may be missing terms from model)
            # XXX can we cache some of the sums over children_in_model to make things more efficient?
            for site in non_reparam_nodes:
                children_in_model = set()
                for node in downstream_guide_cost_nodes[site]:
                    children_in_model.update(model_trace.successors(node))
                # remove terms accounted for above
                children_in_model.difference_update(downstream_guide_cost_nodes[site])
                for child in children_in_model:
                    child_log_pdf_key = 'batch_log_pdf' if child in model_vec_md_nodes else 'log_pdf'
                    site_log_pdf_key = 'batch_log_pdf' if site in guide_vec_md_nodes else 'log_pdf'
                    assert (model_trace.nodes[child]["type"] == "sample")
                    if site_log_pdf_key == 'log_pdf':
                        downstream_costs[site] += model_trace.nodes[child][child_log_pdf_key].sum()
                    else:
                        downstream_costs[site] += model_trace.nodes[child][child_log_pdf_key]

            # construct all the reinforce-like terms.
            # we include only downstream costs to reduce variance
            # optionally include baselines to further reduce variance
            # XXX should the average baseline be in the param store as below?
            elbo_reinforce_terms = 0.0
            for node in non_reparam_nodes:
                guide_site = guide_trace.nodes[node]
                log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf'
                downstream_cost = downstream_costs[node]
                baseline = 0.0
                (nn_baseline, nn_baseline_input, use_decaying_avg_baseline, baseline_beta,
                    baseline_value) = _get_baseline_options(guide_site)
                use_nn_baseline = nn_baseline is not None
                use_baseline_value = baseline_value is not None
                assert(not (use_nn_baseline and use_baseline_value)), \
                    "cannot use baseline_value and nn_baseline simultaneously"
                if use_decaying_avg_baseline:
                    avg_downstream_cost_old = pyro.param("__baseline_avg_downstream_cost_" + node,
                                                         ng_zeros(1), tags="__tracegraph_elbo_internal_tag")
                    avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \
                        baseline_beta * avg_downstream_cost_old
                    avg_downstream_cost_old.data = avg_downstream_cost_new.data  # XXX copy_() ?
                    baseline += avg_downstream_cost_old
                if use_nn_baseline:
                    # block nn_baseline_input gradients except in baseline loss
                    baseline += nn_baseline(detach_iterable(nn_baseline_input))
                elif use_baseline_value:
                    # it's on the user to make sure baseline_value tape only points to baseline params
                    baseline += baseline_value
                if use_nn_baseline or use_baseline_value:
                    # accumulate baseline loss
                    baseline_loss += torch.pow(downstream_cost.detach() - baseline, 2.0).sum()

                guide_log_pdf = guide_site[log_pdf_key] / guide_site["scale"]  # not scaled by subsampling
                if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value:
                    if downstream_cost.size() != baseline.size():
                        raise ValueError("Expected baseline at site {} to be {} instead got {}".format(
                            node, downstream_cost.size(), baseline.size()))
                    downstream_cost = downstream_cost - baseline
                elbo_reinforce_terms += (guide_log_pdf * downstream_cost.detach()).sum()

            surrogate_elbo += elbo_reinforce_terms

        # collect parameters to train from model and guide
        trainable_params = set(site["value"]
                               for trace in (model_trace, guide_trace)
                               for site in trace.nodes.values()
                               if site["type"] == "param")

        if trainable_params:
            surrogate_loss = -surrogate_elbo
            torch_backward(weight * (surrogate_loss + baseline_loss))
            pyro.get_param_store().mark_params_active(trainable_params)

        loss = -elbo
        return weight * loss
Пример #43
0
 def ng_zeros(self, *args, **kwargs):
     t = ng_zeros(*args, **kwargs)
     if self.use_cuda:
         t = t.cuda()
     return t
Пример #44
0
def model(data):
    latent = named.Object("latent")
    latent.z.sample_(dist.normal, ng_zeros(1), ng_ones(1))
    model_recurse(data, latent)
Пример #45
0
 def guide():
     p = pyro.param("p", Variable(torch.ones(1), requires_grad=True))
     pyro.sample("mu_q", dist.normal, ng_zeros(1), p)
     pyro.sample("mu_q_2", dist.normal, ng_zeros(1), p)
Пример #46
0
    def loss_and_grads(self, model, guide, *args, **kwargs):
        """
        :returns: returns an estimate of the ELBO
        :rtype: float

        Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
        Performs backward on the latter. Num_particle many samples are used to form the estimators.
        If baselines are present, a baseline loss is also constructed and differentiated.
        """
        elbo = 0.0

        for weight, model_trace, guide_trace in self._get_traces(
                model, guide, *args, **kwargs):

            # get info regarding rao-blackwellization of vectorized map_data
            guide_vec_md_info = guide_trace.graph["vectorized_map_data_info"]
            model_vec_md_info = model_trace.graph["vectorized_map_data_info"]
            guide_vec_md_condition = guide_vec_md_info[
                'rao-blackwellization-condition']
            model_vec_md_condition = model_vec_md_info[
                'rao-blackwellization-condition']
            do_vec_rb = guide_vec_md_condition and model_vec_md_condition
            if not do_vec_rb:
                warnings.warn(
                    "Unable to do fully-vectorized Rao-Blackwellization in TraceGraph_ELBO. "
                    "Falling back to higher-variance gradient estimator. "
                    "Try to avoid these issues in your model and guide:\n{}".
                    format("\n".join(guide_vec_md_info["warnings"]
                                     | model_vec_md_info["warnings"])))
            guide_vec_md_nodes = guide_vec_md_info[
                'nodes'] if do_vec_rb else set()
            model_vec_md_nodes = model_vec_md_info[
                'nodes'] if do_vec_rb else set()

            # have the trace compute all the individual (batch) log pdf terms
            # so that they are available below
            guide_trace.compute_batch_log_pdf(
                site_filter=lambda name, site: name in guide_vec_md_nodes)
            guide_trace.log_pdf()
            model_trace.compute_batch_log_pdf(
                site_filter=lambda name, site: name in model_vec_md_nodes)
            model_trace.log_pdf()

            # prepare a list of all the cost nodes, each of which is +- log_pdf
            cost_nodes = []
            non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
            for site in model_trace.nodes.keys():
                model_trace_site = model_trace.nodes[site]
                log_pdf_key = 'batch_log_pdf' if site in model_vec_md_nodes else 'log_pdf'
                if model_trace_site["type"] == "sample":
                    if model_trace_site["is_observed"]:
                        cost_node = (model_trace_site[log_pdf_key], True)
                        cost_nodes.append(cost_node)
                    else:
                        # cost node from model sample
                        cost_node1 = (model_trace_site[log_pdf_key], True)
                        # cost node from guide sample
                        zero_expectation = site in non_reparam_nodes
                        cost_node2 = (-guide_trace.nodes[site][log_pdf_key],
                                      not zero_expectation)
                        cost_nodes.extend([cost_node1, cost_node2])

            elbo_particle = 0.0
            surrogate_elbo_particle = 0.0
            baseline_loss_particle = 0.0
            elbo_reinforce_terms_particle = 0.0
            elbo_no_zero_expectation_terms_particle = 0.0

            # compute the elbo; if all stochastic nodes are reparameterizable, we're done
            # this bit is never differentiated: it's here for getting an estimate of the elbo itself
            for cost_node in cost_nodes:
                elbo_particle += cost_node[0].sum()
            elbo += weight * torch_data_sum(elbo_particle)

            # compute the elbo, removing terms whose gradient is zero
            # this is the bit that's actually differentiated
            # XXX should the user be able to control if these terms are included?
            for cost_node in cost_nodes:
                if cost_node[1]:
                    elbo_no_zero_expectation_terms_particle += cost_node[
                        0].sum()
            surrogate_elbo_particle += weight * elbo_no_zero_expectation_terms_particle

            # the following computations are only necessary if we have non-reparameterizable nodes
            if len(non_reparam_nodes) > 0:

                # recursively compute downstream cost nodes for all sample sites in model and guide
                # (even though ultimately just need for non-reparameterizable sample sites)
                # 1. downstream costs used for rao-blackwellization
                # 2. model observe sites (as well as terms that arise from the model and guide having different
                # dependency structures) are taken care of via 'children_in_model' below
                topo_sort_guide_nodes = list(
                    reversed(list(networkx.topological_sort(guide_trace))))
                topo_sort_guide_nodes = [
                    x for x in topo_sort_guide_nodes
                    if guide_trace.nodes[x]["type"] == "sample"
                ]
                downstream_guide_cost_nodes = {}
                downstream_costs = {}

                for node in topo_sort_guide_nodes:
                    node_log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf'
                    downstream_costs[node] = model_trace.nodes[node][node_log_pdf_key] - \
                        guide_trace.nodes[node][node_log_pdf_key]
                    nodes_included_in_sum = set([node])
                    downstream_guide_cost_nodes[node] = set([node])
                    for child in guide_trace.successors(node):
                        child_cost_nodes = downstream_guide_cost_nodes[child]
                        downstream_guide_cost_nodes[node].update(
                            child_cost_nodes)
                        if nodes_included_in_sum.isdisjoint(
                                child_cost_nodes):  # avoid duplicates
                            if node_log_pdf_key == 'log_pdf':
                                downstream_costs[node] += downstream_costs[
                                    child].sum()
                            else:
                                downstream_costs[node] += downstream_costs[
                                    child]
                            nodes_included_in_sum.update(child_cost_nodes)
                    missing_downstream_costs = downstream_guide_cost_nodes[
                        node] - nodes_included_in_sum
                    # include terms we missed because we had to avoid duplicates
                    for missing_node in missing_downstream_costs:
                        mn_log_pdf_key = 'batch_log_pdf' if missing_node in guide_vec_md_nodes else 'log_pdf'
                        if node_log_pdf_key == 'log_pdf':
                            downstream_costs[node] += (
                                model_trace.nodes[missing_node][mn_log_pdf_key]
                                -
                                guide_trace.nodes[missing_node][mn_log_pdf_key]
                            ).sum()
                        else:
                            downstream_costs[node] += model_trace.nodes[missing_node][mn_log_pdf_key] - \
                                                      guide_trace.nodes[missing_node][mn_log_pdf_key]

                # finish assembling complete downstream costs
                # (the above computation may be missing terms from model)
                # XXX can we cache some of the sums over children_in_model to make things more efficient?
                for site in non_reparam_nodes:
                    children_in_model = set()
                    for node in downstream_guide_cost_nodes[site]:
                        children_in_model.update(model_trace.successors(node))
                    # remove terms accounted for above
                    children_in_model.difference_update(
                        downstream_guide_cost_nodes[site])
                    for child in children_in_model:
                        child_log_pdf_key = 'batch_log_pdf' if child in model_vec_md_nodes else 'log_pdf'
                        site_log_pdf_key = 'batch_log_pdf' if site in guide_vec_md_nodes else 'log_pdf'
                        assert (model_trace.nodes[child]["type"] == "sample")
                        if site_log_pdf_key == 'log_pdf':
                            downstream_costs[site] += model_trace.nodes[child][
                                child_log_pdf_key].sum()
                        else:
                            downstream_costs[site] += model_trace.nodes[child][
                                child_log_pdf_key]

                # construct all the reinforce-like terms.
                # we include only downstream costs to reduce variance
                # optionally include baselines to further reduce variance
                # XXX should the average baseline be in the param store as below?

                # for extracting baseline options from site["baseline"]
                # XXX default for baseline_beta currently set here
                def get_baseline_options(site_baseline):
                    options_dict = site_baseline.copy()
                    options_tuple = (options_dict.pop('nn_baseline', None),
                                     options_dict.pop('nn_baseline_input',
                                                      None),
                                     options_dict.pop(
                                         'use_decaying_avg_baseline', False),
                                     options_dict.pop('baseline_beta', 0.90),
                                     options_dict.pop('baseline_value', None))
                    if options_dict:
                        raise ValueError(
                            "Unrecognized baseline options: {}".format(
                                options_dict.keys()))
                    return options_tuple

                baseline_loss_particle = 0.0
                for node in non_reparam_nodes:
                    log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf'
                    downstream_cost = downstream_costs[node]
                    baseline = 0.0
                    (nn_baseline, nn_baseline_input, use_decaying_avg_baseline,
                     baseline_beta, baseline_value) = get_baseline_options(
                         guide_trace.nodes[node]["baseline"])
                    use_nn_baseline = nn_baseline is not None
                    use_baseline_value = baseline_value is not None
                    assert(not (use_nn_baseline and use_baseline_value)), \
                        "cannot use baseline_value and nn_baseline simultaneously"
                    if use_decaying_avg_baseline:
                        avg_downstream_cost_old = pyro.param(
                            "__baseline_avg_downstream_cost_" + node,
                            ng_zeros(1),
                            tags="__tracegraph_elbo_internal_tag")
                        avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \
                            baseline_beta * avg_downstream_cost_old
                        avg_downstream_cost_old.data = avg_downstream_cost_new.data  # XXX copy_() ?
                        baseline += avg_downstream_cost_old
                    if use_nn_baseline:
                        # block nn_baseline_input gradients except in baseline loss
                        baseline += nn_baseline(
                            detach_iterable(nn_baseline_input))
                    elif use_baseline_value:
                        # it's on the user to make sure baseline_value tape only points to baseline params
                        baseline += baseline_value
                    if use_nn_baseline or use_baseline_value:
                        # construct baseline loss
                        baseline_loss = torch.pow(
                            downstream_cost.detach() - baseline, 2.0).sum()
                        baseline_loss_particle += weight * baseline_loss
                    if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value:
                        if downstream_cost.size() != baseline.size():
                            raise ValueError(
                                "Expected baseline at site {} to be {} instead got {}"
                                .format(node, downstream_cost.size(),
                                        baseline.size()))
                        elbo_reinforce_terms_particle += (
                            guide_trace.nodes[node][log_pdf_key] *
                            (downstream_cost - baseline).detach()).sum()
                    else:
                        elbo_reinforce_terms_particle += (
                            guide_trace.nodes[node][log_pdf_key] *
                            downstream_cost.detach()).sum()

                surrogate_elbo_particle += weight * elbo_reinforce_terms_particle
                torch_backward(baseline_loss_particle)

            # collect parameters to train from model and guide
            trainable_params = set(site["value"]
                                   for trace in (model_trace, guide_trace)
                                   for site in trace.nodes.values()
                                   if site["type"] == "param")

            surrogate_loss_particle = -surrogate_elbo_particle
            if trainable_params:
                torch_backward(surrogate_loss_particle)

                # mark all params seen in trace as active so that gradient steps are taken downstream
                pyro.get_param_store().mark_params_active(trainable_params)

        loss = -elbo

        return loss
Пример #47
0
 def model():
     pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
Пример #48
0
 def model():
     pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
Пример #49
0
 def model(self, data):
     decoder = pyro.module('decoder', self.vae_decoder)
     z_mean, z_std = ng_zeros([data.size(0), 20]), ng_ones([data.size(0), 20])
     z = pyro.sample('latent', Normal(z_mean, z_std))
     img = decoder.forward(z)
     pyro.observe('obs', Bernoulli(img), data.view(-1, 784))