示例#1
0
 def _construct_compute_fe_terms(self):
     """
     Construct theano function to compute the log-likelihood and posterior
     KL-divergence terms for the variational free-energy.
     """
     # setup some symbolic variables for theano to deal with
     Xd = T.matrix()
     Xc = T.zeros_like(Xd)
     Xm = T.zeros_like(Xd)
     # construct values to output
     if self.x_type == 'bernoulli':
         ll_term = log_prob_bernoulli(self.x, self.xg)
     else:
         ll_term = log_prob_gaussian2(self.x, self.xg, \
                 log_vars=self.bounded_logvar)
     all_klds = gaussian_kld(self.q_z_given_x.output_mean, \
             self.q_z_given_x.output_logvar, \
             self.prior_mean, self.prior_logvar)
     kld_term = T.sum(all_klds, axis=1)
     # compile theano function for a one-sample free-energy estimate
     fe_term_sample = theano.function(inputs=[Xd], \
             outputs=[ll_term, kld_term], \
             givens={self.Xd: Xd, self.Xc: Xc, self.Xm: Xm})
     # construct a wrapper function for multi-sample free-energy estimate
     def fe_term_estimator(X, sample_count):
         ll_sum = np.zeros((X.shape[0],))
         kld_sum = np.zeros((X.shape[0],))
         for i in range(sample_count):
             result = fe_term_sample(X)
             ll_sum = ll_sum + result[0].ravel()
             kld_sum = kld_sum + result[1].ravel()
         mean_nll = -ll_sum / float(sample_count)
         mean_kld = kld_sum / float(sample_count)
         return [mean_nll, mean_kld]
     return fe_term_estimator
    def _construct_compute_fe_terms(self):
        """
        Construct theano function to compute the log-likelihood and posterior
        KL-divergence terms for the variational free-energy.
        """
        # construct values to output
        if self.x_type == 'bernoulli':
            ll_term = log_prob_bernoulli(self.x_in, self.xg)
        else:
            ll_term = log_prob_gaussian2(self.x_in, self.xg, \
                    log_vars=self.bounded_logvar)
        all_klds = gaussian_kld(self.z_mean, self.z_logvar, \
                                self.prior_mean, self.prior_logvar)
        kld_term = T.sum(all_klds, axis=1)
        # compile theano function for a one-sample free-energy estimate
        fe_term_sample = theano.function(inputs=[self.x_in], \
                                         outputs=[ll_term, kld_term])

        # construct a wrapper function for multi-sample free-energy estimate
        def fe_term_estimator(X, sample_count):
            X = to_fX(X)
            ll_sum = np.zeros((X.shape[0], ))
            kld_sum = np.zeros((X.shape[0], ))
            for i in range(sample_count):
                result = fe_term_sample(X)
                ll_sum = ll_sum + result[0].ravel()
                kld_sum = kld_sum + result[1].ravel()
            mean_nll = -ll_sum / float(sample_count)
            mean_kld = kld_sum / float(sample_count)
            return [mean_nll, mean_kld]

        return fe_term_estimator
 def _construct_compute_fe_terms(self):
     """
     Construct theano function to compute the log-likelihood and posterior
     KL-divergence terms for the variational free-energy.
     """
     # construct values to output
     if self.x_type == 'bernoulli':
         ll_term = log_prob_bernoulli(self.x_in, self.xg)
     else:
         ll_term = log_prob_gaussian2(self.x_in, self.xg, \
                 log_vars=self.bounded_logvar)
     all_klds = gaussian_kld(self.z_mean, self.z_logvar, \
                             self.prior_mean, self.prior_logvar)
     kld_term = T.sum(all_klds, axis=1)
     # compile theano function for a one-sample free-energy estimate
     fe_term_sample = theano.function(inputs=[self.x_in], \
                                      outputs=[ll_term, kld_term])
     # construct a wrapper function for multi-sample free-energy estimate
     def fe_term_estimator(X, sample_count):
         X = to_fX(X)
         ll_sum = np.zeros((X.shape[0],))
         kld_sum = np.zeros((X.shape[0],))
         for i in range(sample_count):
             result = fe_term_sample(X)
             ll_sum = ll_sum + result[0].ravel()
             kld_sum = kld_sum + result[1].ravel()
         mean_nll = -ll_sum / float(sample_count)
         mean_kld = kld_sum / float(sample_count)
         return [mean_nll, mean_kld]
     return fe_term_estimator
