Beispiel #1
0
def test_module_nn(nn_module):
    pyro.clear_param_store()
    nn_module = nn_module()
    assert pyro.get_param_store()._params == {}
    pyro.module("module", nn_module)
    for name in pyro.get_param_store().get_all_param_names():
        assert pyro.params.user_param_name(name) in nn_module.state_dict().keys()
Beispiel #2
0
 def guide(self, x):
     # register PyTorch module `encoder` with Pyro
     pyro.module("encoder", self.encoder)
     # use the encoder to get the parameters used to define q(z|x)
     z_mu, z_sigma = self.encoder.forward(x)
     # sample the latent code z
     pyro.sample("latent", dist.normal, z_mu, z_sigma)
 def guide():
     pyro.module("mymodule", pt_guide)
     mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log)
     sigma = torch.pow(tau_q, -0.5)
     pyro.sample("mu_latent",
                 dist.Normal(mu_q, sigma, reparameterized=reparameterized),
                 baseline=dict(use_decaying_avg_baseline=True))
Beispiel #4
0
 def guide(self, x):
     # register PyTorch module `encoder` with Pyro
     pyro.module("encoder", self.encoder)
     with pyro.iarange("data", x.size(0)):
         # use the encoder to get the parameters used to define q(z|x)
         z_loc, z_scale = self.encoder.forward(x)
         # sample the latent code z
         pyro.sample("latent", dist.Normal(z_loc, z_scale).independent(1))
Beispiel #5
0
 def model(self, data, _, **kwargs):
     pyro.module("decode", self.decode)
     with pyro.iarange('data', data.size(0), use_cuda=self.use_cuda) as ix:
         batch = data[ix]
         n = batch.size(0)
         (z_where, z_pres), x = self.prior(n, **kwargs)
         pyro.sample('obs',
                     dist.Normal(x.view(n, -1),
                                 (self.likelihood_sd * self.prototype.new_ones(n, self.x_size ** 2)))
                         .independent(1),
                     obs=batch.view(n, -1))
Beispiel #6
0
 def model(self, x):
     # register PyTorch module `decoder` with Pyro
     pyro.module("decoder", self.decoder)
     # setup hyperparameters for prior p(z)
     # the type_as ensures we get cuda Tensors if x is on gpu
     z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
     z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)
     # sample from prior (value will be sampled by guide when computing the ELBO)
     z = pyro.sample("latent", dist.normal, z_mu, z_sigma)
     # decode the latent code z
     mu_img = self.decoder.forward(z)
     # score against actual images
     pyro.observe("obs", dist.bernoulli, x.view(-1, 784), mu_img)
Beispiel #7
0
 def model(self, x):
     # register PyTorch module `decoder` with Pyro
     pyro.module("decoder", self.decoder)
     with pyro.iarange("data", x.size(0)):
         # setup hyperparameters for prior p(z)
         z_loc = x.new_zeros(torch.Size((x.size(0), self.z_dim)))
         z_scale = x.new_ones(torch.Size((x.size(0), self.z_dim)))
         # sample from prior (value will be sampled by guide when computing the ELBO)
         z = pyro.sample("latent", dist.Normal(z_loc, z_scale).independent(1))
         # decode the latent code z
         loc_img = self.decoder.forward(z)
         # score against actual images
         pyro.sample("obs", dist.Bernoulli(loc_img).independent(1), obs=x.reshape(-1, 784))
         # return the loc so we can visualize it later
         return loc_img
Beispiel #8
0
    def model(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)

        # register all PyTorch (sub)modules with pyro
        # this needs to happen in both the model and guide
        pyro.module("dmm", self)

        # set z_prev = z_0 to setup the recursive conditioning in p(z_t | z_{t-1})
        z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0))

        # we enclose all the sample statements in the model in a iarange.
        # this marks that each datapoint is conditionally independent of the others
        with pyro.iarange("z_minibatch", len(mini_batch)):
            # sample the latents z and observed x's one time step at a time
            for t in range(1, T_max + 1):
                # the next chunk of code samples z_t ~ p(z_t | z_{t-1})
                # note that (both here and elsewhere) we use poutine.scale to take care
                # of KL annealing. we use the mask() method to deal with raggedness
                # in the observed data (i.e. different sequences in the mini-batch
                # have different lengths)

                # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
                z_loc, z_scale = self.trans(z_prev)

                # then sample z_t according to dist.Normal(z_loc, z_scale)
                # note that we use the reshape method so that the univariate Normal distribution
                # is treated as a multivariate Normal distribution with a diagonal covariance.
                with poutine.scale(scale=annealing_factor):
                    z_t = pyro.sample("z_%d" % t,
                                      dist.Normal(z_loc, z_scale)
                                          .mask(mini_batch_mask[:, t - 1:t])
                                          .independent(1))

                # compute the probabilities that parameterize the bernoulli likelihood
                emission_probs_t = self.emitter(z_t)
                # the next statement instructs pyro to observe x_t according to the
                # bernoulli distribution p(x_t|z_t)
                pyro.sample("obs_x_%d" % t,
                            dist.Bernoulli(emission_probs_t)
                                .mask(mini_batch_mask[:, t - 1:t])
                                .independent(1),
                            obs=mini_batch[:, t - 1, :])
                # the latent sampled at this time step will be conditioned upon
                # in the next time step so keep track of it
                z_prev = z_t
        def guide():
            mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.094 * torch.ones(2),
                                               requires_grad=True))
            log_sig_q = pyro.param("log_sig_q", Variable(
                                   self.analytic_log_sig_n.data - 0.11 * torch.ones(2), requires_grad=True))
            sig_q = torch.exp(log_sig_q)
            trivial_baseline = pyro.module("mu_baseline", pt_mu_baseline, tags="baseline")
            baseline_value = trivial_baseline(ng_ones(1))
            mu_latent = pyro.sample("mu_latent",
                                    dist.Normal(mu_q, sig_q, reparameterized=False),
                                    baseline=dict(baseline_value=baseline_value))

            def obs_inner(i, _i, _x):
                for k in range(n_superfluous_top + n_superfluous_bottom):
                    z_baseline = pyro.module("z_baseline_%d_%d" % (i, k),
                                             pt_superfluous_baselines[3 * k + i], tags="baseline")
                    baseline_value = z_baseline(mu_latent.detach()).unsqueeze(-1)
                    mean_i = pyro.param("mean_%d_%d" % (i, k),
                                        Variable(0.5 * torch.ones(4 - i, 1), requires_grad=True))
                    pyro.sample("z_%d_%d" % (i, k),
                                dist.Normal(mean_i, ng_ones(4 - i, 1), reparameterized=False),
                                baseline=dict(baseline_value=baseline_value))

            def obs_outer(i, x):
                pyro.map_data("map_obs_inner_%d" % i, x, lambda _i, _x:
                              obs_inner(i, _i, _x), batch_size=4 - i)

            pyro.map_data("map_obs_outer", [self.data_tensor[0:4, :], self.data_tensor[4:7, :],
                                            self.data_tensor[7:9, :]],
                          lambda i, x: obs_outer(i, x), batch_size=3)

            return mu_latent
Beispiel #10
0
    def model_classify(self, xs, ys=None):
        """
        this model is used to add an auxiliary (supervised) loss as described in the
        NIPS 2014 paper by Kingma et al titled
        "Semi-Supervised Learning with Deep Generative Models"
        """
        # register all pytorch (sub)modules with pyro
        pyro.module("ss_vae", self)

        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.iarange("independent"):
            # this here is the extra Term to yield an auxiliary loss that we do gradient descend on
            # similar to the NIPS 14 paper (Kingma et al).
            if ys is not None:
                alpha = self.encoder_y.forward(xs)
                pyro.observe("y_aux", dist.categorical, ys, alpha, log_pdf_mask=self.aux_loss_multiplier)
