def test_rowwise_approx(three_var_model, parametric_grouped_approxes):
    # add to inference that supports aevb
    cls, kw = parametric_grouped_approxes
    with three_var_model:
        try:
            approx = Approximation([cls([three_var_model.one], rowwise=True, **kw), Group(None, vfam='mf')])
            inference = pm.KLqp(approx)
            approx = inference.fit(3, obj_n_mc=2)
            approx.sample(10)
            approx.sample_node(
                three_var_model.one
            ).eval()
        except pm.opvi.BatchedGroupError:
            pytest.skip('Does not support rowwise grouping')
def test_clear_cache():
    import pickle

    with pm.Model():
        pm.Normal("n", 0, 1)
        inference = ADVI()
        inference.fit(n=10)
        assert any(len(c) != 0 for c in inference.approx._cache.values())
        inference.approx._cache.clear()
        # should not be cleared at this call
        assert all(len(c) == 0 for c in inference.approx._cache.values())
        new_a = pickle.loads(pickle.dumps(inference.approx))
        assert not hasattr(new_a, "_cache")
        inference_new = pm.KLqp(new_a)
        inference_new.fit(n=10)
        assert any(len(c) != 0 for c in inference_new.approx._cache.values())
        inference_new.approx._cache.clear()
        assert all(len(c) == 0 for c in inference_new.approx._cache.values())
def test_clear_cache():
    import pickle
    pymc3.memoize.clear_cache()
    assert all(len(c) == 0 for c in pymc3.memoize.CACHE_REGISTRY)
    with pm.Model():
        pm.Normal('n', 0, 1)
        inference = ADVI()
        inference.fit(n=10)
        assert any(len(c) != 0 for c in inference.approx._cache.values())
        pymc3.memoize.clear_cache(inference.approx)
        # should not be cleared at this call
        assert all(len(c) == 0 for c in inference.approx._cache.values())
        new_a = pickle.loads(pickle.dumps(inference.approx))
        assert not hasattr(new_a, '_cache')
        inference_new = pm.KLqp(new_a)
        inference_new.fit(n=10)
        assert any(len(c) != 0 for c in inference_new.approx._cache.values())
        pymc3.memoize.clear_cache(inference_new.approx)
        assert all(len(c) == 0 for c in inference_new.approx._cache.values())
def test_sample_aevb(three_var_aevb_approx, aevb_initial):
    pm.KLqp(three_var_aevb_approx).fit(1, more_replacements={
        aevb_initial: np.zeros_like(aevb_initial.get_value())[:1]
    })
    aevb_initial.set_value(np.random.rand(7, 7).astype('float32'))
    trace = three_var_aevb_approx.sample(500)
    assert set(trace.varnames) == {'one', 'one_log__', 'two', 'three'}
    assert len(trace) == 500
    assert trace[0]['one'].shape == (7, 2)
    assert trace[0]['two'].shape == (10, )
    assert trace[0]['three'].shape == (10, 1, 2)

    aevb_initial.set_value(np.random.rand(13, 7).astype('float32'))
    trace = three_var_aevb_approx.sample(500)
    assert set(trace.varnames) == {'one', 'one_log__', 'two', 'three'}
    assert len(trace) == 500
    assert trace[0]['one'].shape == (13, 2)
    assert trace[0]['two'].shape == (10,)
    assert trace[0]['three'].shape == (10, 1, 2)
def test_sample_aevb(three_var_aevb_approx, aevb_initial):
    pm.KLqp(three_var_aevb_approx).fit(
        1, more_replacements={aevb_initial: np.zeros_like(aevb_initial.get_value())[:1]}
    )
    aevb_initial.set_value(np.random.rand(7, 7).astype("float32"))
    trace = three_var_aevb_approx.sample(500, return_inferencedata=False)
    assert set(trace.varnames) == {"one", "one_log__", "two", "three"}
    assert len(trace) == 500
    assert trace[0]["one"].shape == (7, 2)
    assert trace[0]["two"].shape == (10,)
    assert trace[0]["three"].shape == (10, 1, 2)

    aevb_initial.set_value(np.random.rand(13, 7).astype("float32"))
    trace = three_var_aevb_approx.sample(500, return_inferencedata=False)
    assert set(trace.varnames) == {"one", "one_log__", "two", "three"}
    assert len(trace) == 500
    assert trace[0]["one"].shape == (13, 2)
    assert trace[0]["two"].shape == (10,)
    assert trace[0]["three"].shape == (10, 1, 2)
