Example #1
0
    def model(self,*args, **kwargs):

        I, N = self._data['data'].shape

        batch = N if self._params['batch_size'] else self._params['batch_size']

        weights = pyro.sample('mixture_weights', dist.Dirichlet((1 / self._params['K']) * torch.ones(self._params['K'])))
        cat_vector = torch.tensor(np.arange(self._params['hidden_dim']) + 1, dtype = torch.float)

        with pyro.plate('segments', I):
            segment_factor = pyro.sample('segment_factor', dist.Gamma(self._params['theta_scale'], self._params['theta_rate']))
            with pyro.plate('components', self._params['K']):
                cc = pyro.sample("CNV_probabilities", dist.Dirichlet(self.create_dirichlet_init_values()))

        with pyro.plate('data', N, batch):

                # p(x|z_i) = Poisson(marg(cc * theta * segment_factor))

                segment_fact_cat = torch.matmul(segment_factor.reshape([I,1]) , cat_vector.reshape([1, self._params['hidden_dim']]))
                segment_fact_marg = segment_fact_cat * cc
                segment_fact_marg = torch.sum(segment_fact_marg, dim = -1)

                # p(z_i| D, X ) = lk(z_i) * p(z_i | X) / sum_z_i(lk(z_i) * p(z_i | X))
                # log(p(z_i| D, X )) = log(lk(z_i)) + log(p(z_i | X)) - log_sum_exp(log(lk(z_i)) + log(p(z_i | X)))

                pyro.factor("lk", self.likelihood(segment_fact_marg, weights, self._params['theta']))
Example #2
0
def model(transition_alphas, emission_alphas, lengths,
          sequences=None, batch_size=None):
    # From https://pyro.ai/examples/hmm.html
    with ignore_jit_warnings():
        if sequences is not None:
            num_sequences, max_length, data_dim = map(int, sequences.shape)
            assert lengths.shape == (num_sequences,)
            assert lengths.max() <= max_length
        else:
            data_dim = emission_alphas.size(1)
            num_sequences = int(lengths.shape[0])
            max_length = int(lengths.max())
    transition_probs = pyro.sample('transition_probs',
                                   dist.Dirichlet(transition_alphas).to_event(1))
    emission_probs = pyro.sample('emission_probs',
                                 dist.Dirichlet(emission_alphas).to_event(2))
    element_plate = pyro.plate('elements', data_dim, dim=-1)
    with pyro.plate('sequences', num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        state = 0
        for t in pyro.markov(range(max_length)):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                state = pyro.sample(f'state_{t}', dist.Categorical(transition_probs[state]),
                                    infer={'enumerate': 'parallel'})
                obs_element = Vindex(sequences)[batch, t] if sequences is not None else None
                with element_plate:
                    element = pyro.sample(f'element_{t}',
                                          dist.Categorical(emission_probs[state.squeeze(-1)]),
                                          obs=obs_element)
Example #3
0
    def model(self, *args, **kwargs):
        I, N = self._data['data'].shape
        batch = N if self._params['batch_size'] else self._params['batch_size']
        weights = pyro.sample(
            'mixture_weights',
            dist.Dirichlet(
                (1 / self._params['K']) * torch.ones(self._params['K'])))
        with pyro.plate('segments', I):
            with pyro.plate('components', self._params['K']):
                cnv_probs = pyro.sample(
                    "cnv_probs",
                    dist.Dirichlet(self._params['probs'] * 1 /
                                   torch.ones(self._params['hidden_dim'])))

        with pyro.plate("data2", N, batch):
            theta = pyro.sample(
                'norm_factor',
                dist.Gamma(self._params['theta_scale'],
                           self._params['theta_rate']))

        with pyro.plate('data', N, batch):
            assignment = pyro.sample('assignment',
                                     dist.Categorical(weights),
                                     infer={"enumerate": "parallel"})
            for i in pyro.plate('segments2', I):
                cc = pyro.sample('copy_number_{}'.format(i),
                                 dist.Categorical(
                                     Vindex(cnv_probs)[assignment, i, :]),
                                 infer={"enumerate": "parallel"})
                pyro.sample('obs_{}'.format(i),
                            dist.Poisson((cc * theta * self._data['mu'][i]) +
                                         1e-8),
                            obs=self._data['data'][i, :])
Example #4
0
def model(data, batch_size=32):
    alpha = torch.ones(num_topics)
    alpha = alpha.to(device)
    eta = torch.ones(num_words)
    eta = eta.to(device)

    with pyro.plate("topic_loop", num_topics):
        #beta =[num_topics, num_words]
        beta = pyro.sample("beta", dist.Dirichlet(eta))
        beta = beta.to(device)
        #print(beta)
        #print(beta.shape)

    with pyro.plate("document_loop", num_documents) as ind:
        #theta = num_documents * num_topics
        theta = pyro.sample("theta", dist.Dirichlet(alpha))
        theta = theta.to(device)
        #print(theta.shape)
        #print(theta)
        with pyro.plate("word_loop", words_per_doc):
            # z = [num_words, num_documents]
            z = pyro.sample("z", dist.Categorical(theta))
            z = z.to(device)
            #print(z.shape)
            #print(z)
            results = pyro.sample("obs",
                                  dist.Categorical(Vindex(beta)[z]),
                                  obs=data[:, ind])
            #print(results.shape)
            return results
Example #5
0
def model_4(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim ** 0.5)  # split between w and x
    with poutine.mask(mask=include_prior):
        probs_w = pyro.sample("probs_w",
                              dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1)
                                  .to_event(1))
        probs_x = pyro.sample("probs_x",
                              dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1)
                                  .expand_by([hidden_dim])
                                  .to_event(2))
        probs_y = pyro.sample("probs_y",
                              dist.Beta(0.1, 0.9)
                                  .expand([hidden_dim, hidden_dim, data_dim])
                                  .to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        # Note the broadcasting tricks here: we declare a hidden torch.arange and
        # ensure that w and x are always tensors so we can unsqueeze them below,
        # thus ensuring that the x sample sites have correct distribution shape.
        w = x = torch.tensor(0, dtype=torch.long)
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]),
                                infer={"enumerate": "parallel"})
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(Vindex(probs_x)[w, x]),
                                infer={"enumerate": "parallel"})
                with tones_plate as tones:
                    pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]),
                                obs=sequences[batch, t])