Beispiel #11
0
    def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)
        # register all PyTorch (sub)modules with pyro
        pyro.module("dmm", self)

        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous()
        # push the observed x's through the rnn;
        # rnn_output contains the hidden state at each time step
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        # reverse the time-ordering in the hidden state and un-pack it
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))

        # we enclose all the sample statements in the guide in a iarange.
        # this marks that each datapoint is conditionally independent of the others.
        with pyro.iarange("z_minibatch", len(mini_batch)):
            # sample the latents z one time step at a time
            for t in range(1, T_max + 1):
                # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
                z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])

                # if we are using normalizing flows, we apply the sequence of transformations
                # parameterized by self.iafs to the base distribution defined in the previous line
                # to yield a transformed distribution that we use for q(z_t|...)
                if len(self.iafs) > 0:
                    z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs)
                else:
                    z_dist = dist.Normal(z_loc, z_scale)
                assert z_dist.event_shape == ()
                assert z_dist.batch_shape == (len(mini_batch), self.z_q_0.size(0))

                # sample z_t from the distribution z_dist
                with pyro.poutine.scale(scale=annealing_factor):
                    z_t = pyro.sample("z_%d" % t,
                                      z_dist.mask(mini_batch_mask[:, t - 1:t])
                                            .independent(1))
                # the latent sampled at this time step will be conditioned upon in the next time step
                # so keep track of it
                z_prev = z_t
Beispiel #12
0
 def model(self, data):
     decoder = pyro.module('decoder', self.vae_decoder)
     z_mean, z_std = torch.zeros([data.size(0), 20]), torch.ones([data.size(0), 20])
     with pyro.iarange('data', data.size(0)):
         z = pyro.sample('latent', Normal(z_mean, z_std).independent(1))
         img = decoder.forward(z)
         pyro.sample('obs',
                     Bernoulli(img).independent(1),
                     obs=data.reshape(-1, 784))
 def obs_inner(i, _i, _x):
     for k in range(n_superfluous_top + n_superfluous_bottom):
         z_baseline = pyro.module("z_baseline_%d_%d" % (i, k),
                                  pt_superfluous_baselines[3 * k + i], tags="baseline")
         baseline_value = z_baseline(mu_latent.detach()).unsqueeze(-1)
         mean_i = pyro.param("mean_%d_%d" % (i, k),
                             Variable(0.5 * torch.ones(4 - i, 1), requires_grad=True))
         pyro.sample("z_%d_%d" % (i, k),
                     dist.Normal(mean_i, ng_ones(4 - i, 1), reparameterized=False),
                     baseline=dict(baseline_value=baseline_value))
Beispiel #14
0
    def model(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)

        # register all PyTorch (sub)modules with pyro
        # this needs to happen in both the model and guide
        pyro.module("dmm", self)

        # set z_prev = z_0 to setup the recursive conditioning in p(z_t | z_{t-1})
        z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0))

        # sample the latents z and observed x's one time step at a time
        for t in range(1, T_max + 1):
            # the next three lines of code sample z_t ~ p(z_t | z_{t-1})
            # note that (both here and elsewhere) log_pdf_mask takes care of both
            # (i)  KL annealing; and
            # (ii) raggedness in the observed data (i.e. different sequences
            #      in the mini-batch have different lengths)

            # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
            z_mu, z_sigma = self.trans(z_prev)
            # then sample z_t according to dist.Normal(z_mu, z_sigma)
            z_t = pyro.sample("z_%d" % t,
                              dist.normal,
                              z_mu,
                              z_sigma,
                              log_pdf_mask=annealing_factor * mini_batch_mask[:, t - 1:t])

            # compute the probabilities that parameterize the bernoulli likelihood
            emission_probs_t = self.emitter(z_t)
            # the next statement instructs pyro to observe x_t according to the
            # bernoulli distribution p(x_t|z_t)
            pyro.observe("obs_x_%d" % t, dist.bernoulli, mini_batch[:, t - 1, :],
                         emission_probs_t,
                         log_pdf_mask=mini_batch_mask[:, t - 1:t])
            # the latent sampled at this time step will be conditioned upon
            # in the next time step so keep track of it
            z_prev = z_t
Beispiel #15
0
    def guide(self, x):
        """
        1. Pass x as the input data to the LSTM encoder
        2. Pass the LSTM encoder's last hidden state into the MLP to obtain f(x), g(x)
        3. Sample z from multivariate normal N(f(x),g(x))
        """
        pyro.module("encoder_lstm", self.encoder_lstm)
        pyro.module("loc_mlp", self.loc_mlp)
        pyro.module("scale_mlp", self.scale_mlp)

        batch_size, x_tensor = self._preprocess_input(x)
        with pyro.plate("data"):
            encoder_hidden_state = self.encoder_lstm.init_hidden(
                batch_size=batch_size)
            for i in range(MAX_STRING_LEN):
                _, encoder_hidden_state = self.encoder_lstm.forward(
                    x_tensor[i].unsqueeze(0), encoder_hidden_state)

            encoder_hidden = encoder_hidden_state[0].view(batch_size, -1)
            encoder_cell = encoder_hidden_state[1].view(batch_size, -1)
            flattened_lstm_output = torch.cat((encoder_hidden, encoder_cell),
                                              dim=1)
            loc = self.loc_mlp.forward(flattened_lstm_output)
            scale = self.loc_mlp.forward(flattened_lstm_output)

            z = pyro.sample("z", dist.Normal(loc, scale).to_event(1))
Beispiel #16
0
    def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)
        # register all PyTorch (sub)modules with pyro
        pyro.module("dmm", self)

        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        h_0_contig = self.h_0 if not self.use_cuda \
            else self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous()
        # push the observed x's through the rnn;
        # rnn_output contains the hidden state at each time step
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        # reverse the time-ordering in the hidden state and un-pack it
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = self.z_q_0

        # sample the latents z one time step at a time
        for t in range(1, T_max + 1):
            # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
            z_mu, z_sigma = self.combiner(z_prev, rnn_output[:, t - 1, :])
            z_dist = dist.normal

            # if we are using normalizing flows, we apply the sequence of transformations
            # parameterized by self.iafs to the base distribution defined in the previous line
            # to yield a transformed distribution that we use for q(z_t|...)
            if self.iafs.__len__() > 0:
                z_dist = TransformedDistribution(z_dist, self.iafs)
            # sample z_t from the distribution z_dist
            z_t = pyro.sample("z_%d" % t,
                              z_dist,
                              z_mu,
                              z_sigma,
                              log_pdf_mask=annealing_factor * mini_batch_mask[:, t - 1:t])
            # the latent sampled at this time step will be conditioned upon in the next time step
            # so keep track of it
            z_prev = z_t
Beispiel #17
0
    def model(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)

        # register all PyTorch (sub)modules with pyro
        # this needs to happen in both the model and guide
        pyro.module("dmm", self)

        # set z_prev = z_0 to setup the recursive conditioning in p(z_t | z_{t-1})
        z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0))

        # sample the latents z and observed x's one time step at a time
        for t in range(1, T_max + 1):
            # the next three lines of code sample z_t ~ p(z_t | z_{t-1})
            # note that (both here and elsewhere) log_pdf_mask takes care of both
            # (i)  KL annealing; and
            # (ii) raggedness in the observed data (i.e. different sequences
            #      in the mini-batch have different lengths)

            # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
            z_mu, z_sigma = self.trans(z_prev)
            # then sample z_t according to dist.Normal(z_mu, z_sigma)
            z_t = pyro.sample("z_%d" % t,
                              dist.normal,
                              z_mu,
                              z_sigma,
                              log_pdf_mask=annealing_factor * mini_batch_mask[:, t - 1:t])

            # compute the probabilities that parameterize the bernoulli likelihood
            emission_probs_t = self.emitter(z_t)
            # the next statement instructs pyro to observe x_t according to the
            # bernoulli distribution p(x_t|z_t)
            pyro.observe("obs_x_%d" % t, dist.bernoulli, mini_batch[:, t - 1, :],
                         emission_probs_t,
                         log_pdf_mask=mini_batch_mask[:, t - 1:t])
            # the latent sampled at this time step will be conditioned upon
            # in the next time step so keep track of it
            z_prev = z_t
    def model(self, response, mask, annealing_factor=1):
        if self.generative_model == 'irt':
            irt_model_fn = globals()[f'irt_model_{self.irt_num}pl']
            return irt_model_fn(
                self.ability_dim,
                response.size(0),
                self.num_item,
                response.device,
                response=response,
                mask=mask,
                annealing_factor=annealing_factor,
            )
        else:
            pyro.module("decoder", self.decoder)

            ability_prior = dist.Normal(
                torch.zeros((num_person, ability_dim), device=device),
                torch.ones((num_person, ability_dim), device=device),
            )
            with poutine.scale(scale=annealing_factor):
                ability = pyro.sample("ability", ability_prior)

            item_feat_prior = dist.Normal(
                torch.zeros((num_item, self.item_feat_dim), device=device),
                torch.ones((num_item, self.item_feat_dim), device=device),
            )
            item_feat = pyro.sample("item_feat", item_feat_prior)

            response_mu = self.decoder(ability, item_feat)

            if mask is not None:
                response_dist = dist.Bernoulli(response_mu).mask(mask)
            else:
                response_dist = dist.Bernoulli(response_mu)

            if response is not None:
                pyro.sample("response", response_dist, obs=response)
            else:
                response = pyro.sample("response", response_dist)
                return response, ability, item_feat