示例#6
0
    approx = pm.fit(
        20000,
        method='fullrank_advi',
        callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])
    trace = approx.sample(1000)
pm.traceplot(trace, varnames=['l', 'eta'])
#%%
with model:
    group_1 = pm.Group([l, eta],
                       vfam='fr')  # latent1 has full rank approximation
    group_other = pm.Group(None,
                           vfam='mf')  # other variables have mean field Q
    approx = pm.Approximation([group_1, group_other])
    pm.KLqp(approx).fit(
        100000,
        callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])
    trace = approx.sample(1000)
#%% prediction
#nx=50
#x = np.linspace(0, 100, nx)
#y = np.linspace(0, 100, nx)
#xv, yv = np.meshgrid(x, y)
#x_pred = np.vstack((yv.flatten(),xv.flatten())).T

# add the GP conditional to the model, given the new X values
with pm.Model() as predi_model:
    #hyper-parameter priors
    l = pm.HalfNormal('l', sd=.1)
    eta = pm.HalfCauchy('eta', beta=3.)
    cov_func = eta**2 * pm.gp.cov.Matern32(D, ls=l * np.ones(D))
示例#7
0
def init_nuts(init='auto',
              chains=1,
              n_init=500000,
              model=None,
              random_seed=None,
              progressbar=True,
              **kwargs):
    """Set up the mass matrix initialization for NUTS.

    NUTS convergence and sampling speed is extremely dependent on the
    choice of mass/scaling matrix. This function implements different
    methods for choosing or adapting the mass matrix.

    Parameters
    ----------
    init : str
        Initialization method to use.

        * auto : Choose a default initialization method automatically.
          Currently, this is `'jitter+adapt_diag'`, but this can change in
          the future. If you depend on the exact behaviour, choose an
          initialization method explicitly.
        * adapt_diag : Start with a identity mass matrix and then adapt
          a diagonal based on the variance of the tuning samples. All
          chains use the test value (usually the prior mean) as starting
          point.
        * jitter+adapt_diag : Same as `adapt_diag`, but add uniform jitter
          in [-1, 1] to the starting point in each chain.
        * advi+adapt_diag : Run ADVI and then adapt the resulting diagonal
          mass matrix based on the sample variance of the tuning samples.
        * advi+adapt_diag_grad : Run ADVI and then adapt the resulting
          diagonal mass matrix based on the variance of the gradients
          during tuning. This is **experimental** and might be removed
          in a future release.
        * advi : Run ADVI to estimate posterior mean and diagonal mass
          matrix.
        * advi_map: Initialize ADVI with MAP and use MAP as starting point.
        * map : Use the MAP as starting point. This is discouraged.
        * nuts : Run NUTS and estimate posterior mean and mass matrix from
          the trace.
    chains : int
        Number of jobs to start.
    n_init : int
        Number of iterations of initializer
        If 'ADVI', number of iterations, if 'nuts', number of draws.
    model : Model (optional if in `with` context)
    progressbar : bool
        Whether or not to display a progressbar for advi sampling.
    **kwargs : keyword arguments
        Extra keyword arguments are forwarded to pymc3.NUTS.

    Returns
    -------
    start : pymc3.model.Point
        Starting point for sampler
    nuts_sampler : pymc3.step_methods.NUTS
        Instantiated and initialized NUTS sampler object
    """
    model = pm.modelcontext(model)

    vars = kwargs.get('vars', model.vars)
    if set(vars) != set(model.vars):
        raise ValueError('Must use init_nuts on all variables of a model.')
    if not pm.model.all_continuous(vars):
        raise ValueError('init_nuts can only be used for models with only '
                         'continuous variables.')

    if not isinstance(init, str):
        raise TypeError('init must be a string.')

    if init is not None:
        init = init.lower()

    if init == 'auto':
        init = 'jitter+adapt_diag'

    pm._log.info('Initializing NUTS using {}...'.format(init))

    if random_seed is not None:
        random_seed = int(np.atleast_1d(random_seed)[0])
        np.random.seed(random_seed)

    cb = [
        pm.callbacks.CheckParametersConvergence(tolerance=1e-2,
                                                diff='absolute'),
        pm.callbacks.CheckParametersConvergence(tolerance=1e-2,
                                                diff='relative'),
    ]

    if init == 'adapt_diag':
        start = [model.test_point] * chains
        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
        var = np.ones_like(mean)
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, var, 10)
    elif init == 'jitter+adapt_diag':
        start = []
        for _ in range(chains):
            mean = {var: val.copy() for var, val in model.test_point.items()}
            for val in mean.values():
                val[...] += 2 * np.random.rand(*val.shape) - 1
            start.append(mean)
        mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
        var = np.ones_like(mean)
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, var, 10)
    elif init == 'advi+adapt_diag_grad':
        approx = pm.fit(
            random_seed=random_seed,
            n=n_init,
            method='advi',
            model=model,
            callbacks=cb,
            progressbar=progressbar,
            obj_optimizer=pm.adagrad_window,
        )  # type: pm.MeanField
        start = approx.sample(draws=chains)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        mean = approx.bij.rmap(approx.mean.get_value())
        mean = model.dict_to_array(mean)
        weight = 50
        potential = quadpotential.QuadPotentialDiagAdaptGrad(
            model.ndim, mean, cov, weight)
    elif init == 'advi+adapt_diag':
        approx = pm.fit(
            random_seed=random_seed,
            n=n_init,
            method='advi',
            model=model,
            callbacks=cb,
            progressbar=progressbar,
            obj_optimizer=pm.adagrad_window,
        )  # type: pm.MeanField
        start = approx.sample(draws=chains)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        mean = approx.bij.rmap(approx.mean.get_value())
        mean = model.dict_to_array(mean)
        weight = 50
        potential = quadpotential.QuadPotentialDiagAdapt(
            model.ndim, mean, cov, weight)
    elif init == 'advi':
        approx = pm.fit(random_seed=random_seed,
                        n=n_init,
                        method='advi',
                        model=model,
                        callbacks=cb,
                        progressbar=progressbar,
                        obj_optimizer=pm.adagrad_window)  # type: pm.MeanField
        start = approx.sample(draws=chains)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        potential = quadpotential.QuadPotentialDiag(cov)
    elif init == 'advi_map':
        start = pm.find_MAP(include_transformed=True)
        approx = pm.MeanField(model=model, start=start)
        pm.fit(random_seed=random_seed,
               n=n_init,
               method=pm.KLqp(approx),
               callbacks=cb,
               progressbar=progressbar,
               obj_optimizer=pm.adagrad_window)
        start = approx.sample(draws=chains)
        start = list(start)
        stds = approx.bij.rmap(approx.std.eval())
        cov = model.dict_to_array(stds)**2
        potential = quadpotential.QuadPotentialDiag(cov)
    elif init == 'map':
        start = pm.find_MAP(include_transformed=True)
        cov = pm.find_hessian(point=start)
        start = [start] * chains
        potential = quadpotential.QuadPotentialFull(cov)
    elif init == 'nuts':
        init_trace = pm.sample(draws=n_init,
                               step=pm.NUTS(),
                               tune=n_init // 2,
                               random_seed=random_seed)
        cov = np.atleast_1d(pm.trace_cov(init_trace))
        start = list(np.random.choice(init_trace, chains))
        potential = quadpotential.QuadPotentialFull(cov)
    else:
        raise NotImplementedError(
            'Initializer {} is not supported.'.format(init))

    step = pm.NUTS(potential=potential, **kwargs)

    return start, step
    def train_pymc3(docs_te, docs_tr, n_samples_te, n_samples_tr, n_words,
                    n_topics, n_tokens):
        """
        Return: 
            Pymc3 LDA results
        
        Parameters:
            docs_tr: training documents (processed)
            docs_te: testing documents (processed)
            n_samples_te: number of testing docs
            n_samples_tr: number of training docs
            n_words: size of vocabulary
            n_topics: number of topics to learn
            n_tokens: number of non-zero datapoints in processed training tf matrix
            
        """

        # Log-likelihood of documents for LDA
        def logp_lda_doc(beta, theta):
            """
            Returns the log-likelihood function for given documents.

            K : number of topics in the model
            V : number of words (size of vocabulary)
            D : number of documents (in a mini-batch)

            Parameters
            ----------
            beta : tensor (K x V)
              Word distribution.
            theta : tensor (D x K)
              Topic distributions for the documents.
            """
            def ll_docs_f(docs):
                dixs, vixs = docs.nonzero()
                vfreqs = docs[dixs, vixs]
                ll_docs = vfreqs * pmmath.logsumexp(
                    tt.log(theta[dixs]) + tt.log(beta.T[vixs]),
                    axis=1).ravel()

                # Per-word log-likelihood times no. of tokens in the whole dataset
                return tt.sum(ll_docs) / (tt.sum(vfreqs) + 1e-9) * n_tokens

            return ll_docs_f

        # fit the pymc3 LDA

        # we have sparse dataset. It's better to have dence batch so that all words accure there
        minibatch_size = 128

        # defining minibatch
        doc_t_minibatch = pm.Minibatch(docs_tr.toarray(), minibatch_size)
        doc_t = shared(docs_tr.toarray()[:minibatch_size])

        with pm.Model() as model:
            theta = Dirichlet(
                'theta',
                a=pm.floatX((1.0 / n_topics) * np.ones(
                    (minibatch_size, n_topics))),
                shape=(minibatch_size, n_topics),
                transform=t_stick_breaking(1e-9),
                # do not forget scaling
                total_size=n_samples_tr)
            beta = Dirichlet('beta',
                             a=pm.floatX((1.0 / n_topics) * np.ones(
                                 (n_topics, n_words))),
                             shape=(n_topics, n_words),
                             transform=t_stick_breaking(1e-9))
            # Note, that we defined likelihood with scaling, so here we need no additional `total_size` kwarg
            doc = pm.DensityDist('doc',
                                 logp_lda_doc(beta, theta),
                                 observed=doc_t)

        # Encoder
        class LDAEncoder:
            """Encode (term-frequency) document vectors to variational means and (log-transformed) stds.
            """
            def __init__(self,
                         n_words,
                         n_hidden,
                         n_topics,
                         p_corruption=0,
                         random_seed=1):
                rng = np.random.RandomState(random_seed)
                self.n_words = n_words
                self.n_hidden = n_hidden
                self.n_topics = n_topics
                self.w0 = shared(0.01 * rng.randn(n_words, n_hidden).ravel(),
                                 name='w0')
                self.b0 = shared(0.01 * rng.randn(n_hidden), name='b0')
                self.w1 = shared(0.01 * rng.randn(n_hidden, 2 *
                                                  (n_topics - 1)).ravel(),
                                 name='w1')
                self.b1 = shared(0.01 * rng.randn(2 * (n_topics - 1)),
                                 name='b1')
                self.rng = MRG_RandomStreams(seed=random_seed)
                self.p_corruption = p_corruption

            def encode(self, xs):
                if 0 < self.p_corruption:
                    dixs, vixs = xs.nonzero()
                    mask = tt.set_subtensor(
                        tt.zeros_like(xs)[dixs, vixs],
                        self.rng.binomial(size=dixs.shape,
                                          n=1,
                                          p=1 - self.p_corruption))
                    xs_ = xs * mask
                else:
                    xs_ = xs

                w0 = self.w0.reshape((self.n_words, self.n_hidden))
                w1 = self.w1.reshape((self.n_hidden, 2 * (self.n_topics - 1)))
                hs = tt.tanh(xs_.dot(w0) + self.b0)
                zs = hs.dot(w1) + self.b1
                zs_mean = zs[:, :(self.n_topics - 1)]
                zs_rho = zs[:, (self.n_topics - 1):]
                return {'mu': zs_mean, 'rho': zs_rho}

            def get_params(self):
                return [self.w0, self.b0, self.w1, self.b1]

            # call Encoder

        encoder = LDAEncoder(n_words=n_words,
                             n_hidden=100,
                             n_topics=n_topics,
                             p_corruption=0.0)
        local_RVs = OrderedDict([(theta, encoder.encode(doc_t))])

        # get parameters
        encoder_params = encoder.get_params()

        # Train pymc3 Model
        η = .1
        s = shared(η)

        def reduce_rate(a, h, i):
            s.set_value(η / ((i / minibatch_size) + 1)**.7)

        with model:
            approx = pm.MeanField(local_rv=local_RVs)
            approx.scale_cost_to_minibatch = False
            inference = pm.KLqp(approx)
        inference.fit(10000,
                      callbacks=[reduce_rate],
                      obj_optimizer=pm.sgd(learning_rate=s),
                      more_obj_params=encoder_params,
                      total_grad_norm_constraint=200,
                      more_replacements={doc_t: doc_t_minibatch})

        # Extracting characteristic words
        doc_t.set_value(docs_tr.toarray())
        samples = pm.sample_approx(approx, draws=100)
        beta_pymc3 = samples['beta'].mean(axis=0)

        # Predictive distribution
        def calc_pp(ws, thetas, beta, wix):
            """
            Parameters
            ----------
            ws: ndarray (N,)
                Number of times the held-out word appeared in N documents.
            thetas: ndarray, shape=(N, K)
                Topic distributions for N documents.
            beta: ndarray, shape=(K, V)
                Word distributions for K topics.
            wix: int
                Index of the held-out word

            Return
            ------
            Log probability of held-out words.
            """
            return ws * np.log(thetas.dot(beta[:, wix]))

        def eval_lda(transform, beta, docs_te, wixs):
            """Evaluate LDA model by log predictive probability.

            Parameters
            ----------
            transform: Python function
                Transform document vectors to posterior mean of topic proportions.
            wixs: iterable of int
                Word indices to be held-out.
            """
            lpss = []
            docs_ = deepcopy(docs_te)
            thetass = []
            wss = []
            total_words = 0
            for wix in wixs:
                ws = docs_te[:, wix].ravel()
                if 0 < ws.sum():
                    # Hold-out
                    docs_[:, wix] = 0

                    # Topic distributions
                    thetas = transform(docs_)

                    # Predictive log probability
                    lpss.append(calc_pp(ws, thetas, beta, wix))

                    docs_[:, wix] = ws
                    thetass.append(thetas)
                    wss.append(ws)
                    total_words += ws.sum()
                else:
                    thetass.append(None)
                    wss.append(None)

            # Log-probability
            lp = np.sum(np.hstack(lpss)) / total_words

            return {'lp': lp, 'thetass': thetass, 'beta': beta, 'wss': wss}

        inp = tt.matrix(dtype='int64')
        sample_vi_theta = theano.function([inp],
                                          approx.sample_node(
                                              approx.model.theta,
                                              100,
                                              more_replacements={
                                                  doc_t: inp
                                              }).mean(0))

        def transform_pymc3(docs):
            return sample_vi_theta(docs)

        result_pymc3 = eval_lda(transform_pymc3, beta_pymc3, docs_te.toarray(),
                                np.arange(100))
        print('Predictive log prob (pm3) = {}'.format(result_pymc3['lp']))

        return result_pymc3
    def test_algorithm_performace(sklearn_lda, beta_sklearn, beta_pymc3,
                                  docs_te):
        """
        Return:
            CPU time to train the model & Predictive log probability
        """
        def eval_lda(transform, beta, docs_te, wixs):
            ## DUPLICATE
            """Evaluate LDA model by log predictive probability.

            Parameters
            ----------
            transform: Python function
                Transform document vectors to posterior mean of topic proportions.
            wixs: iterable of int
                Word indices to be held-out.
            """
            lpss = []
            docs_ = deepcopy(docs_te)
            thetass = []
            wss = []
            total_words = 0
            for wix in wixs:
                ws = docs_te[:, wix].ravel()
                if 0 < ws.sum():
                    # Hold-out
                    docs_[:, wix] = 0

                    # Topic distributions
                    thetas = transform(docs_)

                    # Predictive log probability
                    lpss.append(calc_pp(ws, thetas, beta, wix))

                    docs_[:, wix] = ws
                    thetass.append(thetas)
                    wss.append(ws)
                    total_words += ws.sum()
                else:
                    thetass.append(None)
                    wss.append(None)

            # Log-probability
            lp = np.sum(np.hstack(lpss)) / total_words

            return {'lp': lp, 'thetass': thetass, 'beta': beta, 'wss': wss}

        η = .1
        s = shared(η)

        def reduce_rate(a, h, i):
            s.set_value(η / ((i / minibatch_size) + 1)**.7)

        with model:
            approx = pm.MeanField(local_rv=local_RVs)
            approx.scale_cost_to_minibatch = False
            inference = pm.KLqp(approx)
        inference.fit(10000,
                      callbacks=[reduce_rate],
                      obj_optimizer=pm.sgd(learning_rate=s),
                      more_obj_params=encoder_params,
                      total_grad_norm_constraint=200,
                      more_replacements={doc_t: doc_t_minibatch})

        inp = tt.matrix(dtype='int64')
        sample_vi_theta = theano.function([inp],
                                          approx.sample_node(
                                              approx.model.theta,
                                              100,
                                              more_replacements={
                                                  doc_t: inp
                                              }).mean(0))

        def transform_pymc3(docs):
            return sample_vi_theta(docs)

        ############ PYMC3 ############
        print('Training Pymc3...')
        t0 = time()
        result_pymc3 = eval_lda(transform_pymc3, beta_pymc3, docs_te.toarray(),
                                np.arange(100))
        pymc3_time = time() - t0
        print('Predictive log prob (pm3) = {}'.format(result_pymc3['lp']))

        ############ sklearn ############
        print('')
        print('')
        print('Training Sklearn...')

        def transform_sklearn(docs):
            thetas = lda.transform(docs)
            return thetas / thetas.sum(axis=1)[:, np.newaxis]

        t0 = time()
        result_sklearn = eval_lda(transform_sklearn, beta_sklearn,
                                  docs_te.toarray(), np.arange(100))
        sklearn_time = time() - t0
        print('Predictive log prob (sklearn) = {}'.format(
            result_sklearn['lp']))

        # save the model times
        times = {
            'pymc3 training time': pymc3_time,
            'sklearn training time': sklearn_time,
        }

        return times