Example #6
0
    def model(self, docs=None, doc_sum=None):
        # Globals.
        with pyro.plate("topics", self.num_topics):
            a = torch.tensor(1. / self.num_topics, device=self.device)
            b = torch.tensor(1., device=self.device)
            topic_weights = pyro.sample("topic_weights", dist.Gamma(a, b))

            alpha = torch.ones(self.vocab_size,
                               device=self.device) / self.vocab_size
            topic_words = pyro.sample("topic_words", dist.Dirichlet(alpha))

        # Locals.
        # We will use nested plates. Pyro convention is to count from the right
        # by using negative indices like -1, -2. This means documents must be at
        # the rightmost dimension, followed by words. For this reason, we transpose
        # the data:
        docs = docs.transpose(0, 1)

        with pyro.plate('documents', docs.shape[-1]):
            doc_topics = pyro.sample("doc_topics",
                                     dist.Dirichlet(topic_weights))
            with pyro.plate("words", docs.shape[-2]):
                word_topics = pyro.sample("word_topics",
                                          dist.Categorical(doc_topics),
                                          infer={"enumerate": "parallel"})
                data = pyro.sample("doc_words",
                                   dist.Categorical(topic_words[word_topics]),
                                   obs=docs)

        return topic_words
Example #7
0
File: lda.py Project: xidulu/pyro
def model(data=None, args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", args.num_topics):
        topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
        topic_words = pyro.sample("topic_words",
                                  dist.Dirichlet(torch.ones(args.num_words) / args.num_words))

    # Locals.
    with pyro.plate("documents", args.num_docs) as ind:
        if data is not None:
            with pyro.util.ignore_jit_warnings():
                assert data.shape == (args.num_words_per_doc, args.num_docs)
            data = data[:, ind]
        doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
        with pyro.plate("words", args.num_words_per_doc):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
                                      infer={"enumerate": "parallel"})
            data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),
                               obs=data)

    return topic_weights, topic_words, data