Beispiel #19
0
    def model(self, xs, ys=None):
        """
        The model corresponds to the following generative process:
        p(z) = normal(0,I)              # handwriting style (latent)
        p(y|x) = categorical(I/10.)     # which digit (semi-supervised)
        p(x|y,z) = bernoulli(mu(y,z))   # an image
        mu is given by a neural network  `decoder`

        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("ss_vae", self)

        batch_size = xs.size(0)
        with pyro.iarange("independent"):

            # sample the handwriting style from the constant prior distribution
            prior_mu = Variable(torch.zeros([batch_size, self.z_dim]))
            prior_sigma = Variable(torch.ones([batch_size, self.z_dim]))
            zs = pyro.sample("z", dist.normal, prior_mu, prior_sigma)

            # if the label y (which digit to write) is supervised, sample from the
            # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
            alpha_prior = Variable(
                torch.ones([batch_size, self.output_size]) /
                (1.0 * self.output_size))
            if ys is None:
                ys = pyro.sample("y", dist.categorical, alpha_prior)
            else:
                pyro.observe("y", dist.categorical, ys, alpha_prior)

            # finally, score the image (x) using the handwriting style (z) and
            # the class label y (which digit to write) against the
            # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
            # where `decoder` is a neural network
            mu = self.decoder.forward([zs, ys])
            pyro.observe("x", dist.bernoulli, xs, mu)
    def model(self, xs, ys=None):
        """
        The model corresponds to the following generative process:
        p(z) = normal(0,I)              # handwriting style (latent)
        p(y|x) = categorical(I/10.)     # which digit (semi-supervised)
        p(x|y,z) = bernoulli(loc(y,z))   # an image
        loc is given by a neural network  `decoder`

        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("ss_vae", self)

        batch_size = xs.size(0)
        options = dict(dtype=xs.dtype, device=xs.device)
        with pyro.plate("data"):

            # sample the handwriting style from the constant prior distribution
            prior_loc = torch.zeros(batch_size, self.z_dim, **options)
            prior_scale = torch.ones(batch_size, self.z_dim, **options)
            zs = pyro.sample("z",
                             dist.Normal(prior_loc, prior_scale).to_event(1))

            # if the label y (which digit to write) is supervised, sample from the
            # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
            alpha_prior = torch.ones(batch_size, self.output_size, **
                                     options) / (1.0 * self.output_size)
            ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

            # finally, score the image (x) using the handwriting style (z) and
            # the class label y (which digit to write) against the
            # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
            # where `decoder` is a neural network
            loc = self.decoder.forward([zs, ys])
            pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs)
            # return the loc so we can visualize it later
            return loc
Beispiel #21
0
    def model(self, xs, ys=None):
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("generation_net", self)
        batch_size = xs.shape[0]
        with pyro.plate("data"):

            # Prior network uses the baseline predictions as initial guess.
            # This is the generative process with recurrent connection
            with torch.no_grad():
                # this ensures the training process does not change the
                # baseline network
                y_hat = self.baseline_net(xs).view(xs.shape)

            # sample the handwriting style from the prior distribution, which is
            # modulated by the input xs.
            prior_loc, prior_scale = self.prior_net(xs, y_hat)
            zs = pyro.sample("z",
                             dist.Normal(prior_loc, prior_scale).to_event(1))

            # the output y is generated from the distribution pθ(y|x, z)
            loc = self.generation_net(zs)

            if ys is not None:
                # In training, we will only sample in the masked image
                mask_loc = loc[(xs == -1).view(-1, 784)].view(batch_size, -1)
                mask_ys = ys[xs == -1].view(batch_size, -1)
                pyro.sample(
                    "y",
                    dist.Bernoulli(mask_loc, validate_args=False).to_event(1),
                    obs=mask_ys,
                )
            else:
                # In testing, no need to sample: the output is already a
                # probability in [0, 1] range, which better represent pixel
                # values considering grayscale. If we sample, we will force
                # each pixel to be  either 0 or 1, killing the grayscale
                pyro.deterministic("y", loc.detach())

            # return the loc so we can visualize it later
            return loc
Beispiel #22
0
    def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)
        # register all PyTorch (sub)modules with pyro
        pyro.module("dmm", self)

        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        h_0_contig = self.h_0 if not self.use_cuda \
            else self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous()
        # push the observed x's through the rnn;
        # rnn_output contains the hidden state at each time step
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        # reverse the time-ordering in the hidden state and un-pack it
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = self.z_q_0

        # sample the latents z one time step at a time
        for t in range(1, T_max + 1):
            # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
            z_mu, z_sigma = self.combiner(z_prev, rnn_output[:, t - 1, :])
            z_dist = dist.normal

            # if we are using normalizing flows, we apply the sequence of transformations
            # parameterized by self.iafs to the base distribution defined in the previous line
            # to yield a transformed distribution that we use for q(z_t|...)
            if self.iafs.__len__() > 0:
                z_dist = TransformedDistribution(z_dist, self.iafs)
            # sample z_t from the distribution z_dist
            z_t = pyro.sample("z_%d" % t,
                              z_dist,
                              z_mu,
                              z_sigma,
                              log_pdf_mask=annealing_factor * mini_batch_mask[:, t - 1:t])
            # the latent sampled at this time step will be conditioned upon in the next time step
            # so keep track of it
            z_prev = z_t
Beispiel #23
0
    def model(self, imgs=None):
        """ 1. sample the latent from the prior:
            2. runs the generative model
            3. score the generative model against actual data 
        """
        #-----------------------#
        #--------  Trick -------#
        #-----------------------#
        if (imgs is None):
            observed = False
            imgs = torch.zeros(8, self.ch, self.height, self.width)
            if (self.use_cuda):
                imgs = imgs.cuda()
        else:
            observed = True
        #-----------------------#
        #----- Enf of Trick ----#
        #-----------------------#
        batch_size, ch, width, height = imgs.shape

        zero = imgs.new_zeros(1)
        one = imgs.new_ones(1)

        pyro.module("decoder", self.decoder)
        sigma_obs = pyro.param("sigma_obs",
                               0.1 * one,
                               constraint=constraints.interval(0.01, 0.5))

        with pyro.plate('batch_size', batch_size, dim=-1):
            with poutine.scale(scale=self.scale):
                z = pyro.sample(
                    'z_latent',
                    dist.Normal(zero, one).expand([self.dim_z]).to_event(1))
            x = self.decoder(z)  #x_mu is between 0 and 1
            #pyro.sample('obs', dist.Normal(x.x_mu,x.x_std).to_event(1), obs=imgs.view(batch_size,-1))
            #pyro.sample('obs', dist.Normal(x.x_mu,sigma_obs).to_event(1), obs=imgs.view(batch_size,-1))
            pyro.sample('obs',
                        dist.Normal(x.x_mu, 0.1 * one).to_event(1),
                        obs=imgs.view(batch_size, -1))
        return x