示例#4
0
 def _log_prob_wrapper(self, x_true, x_apprx):
     """
     Wrap log-prob with switching for bernoulli/gaussian output types.
     """
     if self.x_type == 'bernoulli':
         ll_cost = log_prob_bernoulli(x_true, x_apprx)
     else:
         ll_cost = log_prob_gaussian2(x_true, x_apprx, \
                 log_vars=self.bounded_logvar)
     nll_cost = -ll_cost
     return nll_cost
示例#5
0
 def _construct_nll_costs(self):
     """
     Construct the negative log-likelihood part of cost to minimize.
     """
     if self.x_type == 'bernoulli':
         ll_cost = log_prob_bernoulli(self.x, self.xg)
     else:
         ll_cost = log_prob_gaussian2(self.x, self.xg, \
                 log_vars=self.bounded_logvar)
     nll_cost = -ll_cost
     return nll_cost
示例#6
0
 def _log_prob_wrapper(self, x_true, x_apprx):
     """
     Wrap log-prob with switching for bernoulli/gaussian output types.
     """
     if self.x_type == 'bernoulli':
         ll_cost = log_prob_bernoulli(x_true, x_apprx)
     else:
         ll_cost = log_prob_gaussian2(x_true, x_apprx, \
                 log_vars=self.bounded_logvar)
     nll_cost = -ll_cost
     return nll_cost
示例#7
0
 def _construct_nll_costs(self):
     """
     Construct the negative log-likelihood part of cost to minimize.
     """
     if self.x_type == 'bernoulli':
         ll_cost = log_prob_bernoulli(self.x, self.xg)
     else:
         ll_cost = log_prob_gaussian2(self.x, self.xg, \
                 log_vars=self.bounded_logvar)
     nll_cost = -ll_cost
     return nll_cost
 def _construct_nll_costs(self, xo):
     """
     Construct the negative log-likelihood part of free energy.
     """
     # average log-likelihood over the refinement sequence
     xh = self.obs_transform(self.x_gen)
     if self.x_type == 'bernoulli':
         ll_costs = log_prob_bernoulli(xo, xh)
     else:
         ll_costs = log_prob_gaussian2(xo, xh, log_vars=self.bounded_logvar)
     nll_costs = -ll_costs
     return nll_costs
 def _construct_nll_costs(self, si, xo, nll_mask):
     """
     Construct the negative log-likelihood part of free energy.
     -- only check NLL where nll_mask == 1
     """
     xh = self._from_si_to_x(si)
     if self.x_type == "bernoulli":
         ll_costs = log_prob_bernoulli(xo, xh, mask=nll_mask)
     else:
         ll_costs = log_prob_gaussian2(xo, xh, log_vars=self.bounded_logvar, mask=nll_mask)
     nll_costs = -ll_costs.flatten()
     return nll_costs
示例#10
0
    def compute_log_prob(self, Xd=None):
        """
        Compute negative log likelihood of the data in Xd, with respect to the
        output distributions currently at self.output_....

        Compute log-prob for all entries in Xd.
        """
        if (self.out_type == 'bernoulli'):
            log_prob_cost = log_prob_bernoulli(Xd, self.output, mask=self.output_mask)
        else:
            log_prob_cost = log_prob_gaussian2(Xd, self.output_mu, \
                    les_logvars=self.output_logvar, mask=self.output_mask)
        return log_prob_cost
示例#11
0
 def _construct_nll_costs(self, si, xo, nll_mask):
     """
     Construct the negative log-likelihood part of free energy.
     -- only check NLL where nll_mask == 1
     """
     xh = self._from_si_to_x(si)
     if self.x_type == 'bernoulli':
         ll_costs = log_prob_bernoulli(xo, xh, mask=nll_mask)
     else:
         ll_costs = log_prob_gaussian2(xo, xh, \
                 log_vars=self.bounded_logvar, mask=nll_mask)
     nll_costs = -ll_costs.flatten()
     return nll_costs
 def _construct_nll_costs(self, si, xo, xm):
     """
     Construct the negative log-likelihood part of free energy.
     """
     # average log-likelihood over the refinement sequence
     xh = self._si_as_x(si)
     xm_inv = 1.0 - xm # we will measure nll only where xm_inv is 1
     if self.x_type == 'bernoulli':
         ll_costs = log_prob_bernoulli(xo, xh, mask=xm_inv)
     else:
         ll_costs = log_prob_gaussian2(xo, xh, \
                 log_vars=self.bounded_logvar, mask=xm_inv)
     nll_costs = -ll_costs.flatten()
     return nll_costs