Example #8
0
            def guide_ret(*args, **kwargs):
                I, N = self._data['data'].shape
                batch = N if self._params['batch_size'] else self._params[
                    'batch_size']

                param_weights = pyro.param(
                    "param_weights",
                    lambda: torch.ones(self._params['K']) / self._params['K'],
                    constraint=constraints.simplex)
                hidden_vals = pyro.param(
                    "param_hidden_weights",
                    lambda: self.create_dirichlet_init_values(),
                    constraint=constraints.simplex)
                gamma_scale = pyro.param(
                    "param_gamma_scale",
                    lambda: torch.mean(
                        self._data['data'] / (2 * self._data['mu'].reshape(
                            self._data['data'].shape[0], 1)),
                        axis=0) * self._params['gamma_multiplier'],
                    constraint=constraints.positive)
                gamma_rate = pyro.param(
                    "param_rate",
                    lambda: torch.ones(1) * self._params['gamma_multiplier'],
                    constraint=constraints.positive)
                weights = pyro.sample('mixture_weights',
                                      dist.Dirichlet(param_weights))

                with pyro.plate('segments', I):
                    with pyro.plate('components', self._params['K']):
                        pyro.sample("cnv_probs", dist.Dirichlet(hidden_vals))

                with pyro.plate("data2", N, batch):
                    pyro.sample('norm_factor',
                                dist.Gamma(gamma_scale, gamma_rate))
Example #9
0
def model(doc_word_data=None, category_data=None, args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", args.num_topics):
        # topic_weights does not seem to come from the usual LDA plate notation, but seems to give an indication of
        # the importance of topics. It might be from the amortized LDA paper.
        topic_weights = pyro.sample("topic_weights",
                                    dist.Gamma(1. / args.num_topics, 1.))
        topic_words = pyro.sample(
            "topic_words",
            dist.Dirichlet(torch.ones(args.num_words) / args.num_words))

    with pyro.plate("categories", args.num_categories):
        category_weights = pyro.sample(
            "category_weights", dist.Gamma(1. / args.num_categories, 1.))
        # TODO category weights might not be necessary in our model
        category_topics = pyro.sample("category_topics",
                                      dist.Dirichlet(topic_weights))

    doc_category_list = []
    doc_word_list = []

    # Locals.
    for index, doc in enumerate(pyro.plate("documents", args.num_docs)):
        if doc_word_data is not None:
            cur_doc_word_data = doc_word_data[doc]
        else:
            cur_doc_word_data = None

        if category_data is not None:
            cur_category_data = category_data[doc]
        else:
            cur_category_data = None

        doc_category_list.append(
            pyro.sample("doc_categories_{}".format(doc),
                        dist.Categorical(category_weights),
                        obs=cur_category_data))

        with pyro.plate("words_{}".format(doc), args.num_words_per_doc[doc]):
            word_topics = pyro.sample(
                "word_topics_{}".format(doc),
                dist.Categorical(category_topics[int(
                    doc_category_list[index].item())]))
            # TODO Enum parallel/sequential optimizing?

            doc_word_list.append(
                pyro.sample("doc_words_{}".format(doc),
                            dist.Categorical(topic_words[word_topics]),
                            obs=cur_doc_word_data))

    results = {
        "topic_weights": topic_weights,
        "topic_words": topic_words,
        "doc_word_data": doc_word_list,
        "category_weights": category_weights,
        "category_topics": category_topics,
        "doc_category_data": doc_category_list
    }

    return results
Example #10
0
def parametrized_guide(doc_word_data, category_data, args, batch_size=None):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param("topic_weights_posterior",
                                         lambda: torch.ones(args.num_topics),
                                         constraint=constraints.positive)
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        lambda: torch.ones(args.num_topics, args.num_words),
        constraint=constraints.greater_than(0.5))
    with pyro.plate("topics", args.num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    category_weights_posterior = pyro.param(
        "category_weights_posterior",
        lambda: torch.ones(args.num_categories),
        constraint=constraints.positive)
    category_topics_posterior = pyro.param(
        "category_topics_posterior",
        lambda: torch.ones(args.num_categories, args.num_topics),
        constraint=constraints.greater_than(0.5))
    with pyro.plate("categories", args.num_categories):
        pyro.sample("category_weights",
                    dist.Gamma(category_weights_posterior, 1.))
        pyro.sample("category_topics",
                    dist.Dirichlet(category_topics_posterior))

    doc_category_posterior = pyro.param("doc_category_posterior",
                                        lambda: torch.ones(args.num_topics),
                                        constraint=constraints.positive)
    with pyro.plate("documents", args.num_docs, batch_size) as ind:
        pyro.sample("doc_categories", dist.Categorical(doc_category_posterior))
Example #11
0
    def model(self, *args, **kwargs):
        I, N = self._data['data'].shape
        batch = N if self._params['batch_size'] else self._params['batch_size']
        weights = pyro.sample('mixture_weights',
                              dist.Dirichlet(torch.ones(self._params['K'])))

        with pyro.plate('components', self._params['K']):
            probs_z = pyro.sample(
                "cnv_probs",
                dist.Dirichlet(self._params['t'] *
                               torch.eye(self._params['hidden_dim']) +
                               (1 - self._params['t'])).to_event(1))

        with pyro.plate("data2", N, batch):
            theta = pyro.sample(
                'norm_factor',
                dist.Gamma(self._params['theta_scale'],
                           self._params['theta_rate']))

        with pyro.plate('data', N, batch):
            z = 0
            assignment = pyro.sample('assignment',
                                     dist.Categorical(weights),
                                     infer={"enumerate": "parallel"})
            for i in pyro.markov(range(I)):
                z = pyro.sample("z_{}".format(i),
                                dist.Categorical(
                                    Vindex(probs_z)[assignment, z]),
                                infer={"enumerate": "parallel"})
                pyro.sample('obs_{}'.format(i),
                            dist.Poisson((z * theta * self._data['mu'][i]) +
                                         1e-8),
                            obs=self._data['data'][i, :])
Example #12
0
def model_3(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with poutine.mask(mask=include_prior):
        probs_w = pyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        w, x = 0, 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample("w_{}".format(t),
                                dist.Categorical(probs_w[w]),
                                infer={"enumerate": "parallel"})
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                with tones_plate as tones:
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y[w, x, tones]),
                                obs=sequences[batch, t])