Beispiel #24
0
    def model(self, x, y):
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            # setup hyperparameters of prior, z : (loc_z, scale_z)
            loc_z = torch.zeros(x.shape[0],
                                self.z_dim,
                                dtype=x.dtype,
                                device=x.device)
            scale_z = torch.ones(x.shape[0],
                                 self.z_dim,
                                 dtype=x.dtype,
                                 device=x.device)
            # sample from prior
            z = pyro.sample("latent", Normal(loc_z, scale_z).to_event(1))
            # decoder sampled latent code
            loc_img = self.decoder(z, y)
            # sample from Bernoulli with `p=loc_img`
            pyro.sample("obs",
                        Bernoulli(loc_img).to_event(1),
                        obs=x.reshape(-1, 784))

            return loc_img
    def model(self, docs=None, doc_categories=None):
        # Register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)

        with pyro.plate("document_categories", doc_categories.shape[0]):
            # Dirichlet prior  𝑝(𝜃|𝛼) is replaced by a log-normal distribution
            theta_loc = docs.new_zeros((docs.shape[0], self.num_topics))
            theta_scale = docs.new_ones((docs.shape[0], self.num_topics))
            theta = pyro.sample(
                "theta",
                dist.LogNormal(theta_loc, theta_scale).to_event(1))
            theta = theta / theta.sum(-1, keepdim=True)

            # conditional distribution of 𝑤𝑛 is defined as
            # 𝑤𝑛|𝛽,𝜃 ~ Categorical(𝜎(𝛽𝜃))
            count_param = self.decoder(theta)

        with pyro.plate("documents", docs.shape[0]):
            pyro.sample('obs',
                        dist.Multinomial(docs.shape[1],
                                         count_param).to_event(1),
                        obs=docs)
    def guide(self, x_sent, y_sent, decode=False, kl_annealing=1.0):
        pyro.module("vnmt", self)
        #transform sentences into embeddings
        x_embeds, x_len, x_mask, y_sent = self.x_embed(x_sent, y=y_sent)
        y_embeds, y_len, y_mask, y_indices = self.y_embed(y_sent,
                                                          y=range(len(y_sent)),
                                                          pack_seq=True)

        #run our sequences through rnn
        #2nd output is just last hidden state where as I want ALL hidden states which are output

        x_out, _ = self.encoder(x_embeds)
        X, x_len = pad_packed_sequence(x_out, batch_first=self.b_f)
        #TODO one problem with this is that the sequences are varying lengths
        #TODO i.e. even if the rest of the dims are 0 vectors the strength of
        #TODO each of existing entries might be lessened depending on how long longest sequence is
        if self.use_cuda:
            x_len = x_len.cuda()
        X = torch.sum(X, dim=1) / x_len.unsqueeze(1).float()

        y_out, _ = self.encoder(y_embeds)
        Y, y_len = pad_packed_sequence(y_out, batch_first=self.b_f)
        if self.use_cuda:
            y_len = y_len.cuda()
        Y = torch.sum(Y, dim=1) / y_len.unsqueeze(1).float()

        Y = zip([r for r in Y], y_indices)
        Y = sorted(Y, key=lambda x: x[1], reverse=False)
        Y = torch.stack([y[0] for y in Y])

        #Mean pool over the hidden states to
        z_input = torch.cat([X, Y], dim=1)
        z_mean, z_sig = self.posterior(z_input)

        #sample from our variational distribution P(Z | X, Y)
        with pyro.plate('z_minibatch'):
            with poutine.scale(scale=kl_annealing):
                semantics = pyro.sample('z_semantics',
                                        dist.Normal(z_mean, z_sig))
Beispiel #27
0
 def guide(self,
           x: torch.Tensor,
           **kwargs: float) -> None:
     """
     Defines the guide q(z,c|x)
     """
     # register PyTorch module `encoder_z` with Pyro
     pyro.module("encoder_z", self.encoder_z)
     # KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl)
     beta = kwargs.get("scale_factor", [1., 1.])
     if isinstance(beta, (float, int, list)):
         beta = torch.tensor(beta)
     if beta.ndim == 0:
         beta = torch.tensor([beta, beta])
     with pyro.plate("data"):
         # use the encoder to get the parameters used to define q(z,c|x)
         z_loc, z_scale, alpha = self.encoder_z(x)
         # sample the latent code z
         with pyro.poutine.scale(scale=beta[0]):
             pyro.sample("latent_cont", dist.Normal(z_loc, z_scale).to_event(1))
         with pyro.poutine.scale(scale=beta[1]):
             pyro.sample("latent_disc", dist.OneHotCategorical(alpha))
Beispiel #28
0
    def model(self, data):
        # set up parameters
        pyro.module('VIN', self)

        q_prev = self.q2
        dq_prev = (self.q2 - self.q1) / self.h
        self.T = data.shape[1]
        with pyro.plate('mini_batch', len(data)):
            pyro.sample('y_0',
                        dist.Normal(loc=self.q1, scale=self.sigma).to_event(1),
                        obs=data[:, 0].unsqueeze(-1))
            pyro.sample('y_1',
                        dist.Normal(loc=self.q2, scale=self.sigma).to_event(1),
                        obs=data[:, 1].unsqueeze(-1))
            for t in range(2, self.T):
                q, dq = self.trans(q_prev, dq_prev, self.h)
                pyro.sample('y_{}'.format(t),
                            dist.Normal(loc=q, scale=self.sigma).to_event(1),
                            obs=data[:, t].unsqueeze(-1))

                q_prev = q
                dq_prev = dq
Beispiel #29
0
    def guide(self, x, y=None):
        pyro.module("scanvi", self)
        with pyro.plate("batch", len(x)), poutine.scale(scale=self.scale_factor):
            z2_loc, z2_scale, l_loc, l_scale = self.z2l_encoder(x)
            pyro.sample("l", dist.LogNormal(l_loc, l_scale).to_event(1))
            z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1))

            y_logits = self.classifier(z2)
            y_dist = dist.OneHotCategorical(logits=y_logits)
            if y is None:
                # x is unlabeled so sample y using q(y|z2)
                y = pyro.sample("y", y_dist)
            else:
                # x is labeled so add a classification loss term
                # (this way q(y|z2) learns from both labeled and unlabeled data)
                classification_loss = y_dist.log_prob(y)
                # Note that the negative sign appears because we're adding this term in the guide
                # and the guide log_prob appears in the ELBO as -log q
                pyro.factor("classification_loss", -self.alpha * classification_loss)

            z1_loc, z1_scale = self.z1_encoder(z2, y)
            pyro.sample("z1", dist.Normal(z1_loc, z1_scale).to_event(1))
        def model(self, x, y):
            pyro.module(self.name_prefix + ".gp", self)

            # Draw sample from q(f)
            function_dist = self.pyro_model(x, name_prefix=self.name_prefix)

            # Draw samples of cluster assignments
            cluster_assignment_samples = pyro.sample(
                self.name_prefix + ".cluster_logits",
                pyro.distributions.OneHotCategorical(logits=torch.zeros(self.num_tasks, self.num_functions)).to_event(
                    1
                ),
            )

            # Sample from observation distribution
            with pyro.plate(self.name_prefix + ".output_values_plate", function_dist.batch_shape[-1], dim=-1):
                function_samples = pyro.sample(self.name_prefix + ".f", function_dist)
                obs_dist = pyro.distributions.Normal(
                    loc=(function_samples.unsqueeze(-2) * cluster_assignment_samples).sum(-1), scale=self.noise.sqrt()
                ).to_event(1)
                with pyro.poutine.scale(scale=(self.num_data / y.size(-2))):
                    return pyro.sample(self.name_prefix + ".y", obs_dist, obs=y)
