Exemplo n.º 1
0
 def model():
     p2 = torch.tensor(torch.ones(2) / 2)
     p3 = torch.tensor(torch.ones(3) / 3)
     x2 = pyro.sample("x2", dist.OneHotCategorical(p2))
     x3 = pyro.sample("x3", dist.OneHotCategorical(p3))
     assert x2.shape == torch.Size([2]) + iarange_shape + p2.shape
     assert x3.shape == torch.Size([3, 1]) + iarange_shape + p3.shape
Exemplo n.º 2
0
    def model(self, xs, ys=None, aux_scale=46):
        pyro.module("ss_vae", self)
        batch_size = xs.size(0)
        options = dict(out=None,
                       dtype=xs.dtype,
                       layout=torch.strided,
                       device=xs.device,
                       requires_grad=False)

        with pyro.plate("data"):
            prior_loc = torch.zeros(batch_size, 50, **options)
            prior_scale = torch.ones(batch_size, 50, **options)
            zs2 = pyro.sample("z2",
                              dist.Normal(prior_loc, prior_scale).to_event(1))

            alpha_prior = torch.ones(batch_size, 4, **options) / 4.0
            ys_ = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

            z1_mean, z1_std = self.decoder_z1(zs2, ys_)
            zs1 = pyro.sample("z1", dist.Normal(z1_mean, z1_std).to_event(1))
            x_mean, x_std = self.decoder_x(zs1)
            pyro.sample('x', dist.Normal(x_mean, x_std).to_event(3), obs=xs)

            if ys is not None:
                alpha = self.encoder_y(zs1)
                # with pyro.poutine.scale(scale=self.aux_scale):
                with pyro.poutine.scale(scale=aux_scale):
                    pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
Exemplo n.º 3
0
def test_batch_log_dims(dim, probs):
    batch_pdf_shape = (3,) + (1,) * (dim-1)
    expected_log_prob_sum = np.array(wrap_nested(list(np.log(probs)), dim-1)).reshape(*batch_pdf_shape)
    probs = modify_params_using_dims(probs, dim)
    support = dist.OneHotCategorical(probs).enumerate_support()
    log_prob = dist.OneHotCategorical(probs).log_prob(support)
    assert_equal(log_prob.detach().cpu().numpy(), expected_log_prob_sum)
Exemplo n.º 4
0
    def generate_name(self,
                      lstm: Decoder,
                      address: str,
                      batch_size: int,
                      hidd_cell_states: tuple = None,
                      sample: bool = True):
        """
        lstm: Decoder associated with name being generated
        address: The address to correlate pyro distribution with latent variables
        hidd_cell_states: Previous LSTM hidden state or empty hidden state
        max_name_length: The max name length allowed
        """
        # If no hidden state is provided, initialize it with all 0s
        if hidd_cell_states == None:
            hidd_cell_states = lstm.init_hidden(batch_size=batch_size)

        input_tensor = strings_to_tensor([SOS] * batch_size, 1,
                                         letter_to_index)
        names = [''] * batch_size

        for index in range(MAX_NAME_LENGTH):
            char_dist, hidd_cell_states = lstm.forward(input_tensor,
                                                       hidd_cell_states)

            if sample:
                # Next LSTM input is the sampled character
                input_tensor = pyro.sample(f"unsup_{address}_{index}",
                                           dist.OneHotCategorical(char_dist))
                chars_at_indexes = list(
                    map(lambda index: MODEL_CHARS[int(index.item())],
                        torch.argmax(input_tensor, dim=2).squeeze(0)))
            else:
                # Next LSTM input is the character with the highest probability of occurring
                pyro.sample(f"unsup_{address}_{index}",
                            dist.OneHotCategorical(char_dist))
                chars_at_indexes = list(
                    map(lambda index: MODEL_CHARS[int(index.item())],
                        torch.argmax(char_dist, dim=2).squeeze(0)))
                input_tensor = strings_to_tensor(chars_at_indexes, 1,
                                                 letter_to_index)

            # Add sampled characters to names
            for i, char in enumerate(chars_at_indexes):
                names[i] += char

        # Discard everything after EOS character
        # names = list(map(lambda name: name[:name.find(EOS)] if name.find(EOS) > -1 else name, names))
        return hidd_cell_states, names