Example #13
0
    def model(self, *args, **kwargs):
        I = self._data['segments']




        pi = pyro.sample("pi", dist.Dirichlet(self._params['init_probs']))


        probs_z = pyro.sample("cnv_probs",
                              dist.Dirichlet((1- self._params['t']) * torch.eye(self._params['hidden_dim']) + (
                                      self._params['t'])).to_event(1))
        probs_y =  torch.tensor([[2., 64., 32., 21.5, 16., 43.],[64., 64., 64., 64., 64., 64.]])


        z = pyro.sample("z_0", dist.Categorical(pi),
                infer={"enumerate": "parallel"})

        pyro.sample("y_{}".format(0), dist.Beta(probs_y[0, z], probs_y[1, z]),
                    obs=self._data['data'][0, 0])

        for i in pyro.markov(range(1,I)):
            z = pyro.sample("z_{}".format(i), dist.Categorical(Vindex(probs_z)[z]),
                            infer={"enumerate": "parallel"})


            pyro.sample("y_{}".format(i), dist.Beta(probs_y[0,z], probs_y[1,z]),
                        obs= self._data['data'][i,0])
Example #14
0
def model(K=None,
          M=None,
          N=None,
          V=None,
          alpha=None,
          beta=None,
          doc=None,
          w=None):
    theta = sample('theta', ImproperUniform(shape=(M, K)))
    phi = sample('phi', ImproperUniform(shape=(K, V)))
    for m in range(1, M + 1):
        sample('theta' + '__{}'.format(m - 1) + '__1',
               dist.Dirichlet(alpha),
               obs=theta[m - 1])
    for k in range(1, K + 1):
        sample('phi' + '__{}'.format(k - 1) + '__2',
               dist.Dirichlet(beta),
               obs=phi[k - 1])
    for n in range(1, N + 1):
        gamma = zeros(K)
        for k in range(1, K + 1):
            gamma[k - 1] = log(theta[doc[n - 1] - 1, k - 1]) + log(
                phi[k - 1, w[n - 1] - 1])
        sample('expr' + '__{}'.format(n) + '__3',
               dist.Exponential(1.0),
               obs=-log_sum_exp(gamma))