Beispiel #31
0
def model(data):
    rr = pyro.param("l1", torch.Tensor(5, 10))

    p_net_first = pyro.module(pyro_name="p1", Net, my_args)
    p_net_first = pyro.module(pyro_name="p1", Net(my_args))
    p_net_first = pyro.module(pyro_name="p1", net)

    p_net_second = pyro.module(pyro_name="p2", Net, my_args)

    p_net_model.behave_different = flip()

    #pyro.module("dope_name", MNet())

    # else:
    # p_net = pyro.module("dope_other", SuperNet())
    # do somethign with rr
    # model_forward = torch.mm(rr,data)
    model_forward = p_net.forward(data)

    # net.paramters()

    return model_forward
Beispiel #32
0
    def loss_fn(design, num_particles, evaluation=False, **kwargs):

        try:
            pyro.module("h", h)
        except AssertionError:
            pass

        expanded_design = lexpand(design, num_particles)
        model_conditional_trace = poutine.trace(model_conditional).get_trace(expanded_design)

        if not evaluation:
            model_marginal_trace = poutine.trace(model_marginal).get_trace(expanded_design)

            h_joint = h(expanded_design, model_conditional_trace, observation_labels, target_labels)
            h_independent = h(expanded_design, model_marginal_trace, observation_labels, target_labels)

            terms = torch.nn.functional.softplus(-h_joint) + torch.nn.functional.softplus(h_independent)
            return _safe_mean_terms(terms)

        else:
            h_joint = h(expanded_design, model_conditional_trace, observation_labels, target_labels)
            return _safe_mean_terms(h_joint)
Beispiel #33
0
    def guide(self, input_variable, target_variable, step):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder_dense", self.encoder_dense)
        pyro.module("encoder_rnn", self.encoder_rnn)

        # init vars
        input_length = input_variable.shape[0]
        encoder_outputs = Variable(torch.zeros(input_length, self.encoder_rnn.hidden_size))
        encoder_outputs = encoder_outputs.cuda() if USE_CUDA else encoder_outputs
        encoder_hidden = self.encoder_rnn.init_hidden_gru()

        # loop to encode
        for ei in range(input_length):
            encoder_output, encoder_hidden = self.encoder_rnn(
                input_variable[ei], encoder_hidden)
            encoder_outputs[ei] = encoder_output[0][0]

        # use the encoder to get the parameters used to define q(z|x)
        z_mu, z_sigma = self.encoder_dense(encoder_hidden)

        # sample the latent code z
        pyro.sample("latent", dist.normal, z_mu, z_sigma)
Beispiel #34
0
    def model(self, raw_expr, encoded_expr, read_depth):

        pyro.module("decoder", self.decoder)

        dispersion = pyro.param("dispersion",
                                torch.tensor(5.).to(self.device) *
                                torch.ones(self.num_genes).to(self.device),
                                constraint=constraints.positive)

        with pyro.plate("cells", encoded_expr.shape[0]):

            # Dirichlet prior  𝑝(𝜃|𝛼) is replaced by a log-normal distribution
            theta_loc = self.prior_mu * encoded_expr.new_ones(
                (encoded_expr.shape[0], self.num_topics))
            theta_scale = self.prior_std * encoded_expr.new_ones(
                (encoded_expr.shape[0], self.num_topics))
            theta = pyro.sample(
                "theta",
                dist.LogNormal(theta_loc, theta_scale).to_event(1))
            theta = theta / theta.sum(-1, keepdim=True)

            read_scale = pyro.sample(
                'read_depth',
                dist.LogNormal(torch.log(read_depth), 1.).to_event(1))

            #read_scale = torch.minimum(read_scale, self.max_scale)
            # conditional distribution of 𝑤𝑛 is defined as
            # 𝑤𝑛|𝛽,𝜃 ~ Categorical(𝜎(𝛽𝜃))
            expr_rate, dropout = self.decoder(theta)

            mu = torch.multiply(read_scale, expr_rate)
            p = torch.minimum(mu / (mu + dispersion), self.max_prob)

            pyro.sample('obs',
                        dist.ZeroInflatedNegativeBinomial(
                            total_count=dispersion,
                            probs=p,
                            gate_logits=dropout).to_event(1),
                        obs=raw_expr)
Beispiel #35
0
 def model(self,
           x: torch.Tensor,
           y: Optional[torch.Tensor] = None,
           **kwargs: float) -> None:
     """
     Defines the model p(x|z)p(z)
     """
     # register PyTorch module `decoder` with Pyro
     pyro.module("decoder", self.decoder)
     # KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl)
     beta = kwargs.get("scale_factor", 1.)
     reshape_ = torch.prod(torch.tensor(x.shape[1:])).item()
     with pyro.plate("data", x.shape[0]):
         # setup hyperparameters for prior p(z)
         z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
         z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
         # sample from prior (value will be sampled by guide when computing the ELBO)
         with pyro.poutine.scale(scale=beta):
             z = pyro.sample("latent",
                             dist.Normal(z_loc, z_scale).to_event(1))
         if self.coord > 0:  # rotationally- and/or translationaly-invariant mode
             # Split latent variable into parts for rotation
             # and/or translation and image content
             phi, dx, sc, z = self.split_latent(z)
             if 't' in self.invariances:
                 dx = (dx * self.t_prior).unsqueeze(1)
             # transform coordinate grid
             grid = self.grid.expand(x.shape[0], *self.grid.shape)
             x_coord_prime = transform_coordinates(grid, phi, dx, sc)
         # Add class label (if any)
         if y is not None:
             z = torch.cat([z, y], dim=-1)
         # decode the latent code z together with the transformed coordinates (if any)
         dec_args = (x_coord_prime, z) if self.coord else (z, )
         loc = self.decoder(*dec_args)
         # score against actual images ("binary cross-entropy loss")
         pyro.sample("obs",
                     self.sampler_d(loc.view(-1, reshape_)).to_event(1),
                     obs=x.view(-1, reshape_))
Beispiel #36
0
    def model_ncls(self, xs, ys=None):
        pyro.module("ss_vae", self)
        batch_size = xs.size(0)
        options = dict(out=None,
                       dtype=xs.dtype,
                       layout=torch.strided,
                       device=xs.device,
                       requires_grad=False)

        with pyro.plate("data"):
            prior_loc = torch.zeros(batch_size, 50, **options)
            prior_scale = torch.ones(batch_size, 50, **options)
            zs2 = pyro.sample("z2",
                              dist.Normal(prior_loc, prior_scale).to_event(1))

            alpha_prior = torch.ones(batch_size, 4, **options) / 4.0
            ys_ = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

            z1_mean, z1_std = self.decoder_z1(zs2, ys_)
            zs1 = pyro.sample("z1", dist.Normal(z1_mean, z1_std).to_event(1))
            x_mean, x_std = self.decoder_x(zs1)
            pyro.sample('x', dist.Normal(x_mean, x_std).to_event(3), obs=xs)
def guide(xs, ys=None):
    """
    The guide corresponds to the following:
    q(y|x) = categorical(alpha(x))              # infer digit from an image
    q(z|x,y) = normal(loc(x,y),scale(x,y))       # infer handwriting style from an image and the digit
    loc, scale are given by a neural network `encoder_z`
    alpha is given by a neural network `encoder_y`

    :param xs: a batch of scaled vectors of pixels from an image
    :param ys: (optional) a batch of the class labels i.e.
               the digit corresponding to the image(s)
    :return: None
    """
    pyro.module("encoder_y_fst", encoder_y_fst)
    pyro.module("encoder_y_snd", encoder_y_snd)
    pyro.module("encoder_z_fst", encoder_z_fst)
    pyro.module("encoder_z_out1", encoder_z_out1)
    pyro.module("encoder_z_out2", encoder_z_out2)
    # inform Pyro that the variables in the batch of xs, ys are conditionally independent
    with pyro.plate("data"):

        # if the class label (the digit) is not supervised, sample
        # (and score) the digit with the variational distribution
        # q(y|x) = categorical(alpha(x))
        if ys is None:
            hidden = softplus(encoder_y_fst(xs))
            alpha = softmax(encoder_y_snd(hidden))
            ys = pyro.sample("y", OneHotCategorical(alpha))

        # sample (and score) the latent handwriting-style with the variational
        # distribution q(z|x,y) = normal(loc(x,y),scale(x,y))
        # shape = broadcast_shape(torch.Size([200]), ys.shape[:-1]) + (-1,)
        shape = ys.shape[:-1] + (-1, )
        hidden_z = softplus(
            encoder_z_fst(torch.cat(
                [xs.expand(shape), ys.expand(shape)], -1)))
        loc = encoder_z_out1(hidden_z)
        scale = torch.exp(encoder_z_out2(hidden_z))
        pyro.sample("z", Normal(loc, scale).to_event(1))