Exemplo n.º 5
0
    def model_classify(self, xs, lengths, ys=None):
        """
        this model is used to add an auxiliary (supervised) loss as described in the
        Kingma et al., "Semi-Supervised Learning with Deep Generative Models".
        """
        xs = pad_sequence(xs, batch_first=True, padding_value=self.padding_idx)
        if ys is not None:
            ys = torch.stack(ys)

        if self.use_cuda:
            xs = xs.cuda()
            lengths = lengths.cuda()
            if ys is not None:
                ys = ys.cuda()
        #Get embeddings
        xs = self.embeddings(xs)

        # register all pytorch (sub)modules with pyro
        pyro.module("text_ss_vae", self)

        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.plate("data"):
            # this here is the extra term to yield an auxiliary loss that we do gradient descent on
            if ys is not None:
                alpha = self.encoder_y.forward(xs, lengths)
                with pyro.poutine.scale(scale=self.aux_loss_multiplier):
                    pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
Exemplo n.º 6
0
def one_hot_model(pseudocounts, classes=None):
    probs_prior = dist.Dirichlet(pseudocounts)
    probs = pyro.sample("probs", probs_prior)
    with pyro.plate("classes",
                    classes.size(0) if classes is not None else 1,
                    dim=-1):
        return pyro.sample("obs", dist.OneHotCategorical(probs), obs=classes)
    def model(self, data, label=None):

        pyro.module("decoder", self)

        batch_size = data.size(0)

        with pyro.plate("data"):

            deformation_loc = torch.zeros([batch_size, 1])
            deformation_scale = torch.ones([batch_size, 1])

            label_prior = data.new_ones([batch_size, self.num_classes
                                         ]) / (1.0 * self.num_classes)

            deformation = pyro.sample(
                "deformation",
                dist.Normal(deformation_loc, deformation_scale).to_event(1))
            label = pyro.sample("label",
                                dist.OneHotCategorical(label_prior),
                                obs=label)

            final_image = self.decoder(deformation, label, data)

            pyro.sample("image",
                        dist.Bernoulli(final_image).to_event(1),
                        obs=data)
Exemplo n.º 8
0
def test_one_hot_categorical_shape():
    ps = ng_ones(3, 2) / 2
    d = dist.OneHotCategorical(ps)
    assert d.batch_shape() == (3, )
    assert d.event_shape() == (2, )
    assert d.shape() == (3, 2)
    assert d.sample().size() == d.shape()
Exemplo n.º 9
0
    def model(self, x, y=None):
        # Register various nn.Modules with Pyro
        pyro.module("scanvi", self)

        # This gene-level parameter modulates the variance of the observation distribution
        theta = pyro.param("inverse_dispersion", 10.0 * x.new_ones(self.num_genes),
                           constraint=constraints.positive)

        # We scale all sample statements by scale_factor so that the ELBO is normalized
        # wrt the number of datapoints and genes
        with pyro.plate("batch", len(x)), poutine.scale(scale=self.scale_factor):
            z1 = pyro.sample("z1", dist.Normal(0, x.new_ones(self.latent_dim)).to_event(1))
            # Note that if y is None (i.e. y is unobserved) then y will be sampled;
            # otherwise y will be treated as observed.
            y = pyro.sample("y", dist.OneHotCategorical(logits=x.new_zeros(self.num_labels)),
                            obs=y)

            z2_loc, z2_scale = self.z2_decoder(z1, y)
            z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1))

            l_scale = self.l_scale * x.new_ones(1)
            l = pyro.sample("l", dist.LogNormal(self.l_loc, l_scale).to_event(1))

            # Note that by construction mu is normalized (i.e. mu.sum(-1) == 1) and the
            # total scale of counts for each cell is determined by `l`
            gate_logits, mu = self.x_decoder(z2)
            # TODO revisit this parameterization if torch.distributions.NegativeBinomial changes
            # from failure to success parametrization;
            # see https://github.com/pytorch/pytorch/issues/42449
            nb_logits = (l * mu + self.epsilon).log() - (theta + self.epsilon).log()
            x_dist = dist.ZeroInflatedNegativeBinomial(gate_logits=gate_logits, total_count=theta,
                                                       logits=nb_logits)
            # Observe the datapoint x using the observation distribution x_dist
            pyro.sample("x", x_dist.to_event(1), obs=x)