Example #15
0
    def guide(self, data):
        alpha_posterior = pyro.param("alpha_posterior",
                                     lambda: torch.ones(self.n_topics),
                                     constraint=positive)
        beta_posterior = pyro.param(
            "beta_posterior",
            lambda: torch.ones(self.n_topics, self.vocab_size),
            constraint=greater_than(0.5))

        with pyro.plate("topics", self.n_topics):
            alpha = pyro.sample("alpha", dist.Gamma(alpha_posterior, 1.))
            betas = pyro.sample("beta", dist.Dirichlet(beta_posterior))

        theta = None
        z = None

        for d in pyro.plate("doc_loop", len(data)):
            gamma_q = pyro.param(f"gamma_{d}",
                                 torch.ones(self.n_topics),
                                 constraint=positive)
            theta = pyro.sample(f"theta_{d}", dist.Dirichlet(gamma_q))
            nwords = len(data[d])
            for w in pyro.plate(f"word_loop_{d}", nwords):
                phi_q = pyro.param(f"phi{d}_{w}",
                                   torch.ones(self.n_topics),
                                   constraint=positive)
                z = pyro.sample(f"z{d}_{w}", dist.Categorical(phi_q))
        return theta, z, alpha, betas
Example #16
0
 def model(data):
     initialize = pyro.sample("initialize", dist.Dirichlet(torch.ones(dim)))
     with pyro.plate("states", dim):
         transition = pyro.sample("transition",
                                  dist.Dirichlet(torch.ones(dim, dim)))
         emission_loc = pyro.sample(
             "emission_loc", dist.Normal(torch.zeros(dim), torch.ones(dim)))
         emission_scale = pyro.sample(
             "emission_scale",
             dist.LogNormal(torch.zeros(dim), torch.ones(dim)))
     x = None
     with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)
                               ]):
         for t, y in pyro.markov(enumerate(data)):
             x = pyro.sample(
                 "x_{}".format(t),
                 dist.Categorical(
                     initialize if x is None else transition[x]),
                 infer={"enumerate": "parallel"},
             )
             pyro.sample(
                 "y_{}".format(t),
                 dist.Normal(emission_loc[x], emission_scale[x]),
                 obs=y,
             )