Beispiel #38
0
    def model(self, input, output):
        '''
    Likelihood model: sample from prior, then decode to video.
    param input: video of size (batch_size, self.n_frames_input, C, H, W)
    param output: video of size (batch_size, self.n_frames_output, C, H, W)
    '''
        # Register networks
        for name, net in self.model_modules.items():
            pyro.module(name, net)

        observation = output
        corrupted_observation = input

        # Sample from prior
        latent = self.sample_latent_prior(input)
        # Decode
        decoded_output, masked_decoded_output, components = self.decode(latent)

        if self.use_crop_size and self.gamma_steps > self.gamma_switch_step and self.is_train:
            mask_out = Variable(torch.ones_like(decoded_output).cuda())
            mask_out[:, :, :, :(self.image_size - self.crop_size[1])] = 0
            decoded_output = decoded_output * mask_out
            if self.gamma_steps == self.gamma_switch_step + 1:
                print("Loss evaluated in visible frame.")

        # Observe - prediction
        decoded_output = decoded_output[:, self.n_frames_input:].view(
            *observation.size())
        sd = Variable(0.3 * torch.ones(*decoded_output.size()).cuda())
        pyro.sample('obs', dist.Normal(decoded_output, sd), obs=observation)

        # Observe - corrupted reconstruction
        masked_decoded_output = masked_decoded_output.view(
            *corrupted_observation.size())
        sd = Variable(0.3 * torch.ones(*masked_decoded_output.size()).cuda())
        pyro.sample('masked_obs',
                    dist.Normal(masked_decoded_output, sd),
                    obs=corrupted_observation)
Beispiel #39
0
    def _parametrized_guide(self, predictor, data, labels, batch_size=None):
        args = self.args
        # Use a conjugate guide for global variables.
        topic_weights_posterior = pyro.param(
            "topic_weights_posterior",
            lambda: torch.ones(args.num_topics),
            constraint=constraints.positive)
        topic_words_posterior = pyro.param(
            "topic_words_posterior",
            lambda: torch.ones(args.num_topics, args.num_words),
            constraint=constraints.greater_than(0.5))

        # Is this needed?
        label_posterior = pyro.param("label_posterior",
                                     lambda: torch.ones(2, args.num_topics),
                                     constraint=constraints.positive)

        with pyro.plate("topics", args.num_topics):
            pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior,
                                                    1.))
            pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))
            pyro.sample("label_prior", dist.Beta(*label_posterior))

        # Use an amortized guide for local variables.
        pyro.module("predictor", predictor)
        with pyro.plate("documents", args.num_docs, batch_size) as ind:
            data = data[:, ind]
            labels = labels[ind]
            counts = (torch.zeros(args.num_words, ind.size(0)).scatter_add(
                0, data, torch.ones(data.shape)))

            augmented_input = torch.zeros(batch_size,
                                          args.num_words + args.num_topics)
            augmented_input[:, :args.num_words] = counts.transpose(0, 1)
            augmented_input[:, args.num_words:] = labels

            doc_topics = predictor(augmented_input)
            pyro.sample("doc_topics", dist.Delta(doc_topics).to_event(1))