Exemplo n.º 10
0
    def guide(self, x, y=None):
        pyro.module("scanvi", self)
        with pyro.plate("batch",
                        len(x)), poutine.scale(scale=self.scale_factor):
            z2_loc, z2_scale, l_loc, l_scale = self.z2l_encoder(x)
            pyro.sample("l", dist.LogNormal(l_loc, l_scale).to_event(1))
            z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1))

            y_logits = self.classifier(z2)
            y_dist = dist.OneHotCategorical(logits=y_logits)
            if y is None:
                # x is unlabeled so sample y using q(y|z2)
                y = pyro.sample("y", y_dist)
            else:
                # x is labeled so add a classification loss term
                # (this way q(y|z2) learns from both labeled and unlabeled data)
                classification_loss = y_dist.log_prob(y)
                # Note that the negative sign appears because we're adding this term in the guide
                # and the guide log_prob appears in the ELBO as -log q
                pyro.factor(
                    "classification_loss",
                    -self.alpha * classification_loss,
                    has_rsample=False,
                )

            z1_loc, z1_scale = self.z1_encoder(z2, y)
            pyro.sample("z1", dist.Normal(z1_loc, z1_scale).to_event(1))
    def model(self, xs, ys):
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("ss_vae", self)
        batch_size = xs.size(0)

        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.plate("data"):

            # sample the handwriting style from the constant prior distribution
            prior_loc = xs.new_zeros([batch_size, self.z_dim])
            prior_scale = xs.new_ones([batch_size, self.z_dim])
            zs = pyro.sample("z",
                             dist.Normal(prior_loc, prior_scale).to_event(1))

            # if the label y (which digit to write) is supervised, sample from the
            # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
            alpha_prior = xs.new_ones([batch_size, self.output_size
                                       ]) / (1.0 * self.output_size)
            ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

            # finally, score the image (x) using the handwriting style (z) and
            # the class label y (which digit to write) against the
            # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
            # where `decoder` is a neural network
            loc = self.decoder.forward([zs, ys])
            pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs)