Example #17
0
def model(doc_word_data=None, category_data=None, args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", args.num_topics):
        # topic_weights does not seem to come from the usual LDA plate notation, but seems to give an indication of
        # the importance of topics. It might be from the amortized LDA paper.
        topic_weights = pyro.sample("topic_weights",
                                    dist.Gamma(1. / args.num_topics, 1.))
        topic_words = pyro.sample(
            "topic_words",
            dist.Dirichlet(torch.ones(args.num_words) / args.num_words))

    with pyro.plate("categories", args.num_categories):
        category_weights = pyro.sample(
            "category_weights", dist.Gamma(1. / args.num_categories, 1.))
        category_topics = pyro.sample("category_topics",
                                      dist.Dirichlet(topic_weights))

    # Locals.
    with pyro.plate("documents", args.num_docs) as ind:
        if doc_word_data is not None:
            with pyro.util.ignore_jit_warnings():
                assert doc_word_data.shape == (args.num_words_per_doc,
                                               args.num_docs
                                               )  # Forces the 64x1000 shape
            doc_word_data = doc_word_data[:, ind]

        if category_data is not None:
            category_data = category_data[ind]

        category_data = pyro.sample("doc_categories",
                                    dist.Categorical(category_weights),
                                    obs=category_data)

        with pyro.plate("words", args.num_words_per_doc):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            word_topics = pyro.sample("word_topics",
                                      dist.Categorical(
                                          category_topics[category_data]),
                                      infer={"enumerate": "parallel"})
            doc_word_data = pyro.sample("doc_words",
                                        dist.Categorical(
                                            topic_words[word_topics]),
                                        obs=doc_word_data)

    results = {
        "topic_weights": topic_weights,
        "topic_words": topic_words,
        "doc_word_data": doc_word_data,
        "category_weights": category_weights,
        "category_topics": category_topics,
        "category_data": category_data
    }

    return results
Example #18
0
def guide(data):
    qalpha0 = pyro.param("qalpha0",
                         torch.ones(nd, nz, 1, Td),
                         constraint=constraints.positive)
    qalpha1 = pyro.param("qalpha1",
                         torch.ones(nz, 1, nw, ntr),
                         constraint=constraints.positive)
    # CHANGE: use the fact that dirichlet can draw independant dirichlets
    pyro.sample("latent0", pdist.Dirichlet(concentration=qalpha0.view(nd, -1)))
    pyro.sample("latent1", pdist.Dirichlet(concentration=qalpha1.view(nz, -1)))
Example #19
0
 def model(self, enc_in, dec_in, dec_out, T, N):
     pyro.module("AE", self.AE)
     theta_t = pyro.sample('theta_t', dist.Dirichlet(torch.ones(K) * 10))
     theta_d = pyro.sample('theta_d', dist.Dirichlet(torch.ones(K) * 10))
     with pyro.iarange('data.loop', N, dim=-1) as i:
         z_t = pyro.sample('z_t', dist.Categorical(theta_t.expand(N, K)))
         z_d = pyro.sample('z_d', dist.Categorical(theta_d.expand(N, K)))
         pi = self.AE.decode([z_t, z_d, dec_in[i]])
         for t in range(T - 1):
             pyro.sample('y_{}_{}'.format(i, t),
                         dist.Categorical(pi[:, t, :]),
                         obs=dec_out[i, t])
Example #20
0
def test_dirichlet_multinomial(sample_shape, batch_shape):
    concentration = torch.randn(batch_shape + (3,)).exp()
    total = 10
    probs = torch.tensor([0.2, 0.3, 0.5])
    obs = dist.Multinomial(total, probs).sample(sample_shape + batch_shape)

    f = dist.Dirichlet(concentration)
    g = dist.Dirichlet(1 + obs)
    fg, log_normalizer = f.conjugate_update(g)

    x = fg.sample(sample_shape)
    assert_close(f.log_prob(x) + g.log_prob(x), fg.log_prob(x) + log_normalizer)
Example #21
0
def unsupervised_hmm(words):
    with pyro.plate("prob_plate", num_categories):
        transition_prob = pyro.sample("transition_prob", dist.Dirichlet(transition_prior))
        emission_prob = pyro.sample("emission_prob", dist.Dirichlet(emission_prior))

    transition_log_prob = transition_prob.log()
    emission_log_prob = emission_prob.log()
    log_prob = emission_log_prob[:, words[0]]
    for t in range(1, len(words)):
        log_prob = forward_log_prob(log_prob, words[t], transition_log_prob, emission_log_prob)
    prob = log_prob.logsumexp(dim=0).exp()
    # a trick to inject an additional log_prob into model's log_prob
    pyro.sample("forward_prob", dist.Bernoulli(prob), obs=torch.tensor(1.))
Example #22
0
def model(data):
    s0 = (nd, nz, 1, Td)
    s1 = (nz, 1, nw, ntr)
    alpha0 = torch.ones(*s0).cpu()
    alpha1 = torch.ones(*s1).cpu()
    z = pyro.sample("latent0", pdist.Dirichlet(concentration=alpha0.view(nd, -1)))
    motifs = pyro.sample("latent1", pdist.Dirichlet(concentration=alpha1.view(nz, -1)))

    z = z.reshape(*s0)
    motifs = motifs.reshape(*s1)
    p = p_w_ta_d(z, motifs)
    with pyro.iarange("data", len(data)):
        zts = pyro.sample("zts", pdist.Categorical(probs=z))
        pyro.sample("observe", pdist.Multinomial(probs=p), obs=data)
Example #23
0
def model(data=None, num_words_per_doc=None, args=None):
    # Globals.
    with pyro.plate("topics", args.num_topics):
        topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
        topic_words = pyro.sample("topic_words",
                                  dist.Dirichlet(torch.ones(args.num_words) / args.num_words))
        # Changed here to from vector(with) to iteration to support varying number
        # of words (num_words_per_doc).
        # with pyro.plate("documents", args.num_docs) as ind:
    for doc in pyro.plate("documents", args.num_docs):
        doc_topics = pyro.sample("doc_topics_{}".format(doc), dist.Dirichlet(topic_weights))
        with pyro.plate("words_{}".format(doc), num_words_per_doc[doc]):
            word_topics = pyro.sample("word_topics_{}".format(doc), dist.Categorical(doc_topics))
            pyro.sample("doc_words_{}".format(doc), dist.Categorical(topic_words[word_topics]), obs=data[doc])
    return topic_weights, topic_words
Example #24
0
def guide(data):
    qalpha0 = pyro.param("qalpha0", torch.ones(nd, nz, 1, Td).cpu(), constraint=constraints.positive)  # z_ts table
    global step_motif_count
    if flag_ISM:
        qalpha1 = pyro.param("qalpha1", init_motif, constraint=constraints.positive)  # motif
        if step_motif_count % 5 == 0:
            tem_motif.append(qalpha1)
    else:
        qalpha1 = pyro.param("qalpha1", torch.ones(nz, 1, nw, ntr).cpu(), constraint=constraints.positive)  # motif
        if step_motif_count % 5 == 0:
            tem_motif.append(qalpha1)

    #     CHANGE: use the fact that dirichlet can draw independant dirichlets
    pyro.sample("latent0", pdist.Dirichlet(concentration=qalpha0.view(nd, -1)))
    pyro.sample("latent1", pdist.Dirichlet(concentration=qalpha1.view(nz, -1)))
Example #25
0
def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
    assert not torch._C._get_tracing_state()
    num_sequences, max_length, data_dim = sequences.shape
    with poutine.mask(mask=include_prior):
        # Our prior on transition probabilities will be:
        # stay in the same state with 90% probability; uniformly jump to another
        # state with 10% probability.
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        # We put a weak prior on the conditional probability of a tone sounding.
        # We know that on average about 4 of 88 tones are active, so we'll set a
        # rough weak prior of 10% of the notes being active at any one time.
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2))
    # In this first model we'll sequentially iterate over sequences in a
    # minibatch; this will make it easy to reason about tensor shapes.
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    for i in pyro.plate("sequences", len(sequences), batch_size):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            # On the next line, we'll overwrite the value of x with an updated
            # value. If we wanted to record all x values, we could instead
            # write x[t] = pyro.sample(...x[t-1]...).
            x = pyro.sample("x_{}_{}".format(i, t),
                            dist.Categorical(probs_x[x]),
                            infer={"enumerate": "parallel"})
            with tones_plate:
                pyro.sample("y_{}_{}".format(i, t),
                            dist.Bernoulli(probs_y[x.squeeze(-1)]),
                            obs=sequence[t])