Beispiel #40
0
    def model(self, xs, ys=None):
        """
        The model corresponds to the following generative process:
        p(z) = normal(0,I)              # handwriting style (latent)
        p(y|x) = categorical(I/10.)     # which digit (semi-supervised)
        p(x|y,z) = bernoulli(mu(y,z))   # an image
        mu is given by a neural network  `decoder`

        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("ss_vae", self)

        batch_size = xs.size(0)
        with pyro.iarange("independent"):

            # sample the handwriting style from the constant prior distribution
            prior_mu = Variable(torch.zeros([batch_size, self.z_dim]))
            prior_sigma = Variable(torch.ones([batch_size, self.z_dim]))
            zs = pyro.sample("z", dist.normal, prior_mu, prior_sigma)

            # if the label y (which digit to write) is supervised, sample from the
            # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
            alpha_prior = Variable(torch.ones([batch_size, self.output_size]) / (1.0 * self.output_size))
            if ys is None:
                ys = pyro.sample("y", dist.categorical, alpha_prior)
            else:
                pyro.observe("y", dist.categorical, ys, alpha_prior)

            # finally, score the image (x) using the handwriting style (z) and
            # the class label y (which digit to write) against the
            # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
            # where `decoder` is a neural network
            mu = self.decoder.forward([zs, ys])
            pyro.observe("x", dist.bernoulli, xs, mu)
Beispiel #41
0
 def model(self,
           xs: torch.Tensor,
           ys: Optional[torch.Tensor] = None,
           **kwargs: float) -> None:
     """
     Model of the generative process p(x|z,y)p(y)p(z)
     """
     pyro.module("ss_vae", self)
     batch_dim = xs.size(0)
     specs = dict(dtype=xs.dtype, device=xs.device)
     beta = kwargs.get("scale_factor", 1.)
     # pyro.plate enforces independence between variables in batches xs, ys
     with pyro.plate("data"):
         # sample the latent vector from the constant prior distribution
         prior_loc = torch.zeros(batch_dim, self.z_dim, **specs)
         prior_scale = torch.ones(batch_dim, self.z_dim, **specs)
         with pyro.poutine.scale(scale=beta):
             zs = pyro.sample(
                 "z", dist.Normal(prior_loc, prior_scale).to_event(1))
         # split latent variable into parts for rotation and/or translation
         # and image content
         if self.coord > 0:
             phi, dx, sc, zs = self.split_latent(zs)
             if 't' in self.invariances:
                 dx = (dx * self.t_prior).unsqueeze(1)
             # transform coordinate grid
             grid = self.grid.expand(zs.shape[0], *self.grid.shape)
             x_coord_prime = transform_coordinates(grid, phi, dx, sc)
         # sample label from the constant prior or observe the value
         c_prior = (torch.zeros(batch_dim, self.reg_dim, **specs))
         ys = pyro.sample(
             "y", dist.Normal(c_prior, self.reg_sig).to_event(1), obs=ys)
         # Score against the parametrized distribution
         # p(x|y,z) = bernoulli(decoder(y,z))
         d_args = (x_coord_prime, [zs, ys]) if self.coord else ([zs, ys],)
         loc = self.decoder(*d_args)
         loc = loc.view(*ys.shape[:-1], -1)
         pyro.sample("x", self.sampler_d(loc).to_event(1), obs=xs)
Beispiel #42
0
    def model(self, x):
        pyro.module("decoder_c", self.decoder_c)
        pyro.module("decoder_y", self.decoder_y)

        with pyro.plate("data", x.shape[0]):
            # x is (Outcome, Class, Age, Sex)

            # prior on U_c
            mean_c = x.new_zeros(torch.Size((x.shape[0], self.U_c_dim)))
            std_c = x.new_ones(torch.Size((x.shape[0], self.U_c_dim)))
            U_c = pyro.sample("U_c", dist.Normal(mean_c, std_c).to_event(1))

            # prior on U_y
            mean_y = x.new_zeros(torch.Size((x.shape[0], self.U_y_dim)))
            std_y = x.new_ones(torch.Size((x.shape[0], self.U_y_dim)))
            U_y = pyro.sample("U_y", dist.Normal(mean_y, std_y).to_event(1))

            # prior on Age
            mean_a = 29.7 * x.new_ones(torch.Size((x.shape[0], 1)))
            std_a = 14.5 * x.new_ones(torch.Size((x.shape[0], 1)))
            A = pyro.sample("Age", dist.Normal(mean_a, std_a).to_event(1))

            # prior on Sex
            prob_s = 0.6476 * x.new_ones(torch.Size((x.shape[0], 1)))
            S = pyro.sample("Sex", dist.Bernoulli(prob_s).to_event(1))

            # decode the latent code z
            C_probs = self.decoder_c(U_c, A, S)
            C = pyro.sample("Class",
                            dist.Multinomial(probs=C_probs).to_event(1),
                            obs=to_one_hot(x[:, 1], self.num_classes))
            C = one_hot_to_idx(C)

            # score against actual outcome
            Y = self.decoder_y(U_y, A, S, C)
            pyro.sample("Outcome",
                        dist.Bernoulli(probs=Y).to_event(1),
                        obs=x[:, 0].reshape(-1, 1))
Beispiel #43
0
    def guide(self, batch, subsample, full_size):
        num_time_steps, batch_size = batch.shape
        self.map_estimate("drift")

        group = self.group(match="state_[0-9]*")
        cov_diag = pyro.param("state_cov_diag",
                              lambda: torch.full(group.event_shape, 0.01),
                              constraint=constraints.positive)
        cov_factor = pyro.param(
            "state_cov_factor", lambda: torch.randn(group.event_shape +
                                                    (rank, )) * 0.01)

        if not hasattr(self, "nn"):
            self.nn = torch.nn.Linear(group.event_shape.numel(),
                                      group.event_shape.numel())
            self.nn.weight.data.fill_(1.0 / num_time_steps)
            self.nn.bias.data.fill_(-0.5)
        pyro.module("state_nn", self.nn)
        with self.plate("data", full_size, subsample=subsample):
            loc = self.nn(batch.t())
            group.sample(
                "states",
                dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))
Beispiel #44
0
def main(observations={"x1": 0, "x2": 0}):
    pyro.module("first", first)
    pyro.module("second", second)
    pyro.module("third", third)
    pyro.module("fourth", fourth)
    pyro.module("fifth", fifth)

    obs = torch.tensor([float(observations["x1"]),
                        float(observations["x2"])])
    x1 = obs[0]
    x2 = obs[1]
    v = torch.cat((torch.Tensor.view(x1, [1, 1]),
                   torch.Tensor.view(x2, [1, 1])), 1)

    h1  = relu(first(v))
    h2  = relu(second(h1))
    h3  = relu(third(h2))
    h4  = relu(fourth(h3))
    out = fifth(h4)

    mean = out[0, 0]
    std = torch.exp(out[0, 1])
    pyro.sample("z", Normal(mean, std))
Beispiel #45
0
 def model(self, x):
     # register PyTorch module `decoder` with Pyro
     pyro.module("decoder", self.decoder)
     with pyro.plate("data", x.shape[0]):
         # setup hyperparameters for prior p(z)
         z_loc = torch.zeros(x.shape[0],
                             self.z_dim,
                             dtype=x.dtype,
                             device=x.device)
         z_scale = torch.ones(x.shape[0],
                              self.z_dim,
                              dtype=x.dtype,
                              device=x.device)
         # sample from prior (value will be sampled by guide when computing the ELBO)
         z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
         # decode the latent code z
         loc_img = self.decoder.forward(z)
         # score against actual images
         pyro.sample("obs",
                     dist.Bernoulli(loc_img).to_event(1),
                     obs=x.reshape(-1, 784))
         # return the loc so we can visualize it later
         return loc_img
Beispiel #46
0
    def guide(self, max_time_step):
        pyro.module("agentmodel", self)

        observation = reset_env()
        for t in range(MAXTIME):
            state = observation
            observation = torch.from_numpy(observation).float()
            prob_action = self.policy(observation)
            action = pyro.sample("action_{}".format(t),
                                 dist.Bernoulli(prob_action))
            action = round(action.item())
            observation, reward, done, _ = env.step(action)

            if done and self.echo:
                print("guide exit at", t)
                return t
                break

            if (done):
                break
        if self.echo:
            print("guide solve the problem at t:", max_time_step)
            return max_time_step
Beispiel #47
0
    def test_save_and_load(self):
        lin = pyro.module("mymodule", self.linear_module)
        pyro.module("mymodule2", self.linear_module2)
        x = torch.randn(1, 3)
        myparam = pyro.param("myparam", torch.tensor(1.234 * torch.ones(1), requires_grad=True))

        cost = torch.sum(torch.pow(lin(x), 2.0)) * torch.pow(myparam, 4.0)
        cost.backward()
        params = list(self.linear_module.parameters()) + [myparam]
        optim = torch.optim.Adam(params, lr=.01)
        myparam_copy_stale = copy(pyro.param("myparam").detach().cpu().numpy())

        optim.step()

        myparam_copy = copy(pyro.param("myparam").detach().cpu().numpy())
        param_store_params = copy(pyro.get_param_store()._params)
        param_store_param_to_name = copy(pyro.get_param_store()._param_to_name)
        assert len(list(param_store_params.keys())) == 5
        assert len(list(param_store_param_to_name.values())) == 5

        pyro.get_param_store().save('paramstore.unittest.out')
        pyro.clear_param_store()
        assert len(list(pyro.get_param_store()._params)) == 0
        assert len(list(pyro.get_param_store()._param_to_name)) == 0
        pyro.get_param_store().load('paramstore.unittest.out')

        def modules_are_equal():
            weights_equal = np.sum(np.fabs(self.linear_module3.weight.detach().cpu().numpy() -
                                   self.linear_module.weight.detach().cpu().numpy())) == 0.0
            bias_equal = np.sum(np.fabs(self.linear_module3.bias.detach().cpu().numpy() -
                                self.linear_module.bias.detach().cpu().numpy())) == 0.0
            return (weights_equal and bias_equal)

        assert not modules_are_equal()
        pyro.module("mymodule", self.linear_module3, update_module_params=False)
        assert id(self.linear_module3.weight) != id(pyro.param('mymodule$$$weight'))
        assert not modules_are_equal()
        pyro.module("mymodule", self.linear_module3, update_module_params=True)
        assert id(self.linear_module3.weight) == id(pyro.param('mymodule$$$weight'))
        assert modules_are_equal()

        myparam = pyro.param("myparam")
        store = pyro.get_param_store()
        assert myparam_copy_stale != myparam.detach().cpu().numpy()
        assert myparam_copy == myparam.detach().cpu().numpy()
        assert sorted(param_store_params.keys()) == sorted(store._params.keys())
        assert sorted(param_store_param_to_name.values()) == sorted(store._param_to_name.values())
        assert sorted(store._params.keys()) == sorted(store._param_to_name.values())
Beispiel #48
0
 def prior(self, n, **kwargs):
     pyro.module("decode", self.decode)
     return self.local_model(n, **kwargs)
Beispiel #49
0
 def guide(self, data):
     encoder = pyro.module('encoder', self.vae_encoder)
     with pyro.iarange('data', data.size(0)):
         z_mean, z_var = encoder.forward(data)
         pyro.sample('latent', Normal(z_mean, z_var.sqrt()).independent(1))
Beispiel #50
0
 def model(self, data, batch_size, **kwargs):
     pyro.module("decode", self.decode)
     with pyro.iarange('data', data.size(0), use_cuda=self.use_cuda) as ix:
         return self.local_model(batch_size, data[ix], **kwargs)
Beispiel #51
0
    def guide(self, data, batch_size, **kwargs):
        pyro.module('rnn', self.rnn),
        pyro.module('predict', self.predict),
        pyro.module('encode', self.encode),
        pyro.module('embed', self.embed),
        pyro.module('bl_rnn', self.bl_rnn, tags='baseline'),
        pyro.module('bl_predict', self.bl_predict, tags='baseline'),
        pyro.module('bl_embed', self.bl_embed, tags='baseline')

        pyro.param('h_init', self.h_init)
        pyro.param('c_init', self.c_init)
        pyro.param('z_where_init', self.z_where_init)
        pyro.param('z_what_init', self.z_what_init)
        pyro.param('bl_h_init', self.bl_h_init, tags='baseline')
        pyro.param('bl_c_init', self.bl_c_init, tags='baseline')

        with pyro.iarange('data', data.size(0), subsample_size=batch_size, use_cuda=self.use_cuda) as ix:
            batch = data[ix]
            return self.local_guide(batch.size(0), batch)
Beispiel #52
0
 def cnn_fn(x):
     return pyro.module("CNN", cnn)(x)
Beispiel #53
0
 def guide():
     pyro.module("mymodule", pt_guide)
     mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log)
     sigma = torch.pow(tau_q, -0.5)
     pyro.sample("mu_latent", dist.Normal(mu_q, sigma, reparameterized=reparameterized))
    def do_elbo_test(self, repa1, repa2, n_steps, prec, lr, use_nn_baseline, use_decaying_avg_baseline):
        if self.verbose:
            print(" - - - - - DO NORMALNORMALNORMAL ELBO TEST - - - - - -")
            print("[reparameterized = %s, %s; nn_baseline = %s, decaying_baseline = %s]" %
                  (repa1, repa2, use_nn_baseline, use_decaying_avg_baseline))
        pyro.clear_param_store()

        if use_nn_baseline:

            class VanillaBaselineNN(nn.Module):
                def __init__(self, dim_input, dim_h):
                    super(VanillaBaselineNN, self).__init__()
                    self.lin1 = nn.Linear(dim_input, dim_h)
                    self.lin2 = nn.Linear(dim_h, 1)
                    self.sigmoid = nn.Sigmoid()

                def forward(self, x):
                    h = self.sigmoid(self.lin1(x))
                    return self.lin2(h)

            mu_prime_baseline = pyro.module("mu_prime_baseline", VanillaBaselineNN(2, 5), tags="baseline")
        else:
            mu_prime_baseline = None

        def model():
            mu_latent_prime = pyro.sample(
                    "mu_latent_prime",
                    dist.Normal(self.mu0, torch.pow(self.lam0, -0.5), reparameterized=repa1))
            mu_latent = pyro.sample(
                    "mu_latent",
                    dist.Normal(mu_latent_prime, torch.pow(self.lam0, -0.5), reparameterized=repa2))
            for i, x in enumerate(self.data):
                pyro.observe("obs_%d" % i, dist.normal, x, mu_latent,
                             torch.pow(self.lam, -0.5))
            return mu_latent

        # note that the exact posterior is not mean field!
        def guide():
            mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.334 * torch.ones(2),
                                               requires_grad=True))
            log_sig_q = pyro.param("log_sig_q", Variable(
                                   self.analytic_log_sig_n.data - 0.29 * torch.ones(2),
                                   requires_grad=True))
            mu_q_prime = pyro.param("mu_q_prime", Variable(torch.Tensor([-0.34, 0.52]),
                                    requires_grad=True))
            kappa_q = pyro.param("kappa_q", Variable(torch.Tensor([0.74]),
                                 requires_grad=True))
            log_sig_q_prime = pyro.param("log_sig_q_prime",
                                         Variable(-0.5 * torch.log(1.2 * self.lam0.data),
                                                  requires_grad=True))
            sig_q, sig_q_prime = torch.exp(log_sig_q), torch.exp(log_sig_q_prime)
            mu_latent_dist = dist.Normal(mu_q, sig_q, reparameterized=repa2)
            mu_latent = pyro.sample("mu_latent", mu_latent_dist,
                                    baseline=dict(use_decaying_avg_baseline=use_decaying_avg_baseline))
            mu_latent_prime_dist = dist.Normal(kappa_q.expand_as(mu_latent) * mu_latent + mu_q_prime,
                                               sig_q_prime,
                                               reparameterized=repa1)
            pyro.sample("mu_latent_prime",
                        mu_latent_prime_dist,
                        baseline=dict(nn_baseline=mu_prime_baseline,
                                      nn_baseline_input=mu_latent,
                                      use_decaying_avg_baseline=use_decaying_avg_baseline))

            return mu_latent

        # optim = Optimize(model, guide,
        #                 torch.optim.Adam, {"lr": lr, "betas": (0.97, 0.999)},
        #                 loss="ELBO", trace_graph=True,
        #                 auxiliary_optim_constructor=torch.optim.Adam,
        #                 auxiliary_optim_args={"lr": 5.0 * lr, "betas": (0.90, 0.999)})

        adam = optim.Adam({"lr": .0015, "betas": (0.97, 0.999)})
        svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)

        for k in range(n_steps):
            svi.step()

            mu_error = param_mse("mu_q", self.analytic_mu_n)
            log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n)
            mu_prime_error = param_mse("mu_q_prime", 0.5 * self.mu0)
            kappa_error = param_mse("kappa_q", 0.5 * ng_ones(1))
            log_sig_prime_error = param_mse("log_sig_q_prime", -0.5 * torch.log(2.0 * self.lam0))

            if k % 500 == 0 and self.verbose:
                print("errors:  %.4f, %.4f" % (mu_error, log_sig_error), end='')
                print(", %.4f, %.4f" % (mu_prime_error, log_sig_prime_error), end='')
                print(", %.4f" % kappa_error)

        self.assertEqual(0.0, mu_error, prec=prec)
        self.assertEqual(0.0, log_sig_error, prec=prec)
        self.assertEqual(0.0, mu_prime_error, prec=prec)
        self.assertEqual(0.0, log_sig_prime_error, prec=prec)
        self.assertEqual(0.0, kappa_error, prec=prec)
Beispiel #55
0
 def guide(self, data):
     encoder = pyro.module('encoder', self.vae_encoder)
     z_mean, z_var = encoder.forward(data)
     pyro.sample('latent', Normal(z_mean, z_var.sqrt()))
Beispiel #56
0
    def guide(self, data, batch_size, **kwargs):
        pyro.module('rnn', self.rnn),
        pyro.module('predict', self.predict),
        pyro.module('encode', self.encode),
        pyro.module('embed', self.embed),
        pyro.module('bl_rnn', self.bl_rnn),
        pyro.module('bl_predict', self.bl_predict),
        pyro.module('bl_embed', self.bl_embed)

        pyro.param('h_init', self.h_init)
        pyro.param('c_init', self.c_init)
        pyro.param('z_where_init', self.z_where_init)
        pyro.param('z_what_init', self.z_what_init)
        pyro.param('bl_h_init', self.bl_h_init)
        pyro.param('bl_c_init', self.bl_c_init)

        with pyro.iarange('data', data.size(0), subsample_size=batch_size, use_cuda=self.use_cuda) as ix:
            batch = data[ix]
            n = batch.size(0)

            # Embed inputs.
            flattened_batch = batch.view(n, -1)
            inputs = {
                'raw': batch,
                'embed': self.embed(flattened_batch),
                'bl_embed': self.bl_embed(flattened_batch)
            }

            # Initial state.
            state = GuideState(
                h=batch_expand(self.h_init, n),
                c=batch_expand(self.c_init, n),
                bl_h=batch_expand(self.bl_h_init, n),
                bl_c=batch_expand(self.bl_c_init, n),
                z_pres=self.prototype.new_ones(n, self.z_pres_size),
                z_where=batch_expand(self.z_where_init, n),
                z_what=batch_expand(self.z_what_init, n))

            z_pres = []
            z_where = []

            for t in range(self.num_steps):
                state = self.guide_step(t, n, state, inputs)
                z_where.append(state.z_where)
                z_pres.append(state.z_pres)

            return z_where, z_pres
Beispiel #57
0
 def model(self, data):
     decoder = pyro.module('decoder', self.vae_decoder)
     z_mean, z_std = ng_zeros([data.size(0), 20]), ng_ones([data.size(0), 20])
     z = pyro.sample('latent', Normal(z_mean, z_std))
     img = decoder.forward(z)
     pyro.observe('obs', Bernoulli(img), data.view(-1, 784))