示例#13
0
 def _construct_nll_costs(self, si, xo, xm):
     """
     Construct the negative log-likelihood part of free energy.
     """
     # average log-likelihood over the refinement sequence
     xh = self._si_as_x(si)
     xm_inv = 1.0 - xm  # we will measure nll only where xm_inv is 1
     if self.x_type == 'bernoulli':
         ll_costs = log_prob_bernoulli(xo, xh, mask=xm_inv)
     else:
         ll_costs = log_prob_gaussian2(xo, xh, \
                 log_vars=self.bounded_logvar, mask=xm_inv)
     nll_costs = -ll_costs.flatten()
     return nll_costs
示例#14
0
    def masked_log_prob(self, Xc=None, Xm=None):
        """
        Compute negative log likelihood of the data in Xc, with respect to the
        output distributions currently at self.output_....

        Select entries in Xd to compute log-prob for based on the mask Xm. When
        Xm[i] == 1, don't measure NLL Xc[i]...
        """
        # to measure NLL for Xc[i] only when Xm[i] is 0, we need to make an
        # inverse mask Xm_inv = 1 - X_m, because the masking in the log pdf
        # functions measures NLL only for observations where the mask != 0.
        Xm_inv = 1.0 - Xm
        if (self.out_type == 'bernoulli'):
            log_prob_cost = log_prob_bernoulli(Xc, self.output, mask=Xm_inv)
        else:
            log_prob_cost = log_prob_gaussian2(Xc, self.output_mu, \
                    les_logvars=self.output_logvar, mask=Xm_inv)
        return log_prob_cost
示例#15
0
    def _construct_compute_fe_terms(self):
        """
        Construct theano function to compute the log-likelihood and posterior
        KL-divergence terms for the variational free-energy.
        """
        # setup some symbolic variables for theano to deal with
        Xd = T.matrix()
        Xc = T.zeros_like(Xd)
        Xm = T.zeros_like(Xd)
        # construct values to output
        if self.x_type == 'bernoulli':
            ll_term = log_prob_bernoulli(self.x, self.xg)
        else:
            ll_term = log_prob_gaussian2(self.x, self.xg, \
                    log_vars=self.bounded_logvar)
        all_klds = gaussian_kld(self.q_z_given_x.output_mean, \
                self.q_z_given_x.output_logvar, \
                self.prior_mean, self.prior_logvar)
        kld_term = T.sum(all_klds, axis=1)
        # compile theano function for a one-sample free-energy estimate
        fe_term_sample = theano.function(inputs=[Xd], \
                outputs=[ll_term, kld_term], \
                givens={self.Xd: Xd, self.Xc: Xc, self.Xm: Xm})

        # construct a wrapper function for multi-sample free-energy estimate
        def fe_term_estimator(X, sample_count):
            ll_sum = np.zeros((X.shape[0], ))
            kld_sum = np.zeros((X.shape[0], ))
            for i in range(sample_count):
                result = fe_term_sample(X)
                ll_sum = ll_sum + result[0].ravel()
                kld_sum = kld_sum + result[1].ravel()
            mean_nll = -ll_sum / float(sample_count)
            mean_kld = kld_sum / float(sample_count)
            return [mean_nll, mean_kld]

        return fe_term_estimator