Example #26
0
def model_5(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length

    # Initialize a global module instance if needed.
    global tones_generator
    if tones_generator is None:
        tones_generator = TonesGenerator(args, data_dim)
    pyro.module("tones_generator", tones_generator)

    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        y = torch.zeros(data_dim)
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                # Note that since each tone depends on all tones at a previous time step
                # the tones at different time steps now need to live in separate plates.
                with pyro.plate("tones_{}".format(t), data_dim, dim=-1):
                    y = pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(logits=tones_generator(x, y)),
                        obs=sequences[batch, t])
Example #27
0
def model_2(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x, y = 0, 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                # Note the broadcasting tricks here: to index probs_y on tensors x and y,
                # we also need a final tensor for the tones dimension. This is conveniently
                # provided by the plate associated with that dimension.
                with tones_plate as tones:
                    y = pyro.sample("y_{}".format(t),
                                    dist.Bernoulli(probs_y[x, y, tones]),
                                    obs=sequences[batch, t]).long()
Example #28
0
File: hmm.py Project: pyro-ppl/pyro
def model_7(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length

    # Initialize a global module instance if needed.
    global tones_generator
    if tones_generator is None:
        tones_generator = TonesGenerator(args, data_dim)
    pyro.module("tones_generator", tones_generator)

    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )
    with pyro.plate("sequences", num_sequences, batch_size, dim=-1) as batch:
        lengths = lengths[batch]
        y = sequences[batch] if args.jit else sequences[batch, :lengths.max()]
        x = torch.arange(args.hidden_dim)
        t = torch.arange(y.size(1))
        init_logits = torch.full((args.hidden_dim, ), -float("inf"))
        init_logits[0] = 0
        trans_logits = probs_x.log()
        with ignore_jit_warnings():
            obs_dist = dist.Bernoulli(
                logits=tones_generator(x, y.unsqueeze(-2))).to_event(1)
            obs_dist = obs_dist.mask((t < lengths.unsqueeze(-1)).unsqueeze(-1))
            hmm_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
        pyro.sample("y", hmm_dist, obs=y)
Example #29
0
def test_dirichlet_shape():
    alpha = ng_ones(3, 2) / 2
    d = dist.Dirichlet(alpha)
    assert d.batch_shape() == (3, )
    assert d.event_shape() == (2, )
    assert d.shape() == (3, 2)
    assert d.sample().size() == d.shape()
Example #30
0
def one_hot_model(pseudocounts, classes=None):
    probs_prior = dist.Dirichlet(pseudocounts)
    probs = pyro.sample("probs", probs_prior)
    with pyro.plate("classes",
                    classes.size(0) if classes is not None else 1,
                    dim=-1):
        return pyro.sample("obs", dist.OneHotCategorical(probs), obs=classes)