Exemplo n.º 12
0
    def guide(self, xs, ys=None):
        """
        The guide corresponds to the following:
        q(y|x) = categorical(alpha(x))              # infer digit from an image
        q(z|x,y) = normal(loc(x,y),scale(x,y))       # infer handwriting style from an image and the digit
        loc, scale are given by a neural network `encoder_z`
        alpha is given by a neural network `encoder_y`
        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.plate("data"):

            # if the class label (the digit) is not supervised, sample
            # (and score) the digit with the variational distribution
            # q(y|x) = categorical(alpha(x))
            if ys is None:
                alpha = self.encoder_y.forward(xs)
                ys = pyro.sample("y", dist.OneHotCategorical(alpha))

            # sample (and score) the latent handwriting-style with the variational
            # distribution q(z|x,y) = normal(loc(x,y),scale(x,y))
            loc, scale = self.encoder_z.forward([xs, ys])
            pyro.sample("z", dist.Normal(loc, scale).to_event(1))
Exemplo n.º 13
0
 def model_classify(self, xs, ys=None):
     pyro.module("ss_vae", self)
     with pyro.plate("data"):
         if ys is not None:
             z1, _ = self.encoder_z1(xs)
             alpha = self.encoder_y(z1)
             pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
Exemplo n.º 14
0
def gmm_batch_guide(data):
    with pyro.iarange("data", len(data)) as batch:
        n = len(batch)
        probs = pyro.param("probs", torch.tensor(torch.ones(n, 1) * 0.6, requires_grad=True))
        probs = torch.cat([probs, 1 - probs], dim=1)
        z = pyro.sample("z", dist.OneHotCategorical(probs))
        assert z.shape[-2:] == (n, 2)
Exemplo n.º 15
0
    def model(self, xs, ys=None):
        # can pass in ys as labels or not pass in anything for unlabelled
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        # with pyro.plate("data"):
        with pyro.iarange("data", xs.shape[0]):
            # setup hyperparameters for prior p(z)
            z_loc = xs.new_zeros(torch.Size((xs.shape[0], self.z_dim)))
            z_scale = xs.new_ones(torch.Size((xs.shape[0], self.z_dim)))

            # sample from prior (value will be sampled by guide when computing the ELBO)
            zs = pyro.sample("z", dist.Normal(z_loc, z_scale).independent(1))

            # if there is a label y, sample from the constant prior
            # otherwise, observe the value (i.e. score against constant prior)
            # alpha_prior = xs.new_ones(torch.Size((xs.shape[0], self.output_size))) / (1.0*self.output_size)

            # prior is the actual train data label distribution (not uniform)
            alpha_prior = self.prior_probs.repeat(xs.shape[0], 1)

            ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

            # decoder outputs mean and sqroot cov, sample from normal
            recon_loc, recon_scale = self.decoder.forward([zs, ys])
            pyro.sample("x",
                        dist.Normal(recon_loc, recon_scale).independent(1),
                        obs=xs.reshape(-1, 1335))
Exemplo n.º 16
0
    def model(self, X_u: list, X_s: list, Z_s: dict, observations=None):
        """
        Model for generating names representing p(x,z)
        x: Training data (name string)
        z: Optionally supervised latent values (dictionary of name/format values)
        """
        pyro.module("model_fn_lstm", self.model_fn_lstm)

        formatted_X_u = strings_to_tensor(X_u, MAX_NAME_LENGTH,
                                          printable_to_index)
        formatted_X_s = strings_to_tensor(X_s, MAX_NAME_LENGTH,
                                          printable_to_index)

        with pyro.plate("sup_batch", len(X_s)):
            _, first_names = self.generate_name_supervised(
                self.model_fn_lstm,
                FIRST_NAME_ADD,
                len(X_s),
                observed=Z_s[FIRST_NAME_ADD])
            full_names = list(
                map(lambda name: pad_string(name, MAX_NAME_LENGTH),
                    first_names))
            probs = strings_to_probs(full_names,
                                     MAX_NAME_LENGTH,
                                     printable_to_index,
                                     true_index_prob=self.peak_prob)
            pyro.sample("sup_output",
                        dist.OneHotCategorical(probs.transpose(0,
                                                               1)).to_event(1),
                        obs=formatted_X_s.transpose(0, 1))

        with pyro.plate("unsup_batch", len(X_u)):
            _, first_names = self.generate_name(self.model_fn_lstm,
                                                FIRST_NAME_ADD, len(X_u))
            full_names = list(
                map(lambda name: pad_string(name, MAX_NAME_LENGTH),
                    first_names))
            probs = strings_to_probs(full_names,
                                     MAX_NAME_LENGTH,
                                     printable_to_index,
                                     true_index_prob=self.peak_prob)
            pyro.sample("unsup_output",
                        dist.OneHotCategorical(probs.transpose(0,
                                                               1)).to_event(1),
                        obs=formatted_X_u.transpose(0, 1))

        return full_names
Exemplo n.º 17
0
	def remap_y(self, ys):
		new_ys = []
		options = dict(dtype=ys.dtype, device=ys.device)
		for i, label_length in enumerate(self.label_shape):
		    prior = torch.ones(ys.size(0), label_length, **options) / (1.0 * label_length)
		    new_ys.append(pyro.sample("y_%s" % self.label_names[i], dist.OneHotCategorical(prior), 
		                           obs=torch.nn.functional.one_hot(ys[:,i].to(torch.int64), int(label_length))))
		new_ys = torch.cat(new_ys, -1)
		return new_ys.to(torch.float32)
Exemplo n.º 18
0
 def model(self, images, labels=None, kl_factor=1.0):
     images = images.view(-1, 784)
     n_images = images.size(0)
     # Set-up parameters for the distribution of weights for each layer `a<n>`
     a1_mean = torch.zeros(784, self.n_hidden)
     a1_scale = torch.ones(784, self.n_hidden)
     a1_dropout = torch.tensor(0.25)
     a2_mean = torch.zeros(self.n_hidden + 1, self.n_classes)
     a2_scale = torch.ones(self.n_hidden + 1, self.n_hidden)
     a2_dropout = torch.tensor(1.0)
     a3_mean = torch.zeros(self.n_hidden + 1, self.n_classes)
     a3_scale = torch.ones(self.n_hidden + 1, self.n_hidden)
     a3_dropout = torch.tensor(1.0)
     a4_mean = torch.zeros(self.n_hidden + 1, self.n_classes)
     a4_scale = torch.ones(self.n_hidden + 1, self.n_classes)
     # Mark batched calculations to be conditionally independent given parameters using `plate`
     with pyro.plate('data', size=n_images):
         # Sample first hidden layer
         h1 = pyro.sample(
             'h1',
             bnn.HiddenLayer(images,
                             a1_mean,
                             a1_dropout * a1_scale,
                             non_linearity=nnf.leaky_relu,
                             KL_factor=kl_factor))
         # Sample second hidden layer
         h2 = pyro.sample(
             'h2',
             bnn.HiddenLayer(h1,
                             a2_mean,
                             a2_dropout * a2_scale,
                             non_linearity=nnf.leaky_relu,
                             KL_factor=kl_factor))
         # Sample third hidden layer
         h3 = pyro.sample(
             'h3',
             bnn.HiddenLayer(h2,
                             a3_mean,
                             a3_dropout * a3_scale,
                             non_linearity=nnf.leaky_relu,
                             KL_factor=kl_factor))
         # Sample output logits
         logits = pyro.sample(
             'logits',
             bnn.HiddenLayer(
                 h3,
                 a4_mean,
                 a4_scale,
                 non_linearity=lambda x: nnf.log_softmax(x, dim=-1),
                 KL_factor=kl_factor,
                 include_hidden_bias=False))
         # One-hot encode labels
         labels = nnf.one_hot(labels) if labels is not None else None
         # Condition on observed labels, so it calculates the log-likehood loss when training using VI
         return pyro.sample('label',
                            dist.OneHotCategorical(logits=logits),
                            obs=labels)
Exemplo n.º 19
0
 def model_classify(self, xs, ys=None):
     # register PyTorch module `encoder` with Pyro
     pyro.module("encoder_y", self.encoder_y)
     # with pyro.plate("data")
     with pyro.iarange("data", xs.shape[0]):
         # this here is the extra term to yield an auxiliary loss that we do gradient descent on
         if ys is not None:
             alpha = self.encoder_y.forward(xs)
             with pyro.poutine.scale(scale=self.aux_loss_multiplier):
                 pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
Exemplo n.º 20
0
def gmm_batch_model(data):
    p = pyro.param("p", torch.tensor([0.3], requires_grad=True))
    p = torch.cat([p, 1 - p])
    scale = pyro.param("scale", torch.tensor([1.0], requires_grad=True))
    mus = torch.tensor([-1.0, 1.0])
    with pyro.iarange("data", len(data)) as batch:
        n = len(batch)
        z = pyro.sample("z", dist.OneHotCategorical(p).expand_by([n]))
        assert z.shape[-2:] == (n, 2)
        loc = (z * mus).sum(-1)
        pyro.sample("x", dist.Normal(loc, scale.expand(n)), obs=data[batch])
Exemplo n.º 21
0
def test_basevae_encode_xy():
    data_dim = (2, 64)
    x = torch.randn(*data_dim)
    alpha = torch.ones(data_dim[0], 3) / 3
    y = dist.OneHotCategorical(alpha).sample()
    vae = models.base.baseVAE(data_dim[1:], None)
    encoder_net = nets.fcEncoderNet(data_dim[1:], 2, 3)
    vae.set_encoder(encoder_net)
    encoded = vae._encode(x, y)
    assert_equal(encoded[:, :2].shape, (data_dim[0], 2))
    assert_equal(encoded[:, 2:].shape, (data_dim[0], 2))
Exemplo n.º 22
0
def test_posterior_predictive_svi_one_hot():
    pseudocounts = torch.ones(3) * 0.1
    true_probs = torch.tensor([0.15, 0.6, 0.25])
    classes = dist.OneHotCategorical(true_probs).sample((10000,))
    guide = AutoDelta(one_hot_model)
    svi = SVI(one_hot_model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO())
    for i in range(1000):
        svi.step(pseudocounts, classes=classes)
    posterior_samples = Predictive(guide, num_samples=10000).get_samples(pseudocounts)
    posterior_predictive = Predictive(one_hot_model, posterior_samples)
    marginal_return_vals = posterior_predictive.get_samples(pseudocounts)["obs"]
    assert_close(marginal_return_vals.mean(dim=0), true_probs.unsqueeze(0), rtol=0.1)
Exemplo n.º 23
0
    def guide_ncls(self, xs, ys=None):
        with pyro.plate("data"):
            z1_mean, z1_std = self.encoder_z1(xs)
            zs1 = pyro.sample("z1", dist.Normal(z1_mean, z1_std).to_event(1))

            if ys is None:
                alpha = self.encoder_y(zs1)
                ys = pyro.sample("y", dist.OneHotCategorical(alpha))

            z2_mean, z2_std = self.encoder_z2(zs1, ys)
            zs2 = pyro.sample("z2", dist.Normal(z2_mean, z2_std).to_event(1))
            pass
    def guide(self, data, label=None):

        pyro.module("encoder", self)

        with pyro.plate("data"):

            label_prior, deformation_loc, deformation_scale = self.encoder(
                data)

            pyro.sample("label", dist.OneHotCategorical(label_prior))
            pyro.sample("deformation",
                        dist.Normal(deformation_loc, deformation_scale))
Exemplo n.º 25
0
def test_support_dims(dim, probs):
    probs = modify_params_using_dims(probs, dim)
    d = dist.OneHotCategorical(probs)
    support = d.enumerate_support()
    for s in support:
        assert_correct_dimensions(s, probs)
    n = len(support)
    assert support.shape == (n,) + d.batch_shape + d.event_shape
    support_expanded = d.enumerate_support(expand=True)
    assert support_expanded.shape == (n,) + d.batch_shape + d.event_shape
    support_unexpanded = d.enumerate_support(expand=False)
    assert support_unexpanded.shape == (n,) + (1,) * len(d.batch_shape) + d.event_shape
Exemplo n.º 26
0
    def model(self, xs, ys=None):
        """
        The model corresponds to the following generative process:
        p(z) = normal(0,I)              # handwriting style (latent)
        p(y|x) = categorical(I/10.)     # which digit (semi-supervised)
        p(x|y,z) = bernoulli(mu(y,z))   # an image
        mu is given by a neural network  `decoder`

        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("ss_vae", self)

        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        batch_size = xs.size(0)
        with pyro.iarange("independent"):
            # sample the handwriting style from the constant prior distribution
            prior_mu = Variable(torch.zeros([batch_size, self.z_dim]))
            prior_sigma = Variable(torch.ones([batch_size, self.z_dim]))
            zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma).reshape(extra_event_dims=1))

            # if the label y (which digit to write) is supervised, sample from the
            # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
            alpha_prior = Variable(torch.ones([batch_size, self.y_dim]) / (1.0 * self.y_dim))
            if ys is None:
                ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior))
            else:
                pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

            # finally, score the image (x) using the handwriting style (z) and
            # the class label y (which digit to write) against the
            # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
            # where `decoder` is a neural network
            mu = self.decoder.forward(zs, ys)
            pyro.sample("x", dist.Bernoulli(mu).reshape(extra_event_dims=1), obs=xs)