示例#10
0
def run_lda(args):
    tf_vectorizer, docs_tr, docs_te = prepare_sparse_matrix_nonlabel(args.n_tr, args.n_te, args.n_word)
    feature_names = tf_vectorizer.get_feature_names()
    doc_tr_minibatch = pm.Minibatch(docs_tr.toarray(), args.bsz)
    doc_tr = shared(docs_tr.toarray()[:args.bsz])

    def log_prob(beta, theta):
        """Returns the log-likelihood function for given documents.

        K : number of topics in the model
        V : number of words (size of vocabulary)
        D : number of documents (in a mini-batch)

        Parameters
        ----------
        beta : tensor (K x V)
            Word distributions.
        theta : tensor (D x K)
            Topic distributions for documents.
        """

        def ll_docs_f(docs):
            dixs, vixs = docs.nonzero()
            vfreqs = docs[dixs, vixs]
            ll_docs = (vfreqs * pmmath.logsumexp(tt.log(theta[dixs]) + tt.log(beta.T[vixs]),
                                                 axis=1).ravel())

            return tt.sum(ll_docs) / (tt.sum(vfreqs) + 1e-9)

        return ll_docs_f

    with pm.Model() as model:
        beta = Dirichlet("beta",
                         a=pm.floatX((1. / args.n_topic) * np.ones((args.n_topic, args.n_word))),
                         shape=(args.n_topic, args.n_word), )

        theta = Dirichlet("theta",
                          a=pm.floatX((10. / args.n_topic) * np.ones((args.bsz, args.n_topic))),
                          shape=(args.bsz, args.n_topic), total_size=args.n_tr, )

        doc = pm.DensityDist("doc", log_prob(beta, theta), observed=doc_tr)

    encoder = ThetaEncoder(n_words=args.n_word, n_hidden=100, n_topics=args.n_topic)
    local_RVs = OrderedDict([(theta, encoder.encode(doc_tr))])
    encoder_params = encoder.get_params()

    s = shared(args.lr)

    def reduce_rate(a, h, i):
        s.set_value(args.lr / ((i / args.bsz) + 1) ** 0.7)

    with model:
        approx = pm.MeanField(local_rv=local_RVs)
        approx.scale_cost_to_minibatch = False
        inference = pm.KLqp(approx)

    inference.fit(args.n_iter,
                  callbacks=[reduce_rate, pm.callbacks.CheckParametersConvergence(diff="absolute")],
                  obj_optimizer=pm.adam(learning_rate=s),
                  more_obj_params=encoder_params,
                  total_grad_norm_constraint=200,
                  more_replacements={ doc_tr: doc_tr_minibatch }, )

    doc_tr.set_value(docs_tr.toarray())
    inp = tt.matrix(dtype="int64")
    sample_vi_theta = theano.function([inp],
        approx.sample_node(approx.model.theta, args.n_sample, more_replacements={doc_tr: inp}), )

    test = docs_te.toarray()
    test_n = test.sum(1)

    beta_pymc3 = pm.sample_approx(approx, draws=args.n_sample)['beta']
    theta_pymc3 = sample_vi_theta(test)

    assert beta_pymc3.shape == (args.n_sample, args.n_topic, args.n_word)
    assert theta_pymc3.shape == (args.n_sample, args.n_te, args.n_topic)

    beta_mean = beta_pymc3.mean(0)
    theta_mean = theta_pymc3.mean(0)

    pred_rate = theta_mean.dot(beta_mean)
    pp_test = (test * np.log(pred_rate)).sum(1) / test_n

    posteriors = { 'theta': theta_pymc3, 'beta': beta_pymc3,}

    log_top_words(beta_pymc3.mean(0), feature_names, n_top_words=args.n_top_word)
    save_elbo(approx.hist)
    save_pp(pp_test)
    save_draws(posteriors)