示例#16
0
    def __init__(self, rng=None, \
            x_in=None, x_out=None, \
            p_s0_given_z=None, \
            p_hi_given_si=None, \
            p_sip1_given_si_hi=None, \
            q_z_given_x=None, \
            q_hi_given_x_si=None, \
            obs_dim=None, \
            z_dim=None, h_dim=None, \
            ir_steps=4, params=None, \
            shared_param_dicts=None):
        # setup a rng for this GIPair
        self.rng = RandStream(rng.randint(100000))

        # grab the user-provided parameters
        self.params = params
        self.x_type = self.params['x_type']
        assert((self.x_type == 'bernoulli') or (self.x_type == 'gaussian'))
        if 'obs_transform' in self.params:
            assert((self.params['obs_transform'] == 'sigmoid') or \
                    (self.params['obs_transform'] == 'none'))
            if self.params['obs_transform'] == 'sigmoid':
                self.obs_transform = lambda x: T.nnet.sigmoid(20.0 * T.tanh(0.05 * x))
            else:
                self.obs_transform = lambda x: x
        else:
            self.obs_transform = lambda x: T.nnet.sigmoid(20.0 * T.tanh(0.05 * x))
        if self.x_type == 'bernoulli':
            self.obs_transform = lambda x: T.nnet.sigmoid(20.0 * T.tanh(0.05 * x))
        self.shared_param_dicts = shared_param_dicts

        # record the dimensions of various spaces relevant to this model
        self.obs_dim = obs_dim
        self.z_dim = z_dim
        self.h_dim = h_dim
        self.ir_steps = ir_steps

        # grab handles to the relevant InfNets
        self.q_z_given_x = q_z_given_x
        self.q_hi_given_x_si = q_hi_given_x_si
        self.p_s0_given_z = p_s0_given_z
        self.p_hi_given_si = p_hi_given_si
        self.p_sip1_given_si_hi = p_sip1_given_si_hi

        # record the symbolic variables that will provide inputs to the
        # computation graph created to describe this MultiStageModel
        self.x_in = x_in
        self.x_out = x_out
        self.hi_zmuv = T.tensor3() # for ZMUV Gaussian samples to use in scan

        # setup switching variable for changing between sampling/training
        zero_ary = to_fX( np.zeros((1,)) )
        self.train_switch = theano.shared(value=zero_ary, name='msm_train_switch')
        self.set_train_switch(1.0)
        # setup a variable for controlling dropout noise
        self.drop_rate = theano.shared(value=zero_ary, name='msm_drop_rate')
        self.set_drop_rate(0.0)
        # this weight balances l1 vs. l2 penalty on posterior KLds
        self.lam_kld_l1l2 = theano.shared(value=zero_ary, name='msm_lam_kld_l1l2')
        self.set_lam_kld_l1l2(1.0)

        if self.shared_param_dicts is None:
            # initialize "optimizable" parameters specific to this MSM
            init_vec = to_fX( np.zeros((self.z_dim,)) )
            self.p_z_mean = theano.shared(value=init_vec, name='msm_p_z_mean')
            self.p_z_logvar = theano.shared(value=init_vec, name='msm_p_z_logvar')
            init_vec = to_fX( np.zeros((self.obs_dim,)) )
            self.obs_logvar = theano.shared(value=zero_ary, name='msm_obs_logvar')
            self.bounded_logvar = 8.0 * T.tanh((1.0/8.0) * self.obs_logvar)
            self.shared_param_dicts = {}
            self.shared_param_dicts['p_z_mean'] = self.p_z_mean
            self.shared_param_dicts['p_z_logvar'] = self.p_z_logvar
            self.shared_param_dicts['obs_logvar'] = self.obs_logvar
        else:
            self.p_z_mean = self.shared_param_dicts['p_z_mean']
            self.p_z_logvar = self.shared_param_dicts['p_z_logvar']
            self.obs_logvar = self.shared_param_dicts['obs_logvar']
            self.bounded_logvar = 8.0 * T.tanh((1.0/8.0) * self.obs_logvar)

        # setup a function for computing reconstruction log likelihood
        if self.x_type == 'bernoulli':
            self.log_prob_func = lambda xo, xh: \
                    (-1.0 * log_prob_bernoulli(xo, xh))
        else:
            self.log_prob_func = lambda xo, xh: \
                    (-1.0 * log_prob_gaussian2(xo, xh, \
                     log_vars=self.bounded_logvar))

        # get a drop mask that drops things with probability p
        drop_scale = 1. / (1. - self.drop_rate[0])
        drop_rnd = self.rng.uniform(size=self.x_out.shape, \
                low=0.0, high=1.0, dtype=theano.config.floatX)
        drop_mask = drop_scale * (drop_rnd > self.drop_rate[0])

        #############################
        # Setup self.z and self.s0. #
        #############################
        print("Building MSM step 0...")
        drop_x = drop_mask * self.x_in
        self.q_z_mean, self.q_z_logvar, self.z = \
                self.q_z_given_x.apply(drop_x, do_samples=True)
        # get initial observation state
        self.s0, _ = self.p_s0_given_z.apply(self.z, do_samples=False)

        # gather KLd and NLL for the initialization step
        self.init_klds = gaussian_kld(self.q_z_mean, self.q_z_logvar, \
                                      self.p_z_mean, self.p_z_logvar)
        self.init_nlls =  -1.0 * \
                self.log_prob_func(self.x_out, self.obs_transform(self.s0))

        ##################################################
        # Setup the iterative generation loop using scan #
        ##################################################
        def ir_step_func(hi_zmuv, sim1):
            # get variables used throughout this refinement step
            sim1_obs = self.obs_transform(sim1) # transform state -> obs
            grad_ll = self.x_out - sim1_obs

            # get samples of next hi, conditioned on current si
            hi_p_mean, hi_p_logvar = self.p_hi_given_si.apply( \
                    sim1_obs, do_samples=False)
            # now we build the model for variational hi given si
            hi_q_mean, hi_q_logvar = self.q_hi_given_x_si.apply( \
                    T.horizontal_stack(grad_ll, sim1_obs), \
                    do_samples=False)
            hi_q = (T.exp(0.5 * hi_q_logvar) * hi_zmuv) + hi_q_mean
            hi_p = (T.exp(0.5 * hi_p_logvar) * hi_zmuv) + hi_p_mean

            # make hi samples that can be switched between hi_p and hi_q
            hi = ( ((self.train_switch[0] * hi_q) + \
                    ((1.0 - self.train_switch[0]) * hi_p)) )

            # p_sip1_given_si_hi is conditioned on si and  hi.
            ig_vals, fg_vals, in_vals = self.p_sip1_given_si_hi.apply(hi)
                    
            # get the transformed values (for an LSTM style update)
            i_gate = 1.0 * T.nnet.sigmoid(ig_vals + 2.0)
            f_gate = 1.0 * T.nnet.sigmoid(fg_vals + 2.0)
            # perform an LSTM-like update of the state sim1 -> si
            si = (in_vals * i_gate) + (sim1 * f_gate)

            # compute generator NLL for this step
            nlli = self.log_prob_func(self.x_out, self.obs_transform(si))
            # compute relevant KLds for this step
            kldi_q2p = gaussian_kld(hi_q_mean, hi_q_logvar, \
                                    hi_p_mean, hi_p_logvar)
            kldi_p2q = gaussian_kld(hi_p_mean, hi_p_logvar, \
                                    hi_q_mean, hi_q_logvar)
            return si, nlli, kldi_q2p, kldi_p2q

        init_values = [self.s0, None, None, None]

        self.scan_results, self.scan_updates = theano.scan(ir_step_func, \
                outputs_info=init_values, sequences=self.hi_zmuv)

        self.si = self.scan_results[0]
        self.nlli = self.scan_results[1]
        self.kldi_q2p = self.scan_results[2]
        self.kldi_p2q = self.scan_results[3]

        ######################################################################
        # ALL SYMBOLIC VARS NEEDED FOR THE OBJECTIVE SHOULD NOW BE AVAILABLE #
        ######################################################################

        # shared var learning rate for generator and inferencer
        zero_ary = to_fX( np.zeros((1,)) )
        self.lr_1 = theano.shared(value=zero_ary, name='msm_lr_1')
        self.lr_2 = theano.shared(value=zero_ary, name='msm_lr_2')
        # shared var momentum parameters for generator and inferencer
        self.mom_1 = theano.shared(value=zero_ary, name='msm_mom_1')
        self.mom_2 = theano.shared(value=zero_ary, name='msm_mom_2')
        # init parameters for controlling learning dynamics
        self.set_sgd_params()
        # init shared var for weighting nll of data given posterior sample
        self.lam_nll = theano.shared(value=zero_ary, name='msm_lam_nll')
        self.set_lam_nll(lam_nll=1.0)
        # init shared var for weighting prior kld against reconstruction
        self.lam_kld_z = theano.shared(value=zero_ary, name='msm_lam_kld_z')
        self.lam_kld_q2p = theano.shared(value=zero_ary, name='msm_lam_kld_q2p')
        self.lam_kld_p2q = theano.shared(value=zero_ary, name='msm_lam_kld_p2q')
        self.set_lam_kld(lam_kld_z=1.0, lam_kld_q2p=0.7, lam_kld_p2q=0.3)
        # init shared var for controlling l2 regularization on params
        self.lam_l2w = theano.shared(value=zero_ary, name='msm_lam_l2w')
        self.set_lam_l2w(1e-5)

        # Grab all of the "optimizable" parameters in "group 1"
        self.q_params = []
        self.q_params.extend(self.q_z_given_x.mlp_params)
        self.q_params.extend(self.q_hi_given_x_si.mlp_params)
        # Grab all of the "optimizable" parameters in "group 2"
        self.p_params = [self.p_z_mean, self.p_z_logvar]
        self.p_params.extend(self.p_hi_given_si.mlp_params)
        self.p_params.extend(self.p_sip1_given_si_hi.mlp_params)
        self.p_params.extend(self.p_s0_given_z.mlp_params)

        # Make a joint list of parameters group 1/2
        self.joint_params = self.q_params + self.p_params

        #################################
        # CONSTRUCT THE KLD-BASED COSTS #
        #################################
        self.kld_z_q2p, self.kld_z_p2q, self.kld_hi_q2p, self.kld_hi_p2q = \
                self._construct_kld_costs(p=1.0)
        self.kld_z = (self.lam_kld_q2p[0] * self.kld_z_q2p) + \
                     (self.lam_kld_p2q[0] * self.kld_z_p2q)
        self.kld_hi = (self.lam_kld_q2p[0] * self.kld_hi_q2p) + \
                      (self.lam_kld_p2q[0] * self.kld_hi_p2q)
        self.kld_costs = (self.lam_kld_z[0] * self.kld_z) + self.kld_hi
        # now do l2 KLd costs
        self.kl2_z_q2p, self.kl2_z_p2q, self.kl2_hi_q2p, self.kl2_hi_p2q = \
                self._construct_kld_costs(p=2.0)
        self.kl2_z = (self.lam_kld_q2p[0] * self.kl2_z_q2p) + \
                     (self.lam_kld_p2q[0] * self.kl2_z_p2q)
        self.kl2_hi = (self.lam_kld_q2p[0] * self.kl2_hi_q2p) + \
                      (self.lam_kld_p2q[0] * self.kl2_hi_p2q)
        self.kl2_costs = (self.lam_kld_z[0] * self.kl2_z) + self.kl2_hi
        # compute joint l1/l2 KLd cost
        self.kld_l1l2_costs = (self.lam_kld_l1l2[0] * self.kld_costs) + \
                ((1.0 - self.lam_kld_l1l2[0]) * self.kl2_costs)
        # compute "mean" (rather than per-input) costs
        self.kld_cost = T.mean(self.kld_costs)
        self.kl2_cost = T.mean(self.kl2_costs)
        self.kld_l1l2_cost = T.mean(self.kld_l1l2_costs)
        #################################
        # CONSTRUCT THE NLL-BASED COSTS #
        #################################
        self.nll_costs = self.nlli[-1]
        self.nll_cost = self.lam_nll[0] * T.mean(self.nll_costs)
        ########################################
        # CONSTRUCT THE REST OF THE JOINT COST #
        ########################################
        param_reg_cost = self._construct_reg_costs()
        self.reg_cost = self.lam_l2w[0] * param_reg_cost
        self.joint_cost = self.nll_cost + self.kld_l1l2_cost + \
                          self.reg_cost
        ##############################
        # CONSTRUCT A PER-INPUT COST #
        ##############################
        self.obs_costs = self.nll_costs + self.kld_l1l2_costs

        # Get the gradient of the joint cost for all optimizable parameters
        print("Computing gradients of self.joint_cost...")
        self.joint_grads = OrderedDict()
        grad_list = T.grad(self.joint_cost, self.joint_params)
        for i, p in enumerate(self.joint_params):
            self.joint_grads[p] = grad_list[i]

        # Construct the updates for the generator and inferencer networks
        self.q_updates = get_adam_updates(params=self.q_params, \
                grads=self.joint_grads, alpha=self.lr_1, \
                beta1=self.mom_1, beta2=self.mom_2, \
                mom2_init=1e-3, smoothing=1e-5, max_grad_norm=10.0)
        self.p_updates = get_adam_updates(params=self.p_params, \
                grads=self.joint_grads, alpha=self.lr_2, \
                beta1=self.mom_1, beta2=self.mom_2, \
                mom2_init=1e-3, smoothing=1e-5, max_grad_norm=10.0)
        self.joint_updates = OrderedDict()
        for k in self.q_updates:
            self.joint_updates[k] = self.q_updates[k]
        for k in self.p_updates:
            self.joint_updates[k] = self.p_updates[k]
        # add scan updates, which seem to be required
        for k in self.scan_updates:
            self.joint_updates[k] = self.scan_updates[k]

        # Construct a function for jointly training the generator/inferencer
        print("Compiling cost computer...")
        self.compute_raw_klds = self._construct_raw_klds()
        print("Compiling training function...")
        self.train_joint = self._construct_train_joint()
        print("Compiling free-energy sampler...")
        self.compute_fe_terms = self._construct_compute_fe_terms()
        print("Compiling open-loop model sampler...")
        self.sample_from_prior = self._construct_sample_from_prior()
        print("Compiling data-guided model sampler...")
        self.sample_from_input = self._construct_sample_from_input()
        return