Exemplo n.º 27
0
    def model_classify(self, xs, ys=None):
        """
        this model is used to add an auxiliary (supervised) loss as described in the
        Kingma et al., "Semi-Supervised Learning with Deep Generative Models".
        """
        # register all pytorch (sub)modules with pyro
        pyro.module("ss_vae", self)

        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.plate("data"):
            # this here is the extra term to yield an auxiliary loss that we do gradient descent on
            if ys is not None:
                alpha = self.encoder_y.forward(xs)
                with pyro.poutine.scale(scale=self.aux_loss_multiplier):
                    pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
Exemplo n.º 28
0
 def model_aux(self,
               xs: torch.Tensor,
               ys: Optional[torch.Tensor] = None,
               **kwargs: float) -> None:
     """
     Models an auxiliary (supervised) loss
     """
     pyro.module("ss_vae", self)
     with pyro.plate("data"):
         # the extra term to yield an auxiliary loss
         aux_loss_multiplier = kwargs.get("aux_loss_multiplier", 20)
         if ys is not None:
             alpha = self.encoder_y.forward(xs)
             with pyro.poutine.scale(scale=aux_loss_multiplier):
                 pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
Exemplo n.º 29
0
def test_auxsvi_trainer_cls(invariances):
    data_dim = (5, 8, 8)
    train_unsup = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item())
    train_sup = train_unsup + .1 * torch.randn_like(train_unsup)
    labels = dist.OneHotCategorical(torch.ones(data_dim[0], 3)).sample()
    loader_unsup, loader_sup, loader_val = utils.init_ssvae_dataloaders(
        train_unsup, (train_sup, labels), (train_sup, labels), batch_size=2)
    vae = models.ssiVAE(data_dim[1:], 2, 3, invariances)
    trainer = trainers.auxSVItrainer(vae)
    weights_before = dc(vae.state_dict())
    for _ in range(2):
        trainer.step(loader_unsup, loader_sup, loader_val)
    weights_after = vae.state_dict()
    assert_(not torch.isnan(tt(trainer.history["training_loss"])).any())
    assert_(not assert_weights_equal(weights_before, weights_after))
Exemplo n.º 30
0
    def model(self,
              batch,
              reversed_batch,
              batch_mask,
              batch_seqlens,
              kl_anneal=1.0):
        """
        the model defines p(x_{1:T}|z_{1:T}) and p(z_{1:T})
        """
        # maximum duration of batch
        Tmax = batch.size(1)

        # register torch submodules w/ pyro
        pyro.module("dmm", self)

        # setup recursive conditioning for p(z_t|z_{t-1})
        z_prev = self.z_0.expand(batch.size(0), self.z_0.size(0))

        # sample conditionally indepdent text across the batch
        with pyro.plate("z_batch", len(batch)):
            # sample latent vars z and observed x w/ multiple samples from the guide for each z
            for t in pyro.markov(range(1, Tmax + 1)):

                # compute params of diagonal gaussian p(z_t|z_{t-1})
                z_loc, z_scale = self.transition(z_prev)

                # sample latent variable
                with poutine.scale(scale=kl_anneal):
                    z_t = pyro.sample(
                        "z_%d" % t,
                        dist.Normal(z_loc, z_scale).mask(
                            batch_mask[:, t - 1:t]).to_event(1),
                    )

                # compute emission probability from latent variable
                emission_prob = self.emitter(z_t)

                # observe x_t according to the Categorical distribution defined by the emitter probability
                pyro.sample(
                    "obs_x_%d" % t,
                    dist.OneHotCategorical(emission_prob).mask(
                        batch_mask[:, t - 1:t]).to_event(1),
                    obs=batch[:, t - 1, :],
                )

                # set conditional var for next time step
                z_prev = z_t
        pass