Ejemplo n.º 1
0
def smooth_kl_divergence(p, q):
    """Measure the KL-divergence from "approximate" distribution q to "true"
    distribution p. Use smoothed softmax to convert p and q from encodings
    in terms of relative log-likelihoods into sum-to-one distributions."""
    p_sm = safe_softmax(p)
    q_sm = safe_softmax(q)
    # This term is: cross_entropy(p, q) - entropy(p)
    kl_sm = T.sum(((T.log(p_sm) - T.log(q_sm)) * p_sm), axis=1, keepdims=True)
    return kl_sm
Ejemplo n.º 2
0
def smooth_cross_entropy(p, q):
    """Measure the cross-entropy between "approximate" distribution q and
    "true" distribution p. Use smoothed softmax to convert p and q from
    encodings in terms of relative log-likelihoods into sum-to-one dists."""
    p_sm = safe_softmax(p)
    q_sm = safe_softmax(q)
    # This term is: entropy(p) + kl_divergence(p, q)
    ce_sm = -T.sum((p_sm * T.log(q_sm)), axis=1, keepdims=True)
    return ce_sm
Ejemplo n.º 3
0
def smooth_cross_entropy(p, q):
    """Measure the cross-entropy between "approximate" distribution q and
    "true" distribution p. Use smoothed softmax to convert p and q from
    encodings in terms of relative log-likelihoods into sum-to-one dists."""
    p_sm = safe_softmax(p)
    q_sm = safe_softmax(q)
    # This term is: entropy(p) + kl_divergence(p, q)
    ce_sm = -T.sum((p_sm * T.log(q_sm)), axis=1, keepdims=True)
    return ce_sm
Ejemplo n.º 4
0
def smooth_kl_divergence(p, q):
    """Measure the KL-divergence from "approximate" distribution q to "true"
    distribution p. Use smoothed softmax to convert p and q from encodings
    in terms of relative log-likelihoods into sum-to-one distributions."""
    p_sm = safe_softmax(p)
    q_sm = safe_softmax(q)
    # This term is: cross_entropy(p, q) - entropy(p)
    kl_sm = T.sum(((T.log(p_sm) - T.log(q_sm)) * p_sm), axis=1, keepdims=True)
    return kl_sm
Ejemplo n.º 5
0
def smooth_js_divergence(p, q):
    """
    Measure the Jensen-Shannon divergence between (log-space) p and q.
    """
    p_sm = safe_softmax(p)
    q_sm = safe_softmax(q)
    mean_dist = (p_sm + q_sm) / 2.0
    js_1 = T.sum(p_sm * (T.log(p_sm) - T.log(mean_dist)), axis=1, keepdims=True)
    js_2 = T.sum(q_sm * (T.log(q_sm) - T.log(mean_dist)), axis=1, keepdims=True)
    js_div = (js_1 + js_2) / 2.0
    return js_div
Ejemplo n.º 6
0
def smooth_js_divergence(p, q):
    """
    Measure the Jensen-Shannon divergence between (log-space) p and q.
    """
    p_sm = safe_softmax(p)
    q_sm = safe_softmax(q)
    mean_dist = (p_sm + q_sm) / 2.0
    js_1 = T.sum(p_sm * (T.log(p_sm) - T.log(mean_dist)),
                 axis=1,
                 keepdims=True)
    js_2 = T.sum(q_sm * (T.log(q_sm) - T.log(mean_dist)),
                 axis=1,
                 keepdims=True)
    js_div = (js_1 + js_2) / 2.0
    return js_div
Ejemplo n.º 7
0
def smooth_softmax(p, lam_smooth=1e-3):
    """
    Give a "smoothed" softmax distribution form for p. This is to counter
    the tendency for KL-divergence, cross-entropy, etc. to get a bit wonky
    when comparing strongly peaked categorical distributions.
    """
    dist_size = T.cast(p.shape[1], 'floatX')
    p_sm = (safe_softmax(p) + lam_smooth) / (1.0 + (dist_size * lam_smooth))
    return p_sm
Ejemplo n.º 8
0
 def _construct_nll_costs(self, yi):
     """
     Construct the categorical log-likelihood part of cost.
     """
     y_prob = safe_softmax(self.y_out)
     row_idx = T.arange(yi.shape[0])
     col_idx = yi.flatten() - 1
     row_mask = T.neq(yi, 0).reshape((yi.shape[0], 1))
     wacky_mat = (y_prob * row_mask) + (1. - row_mask)
     flat_nlls = -T.log(wacky_mat[row_idx,col_idx])
     class_nlls = flat_nlls.reshape((yi.shape[0], 1))
     return class_nlls
Ejemplo n.º 9
0
 def _construct_sample_posterior(self):
     """
     Construct a function for sampling from the categorical distribution
     resulting from taking a softmax of the output of this PeaNet.
     """
     func = theano.function([self.Xd], \
             outputs=safe_softmax(self.output_proto))
     # this function is based on "roulette wheel" sampling
     def sampler(x):
         y_probs = func(x)
         y_cumsum = np.cumsum(y_probs, axis=1)
         rand_vals = npr.rand(y_probs.shape[0],1)
         y_bin = np.zeros(y_probs.shape)
         for row in range(y_bin.shape[0]):
             for col in range(y_bin.shape[1]):
                 if y_cumsum[row,col] > rand_vals[row]:
                     y_bin[row,col] = 1.0
                     break
         y_bin = y_bin.astype(theano.config.floatX)
         return y_bin
     return sampler
Ejemplo n.º 10
0
    def _construct_sample_posterior(self):
        """
        Construct a function for sampling from the categorical distribution
        resulting from taking a softmax of the output of this PeaNet.
        """
        func = theano.function([self.Xd], \
                outputs=safe_softmax(self.output_proto))

        # this function is based on "roulette wheel" sampling
        def sampler(x):
            y_probs = func(x)
            y_cumsum = np.cumsum(y_probs, axis=1)
            rand_vals = npr.rand(y_probs.shape[0], 1)
            y_bin = np.zeros(y_probs.shape)
            for row in range(y_bin.shape[0]):
                for col in range(y_bin.shape[1]):
                    if y_cumsum[row, col] > rand_vals[row]:
                        y_bin[row, col] = 1.0
                        break
            y_bin = y_bin.astype(theano.config.floatX)
            return y_bin

        return sampler
Ejemplo n.º 11
0
    def __init__(self, rng=None, \
            Xd=None, Yd=None, Xc=None, Xm=None, \
            g_net=None, i_net=None, p_net=None, \
            data_dim=None, prior_dim=None, label_dim=None, \
            batch_size=None, \
            params=None, shared_param_dicts=None):
        # TODO: refactor for use with "encoded" inferencer/generator
        assert(not (i_net.use_encoder or g_net.use_encoder))

        # setup a rng for this GITrip
        self.rng = RandStream(rng.randint(100000))
        # setup the prior distribution over the categorical variable
        if params is None:
            self.params = {}
        else:
            self.params = params

        # record the dimensionality of the data handled by this GITrip
        self.data_dim = data_dim
        self.label_dim = label_dim
        self.prior_dim = prior_dim
        self.batch_size = batch_size

        # create a mask for disabling and/or reweighting input dimensions
        row_mask = np.ones((self.data_dim,)).astype(theano.config.floatX)
        self.input_mask = theano.shared(value=row_mask, name='git_input_mask')
        
        # record the symbolic variables that will provide inputs to the
        # computation graph created to describe this GITrip
        self.Xd = self.input_mask * Xd
        self.Yd = Yd
        self.Xc = Xc
        self.Xm = Xm
        
        # construct a vertically-repeated identity matrix for marginalizing
        # over possible values of the categorical latent variable.
        Ic = np.vstack([np.identity(label_dim) for i in range(batch_size)])
        self.Ic = theano.shared(value=Ic.astype(theano.config.floatX), name='git_Ic')
        # create "shared-parameter" clones of the continuous and categorical
        # inferencers that this GITrip will be built on.
        self.IN = i_net.shared_param_clone(rng=rng, \
                Xd=self.Xd, Xc=self.Xc, Xm=self.Xm)
        self.PN = p_net.shared_param_clone(rng=rng, Xd=self.Xd)
        # create symbolic variables for the approximate posteriors over the 
        # continuous and categorical latent variables
        self.Xp = self.IN.output
        self.Yp = safe_softmax(self.PN.output_spawn[0])
        self.Yp_proto = safe_softmax(self.PN.output_proto)
        # create a symbolic variable structured to allow easy "marginalization"
        # over possible settings of the categorical latent variable. the left
        # matrix (i.e. self.Ic) comprises batch_size copies of the label_dim
        # dimensional identity matrix stacked on top of each other, and the
        # right matrix comprises a single sample from the approximate posterior
        # over the continuous latent variables for each of batch_size examples
        # with each sample repeated label_dim times.
        self.XYp = T.horizontal_stack(self.Ic, T.repeat(self.Xp, \
                self.label_dim, axis=0))
        # pipe the "convenient marginlization" matrix into a shared parameter
        # clone of the generator network
        self.GN = g_net.shared_param_clone(rng=rng, Xp=self.XYp)
        # capture a handle for sampled reconstructions from the generator
        self.Xg = self.GN.output

        # we will be assuming one proto-net in the pseudo-ensemble represented
        # by self.PN, and either one or two spawn-nets for that proto-net.
        assert(len(self.PN.proto_nets) == 1)
        assert((len(self.PN.spawn_nets) == 1) or \
                (len(self.PN.spawn_nets) == 2))
        # output of the generator and input to the inferencer should both be
        # equal to self.data_dim
        assert(self.data_dim == self.GN.mlp_layers[-1].out_dim)
        assert(self.data_dim == self.IN.shared_layers[0].in_dim)
        assert(self.data_dim == self.PN.proto_nets[0][0].in_dim)
        # mu/sigma outputs of self.IN should be equal to prior_dim, output of
        # self.PN should be equal to label_dim, and input of self.GN should be
        # equal to prior_dim + label_dim
        assert(self.prior_dim == self.IN.mu_layers[-1].out_dim)
        assert(self.prior_dim == self.IN.sigma_layers[-1].out_dim)
        assert(self.label_dim == self.PN.proto_nets[0][-1].out_dim)
        assert((self.prior_dim + self.label_dim) == self.GN.mlp_layers[0].in_dim)

        # determine whether this GITrip is a clone or an original
        if shared_param_dicts is None:
            # This is not a clone, and we will need to make a dict for
            # referring to some important shared parameters.
            self.shared_param_dicts = {}
            self.is_clone = False
        else:
            # This is a clone, and its layer parameters can be found by
            # referring to the given param dict (i.e. shared_param_dicts).
            self.shared_param_dicts = shared_param_dicts
            self.is_clone = True

        if not self.is_clone:
            # shared var learning rate for generator and inferencer
            zero_ary = np.zeros((1,)).astype(theano.config.floatX)
            self.lr_gn = theano.shared(value=zero_ary, name='git_lr_gn')
            self.lr_in = theano.shared(value=zero_ary, name='git_lr_in')
            self.lr_pn = theano.shared(value=zero_ary, name='git_lr_pn')
            # shared var momentum parameters for generator and inferencer
            self.mo_gn = theano.shared(value=zero_ary, name='git_mo_gn')
            self.mo_in = theano.shared(value=zero_ary, name='git_mo_in')
            self.mo_pn = theano.shared(value=zero_ary, name='git_mo_pn')
            # init parameters for controlling learning dynamics
            self.set_all_sgd_params()
            # init shared var for weighting nll of data given posterior sample
            self.lam_nll = theano.shared(value=zero_ary, name='git_lam_nll')
            self.set_lam_nll(lam_nll=1.0)
            # init shared var for weighting posterior KL-div from prior
            self.lam_kld = theano.shared(value=zero_ary, name='git_lam_kld')
            self.set_lam_kld(lam_kld=1.0)
            # init shared var for weighting semi-supervised classification
            self.lam_cat = theano.shared(value=zero_ary, name='git_lam_cat')
            self.set_lam_cat(lam_cat=0.0)
            # init shared var for weighting ensemble agreement regularization
            self.lam_pea = theano.shared(value=zero_ary, name='git_lam_pea')
            self.set_lam_pea(lam_pea=0.0)
            # init shared var for weighting entropy regularization on the
            # inferred posteriors over the categorical variable of interest
            self.lam_ent = theano.shared(value=zero_ary, name='git_lam_ent')
            self.set_lam_ent(lam_ent=0.0)
            # init shared var for weighting dirichlet regularization on the
            # inferred posteriors over the categorical variable of interest
            self.lam_dir = theano.shared(value=zero_ary, name='git_lam_dir')
            self.set_lam_dir(lam_dir=0.0)
            # init shared var for controlling l2 regularization on params
            self.lam_l2w = theano.shared(value=zero_ary, name='git_lam_l2w')
            self.set_lam_l2w(lam_l2w=1e-3)
            # record shared parameters that are to be shared among clones
            self.shared_param_dicts['git_lr_gn'] = self.lr_gn
            self.shared_param_dicts['git_lr_in'] = self.lr_in
            self.shared_param_dicts['git_lr_pn'] = self.lr_pn
            self.shared_param_dicts['git_mo_gn'] = self.mo_gn
            self.shared_param_dicts['git_mo_in'] = self.mo_in
            self.shared_param_dicts['git_mo_pn'] = self.mo_pn
            self.shared_param_dicts['git_lam_nll'] = self.lam_nll
            self.shared_param_dicts['git_lam_kld'] = self.lam_kld
            self.shared_param_dicts['git_lam_cat'] = self.lam_cat
            self.shared_param_dicts['git_lam_pea'] = self.lam_pea
            self.shared_param_dicts['git_lam_ent'] = self.lam_ent
            self.shared_param_dicts['git_lam_dir'] = self.lam_dir
            self.shared_param_dicts['git_lam_l2w'] = self.lam_l2w
            self.shared_param_dicts['git_input_mask'] = self.input_mask
        else:
            # use some shared parameters that are shared among all clones of
            # some "base" GITrip
            self.lr_gn = self.shared_param_dicts['git_lr_gn']
            self.lr_in = self.shared_param_dicts['git_lr_in']
            self.lr_pn = self.shared_param_dicts['git_lr_pn']
            self.mo_gn = self.shared_param_dicts['git_mo_gn']
            self.mo_in = self.shared_param_dicts['git_mo_in']
            self.mo_pn = self.shared_param_dicts['git_mo_pn']
            self.lam_nll = self.shared_param_dicts['git_lam_nll']
            self.lam_kld = self.shared_param_dicts['git_lam_kld']
            self.lam_cat = self.shared_param_dicts['git_lam_cat']
            self.lam_pea = self.shared_param_dicts['git_lam_pea']
            self.lam_ent = self.shared_param_dicts['git_lam_ent']
            self.lam_dir = self.shared_param_dicts['git_lam_dir']
            self.lam_l2w = self.shared_param_dicts['git_lam_l2w']
            self.input_mask = self.shared_param_dicts['git_input_mask']

        # Grab the full set of "optimizable" parameters from the generator
        # and inferencer networks that we'll be working with.
        self.gn_params = [p for p in self.GN.mlp_params]
        self.in_params = [p for p in self.IN.mlp_params]
        self.pn_params = [p for p in self.PN.proto_params]

        ###################################
        # CONSTRUCT THE COSTS TO OPTIMIZE #
        ###################################
        self.data_nll_cost = self.lam_nll[0] * self._construct_data_nll_cost()
        self.post_kld_cost = self.lam_kld[0] * self._construct_post_kld_cost()
        self.post_cat_cost = self.lam_cat[0] * self._construct_post_cat_cost()
        self.post_pea_cost = self.lam_pea[0] * self._construct_post_pea_cost()
        self.post_ent_cost = self.lam_ent[0] * self._construct_post_ent_cost()
        self.post_dir_cost = self.lam_dir[0] * self._construct_post_dir_cost()
        self.other_reg_costs = self._construct_other_reg_cost()
        self.other_reg_cost = self.other_reg_costs[0]
        self.joint_cost = self.data_nll_cost + self.post_kld_cost + self.post_cat_cost + \
                self.post_pea_cost + self.post_ent_cost + self.post_dir_cost + \
                self.other_reg_cost

        # Initialize momentums for mini-batch SGD updates. All parameters need
        # to be safely nestled in their lists by now.
        self.joint_moms = OrderedDict()
        self.gn_moms = OrderedDict()
        self.in_moms = OrderedDict()
        self.pn_moms = OrderedDict()
        for p in self.gn_params:
            p_mo = np.zeros(p.get_value(borrow=True).shape) + 5.0
            self.gn_moms[p] = theano.shared(value=p_mo.astype(theano.config.floatX))
            self.joint_moms[p] = self.gn_moms[p]
        for p in self.in_params:
            p_mo = np.zeros(p.get_value(borrow=True).shape) + 5.0
            self.in_moms[p] = theano.shared(value=p_mo.astype(theano.config.floatX))
            self.joint_moms[p] = self.in_moms[p]
        for p in self.pn_params:
            p_mo = np.zeros(p.get_value(borrow=True).shape) + 5.0
            self.pn_moms[p] = theano.shared(value=p_mo.astype(theano.config.floatX))
            self.joint_moms[p] = self.pn_moms[p]

        # Now, we need to construct updates for inferencers and the generator
        self.joint_updates = OrderedDict()
        self.gn_updates = OrderedDict()
        self.in_updates = OrderedDict()
        self.pn_updates = OrderedDict()
        self.grad_sq_sums = []
        #######################################
        # Construct updates for the generator #
        #######################################
        for var in self.gn_params:
            # these updates are for trainable params in the generator net...
            # first, get gradient of cost w.r.t. var
            var_grad = T.grad(self.joint_cost, var, \
                    consider_constant=[self.GN.dist_mean, self.GN.dist_cov]).clip(-1.0,1.0)
            #var_grad = ifelse(T.any(T.isnan(nan_grad)), T.zeros_like(nan_grad), nan_grad)
            #self.grad_sq_sums.append(T.sum(var_grad**2.0))
            # get the momentum for this var
            var_mom = self.gn_moms[var]
            # update the momentum for this var using its grad
            self.gn_updates[var_mom] = (self.mo_gn[0] * var_mom) + \
                    ((1.0 - self.mo_gn[0]) * (var_grad**2.0))
            self.joint_updates[var_mom] = self.gn_updates[var_mom]
            # make basic update to the var
            var_new = var - (self.lr_gn[0] * (var_grad / T.sqrt(var_mom + 1e-2)))
            self.gn_updates[var] = var_new
            # add this var's update to the joint updates too
            self.joint_updates[var] = self.gn_updates[var]
        ###################################################
        # Construct updates for the continuous inferencer #
        ###################################################
        for var in self.in_params:
            # these updates are for trainable params in the inferencer net...
            # first, get gradient of cost w.r.t. var
            var_grad = T.grad(self.joint_cost, var, \
                    consider_constant=[self.GN.dist_mean, self.GN.dist_cov]).clip(-1.0,1.0)
            #var_grad = ifelse(T.any(T.isnan(nan_grad)), T.zeros_like(nan_grad), nan_grad)
            #self.grad_sq_sums.append(T.sum(var_grad**2.0))
            # get the momentum for this var
            var_mom = self.in_moms[var]
            # update the momentum for this var using its grad
            self.in_updates[var_mom] = (self.mo_in[0] * var_mom) + \
                    ((1.0 - self.mo_in[0]) * (var_grad**2.0))
            self.joint_updates[var_mom] = self.in_updates[var_mom]
            # make basic update to the var
            var_new = var - (self.lr_in[0] * (var_grad / T.sqrt(var_mom + 1e-2)))
            self.in_updates[var] = var_new
            # add this var's update to the joint updates too
            self.joint_updates[var] = self.in_updates[var]
        ####################################################
        # Construct updates for the categorical inferencer #
        ####################################################
        for var in self.pn_params:
            # these updates are for trainable params in the inferencer net...
            # first, get gradient of cost w.r.t. var
            var_grad = T.grad(self.joint_cost, var, \
                    consider_constant=[self.GN.dist_mean, self.GN.dist_cov]).clip(-1.0,1.0)
            #var_grad = ifelse(T.any(T.isnan(nan_grad)), T.zeros_like(nan_grad), nan_grad)
            #self.grad_sq_sums.append(T.sum(var_grad**2.0))
            # get the momentum for this var
            var_mom = self.pn_moms[var]
            # update the momentum for this var using its grad
            self.pn_updates[var_mom] = (self.mo_pn[0] * var_mom) + \
                    ((1.0 - self.mo_pn[0]) * (var_grad**2.0))
            self.joint_updates[var_mom] = self.pn_updates[var_mom]
            # make basic update to the var
            var_new = var - (self.lr_pn[0] * (var_grad / T.sqrt(var_mom + 1e-2)))
            self.pn_updates[var] = var_new
            # add this var's update to the joint updates too
            self.joint_updates[var] = self.pn_updates[var]
        # Record the sum of squared gradients (for NaN checking)
        self.grad_sq_sum = T.sum(self.grad_sq_sums)

        # Construct batch-based training functions for the generator and
        # inferer networks, as well as a joint training function.
        #self.train_gn = self._construct_train_gn()
        #self.train_in = self._construct_train_in()
        self.train_joint = self._construct_train_joint()
        return
Ejemplo n.º 12
0
    def __init__(self, rng=None, \
            Xd=None, Yd=None, Xc=None, Xm=None, \
            g_net=None, i_net=None, p_net=None, \
            data_dim=None, prior_dim=None, label_dim=None, \
            batch_size=None, \
            params=None, shared_param_dicts=None):
        # setup a rng for this GIStack
        self.rng = RandStream(rng.randint(100000))
        # record the symbolic variables that will provide inputs to the
        # computation graph created to describe this GIStack
        self.Xd = Xd
        self.Yd = Yd
        self.Xc = Xc
        self.Xm = Xm
        # record the dimensionality of the data handled by this GIStack
        self.data_dim = data_dim
        self.label_dim = label_dim
        self.prior_dim = prior_dim
        self.batch_size = batch_size
        # create "shared-parameter" clones of the continuous inferencer
        self.IN = i_net.shared_param_clone(rng=rng, \
                Xd=self.Xd, Xc=self.Xc, Xm=self.Xm)
        # capture a handle for the output of the continuous inferencer
        self.Xp = self.IN.output
        # feed it into a shared-parameter clone of the generator
        self.GN = g_net.shared_param_clone(rng=rng, Xp=self.Xp)
        # capture a handle for sampled reconstructions from the generator
        self.Xg = self.GN.output
        # and feed it into a shared-parameter clone of the label inferencer
        self.PN = p_net.shared_param_clone(rng=rng, Xd=self.Xp)
        # capture a handle for the output of the label inferencer. we'll use
        # the output of the "first" spawn-net. it may be useful to try using
        # the output of the proto-net instead...
        self.Yp = safe_softmax(self.PN.output_spawn[0])
        self.Yp_proto = safe_softmax(self.PN.output_proto)

        # we will be assuming one proto-net in the pseudo-ensemble represented
        # by self.PN, and either one or two spawn-nets for that proto-net.
        assert(len(self.PN.proto_nets) == 1)
        assert((len(self.PN.spawn_nets) == 1) or \
                (len(self.PN.spawn_nets) == 2))
        # output of the generator and input to the continuous inferencer should
        # both be equal to self.data_dim
        assert(self.data_dim == self.GN.mlp_layers[-1].out_dim)
        assert(self.data_dim == self.IN.shared_layers[0].in_dim)
        # mu/sigma outputs of self.IN should be equal to prior_dim, as should
        # the inputs to self.GN and self.PN. self.PN should produce output with
        # dimension label_dim.
        assert(self.prior_dim == self.IN.mu_layers[-1].out_dim)
        assert(self.prior_dim == self.IN.sigma_layers[-1].out_dim)
        assert(self.prior_dim == self.GN.mlp_layers[0].in_dim)
        assert(self.prior_dim == self.PN.proto_nets[0][0].in_dim)
        assert(self.label_dim == self.PN.proto_nets[0][-1].out_dim)

        # determine whether this GIStack is a clone or an original
        if shared_param_dicts is None:
            # This is not a clone, and we will need to make a dict for
            # referring to some important shared parameters.
            self.shared_param_dicts = {}
            self.is_clone = False
        else:
            # This is a clone, and its layer parameters can be found by
            # referring to the given param dict (i.e. shared_param_dicts).
            self.shared_param_dicts = shared_param_dicts
            self.is_clone = True

        if not self.is_clone:
            # shared var learning rate for generator and inferencer
            zero_ary = np.zeros((1,)).astype(theano.config.floatX)
            self.lr_gn = theano.shared(value=zero_ary, name='gis_lr_gn')
            self.lr_in = theano.shared(value=zero_ary, name='gis_lr_in')
            self.lr_pn = theano.shared(value=zero_ary, name='gis_lr_pn')
            # shared var momentum parameters for generator and inferencer
            self.mo_gn = theano.shared(value=zero_ary, name='gis_mo_gn')
            self.mo_in = theano.shared(value=zero_ary, name='gis_mo_in')
            self.mo_pn = theano.shared(value=zero_ary, name='gis_mo_pn')
            # init parameters for controlling learning dynamics
            self.set_all_sgd_params()
            # init shared var for weighting nll of data given posterior sample
            self.lam_nll = theano.shared(value=zero_ary, name='gis_lam_nll')
            self.set_lam_nll(lam_nll=1.0)
            # init shared var for weighting posterior KL-div from prior
            self.lam_kld = theano.shared(value=zero_ary, name='gis_lam_kld')
            self.set_lam_kld(lam_kld=1.0)
            # init shared var for weighting semi-supervised classification
            self.lam_cat = theano.shared(value=zero_ary, name='gis_lam_cat')
            self.set_lam_cat(lam_cat=0.0)
            # init shared var for weighting ensemble agreement regularization
            self.lam_pea = theano.shared(value=zero_ary, name='gis_lam_pea')
            self.set_lam_pea(lam_pea=0.0)
            # init shared var for weighting entropy regularization on the
            # inferred posteriors over the categorical variable of interest
            self.lam_ent = theano.shared(value=zero_ary, name='gis_lam_ent')
            self.set_lam_ent(lam_ent=0.0)
            # init shared var for controlling l2 regularization on params
            self.lam_l2w = theano.shared(value=zero_ary, name='gis_lam_l2w')
            self.set_lam_l2w(lam_l2w=1e-3)
            # record shared parameters that are to be shared among clones
            self.shared_param_dicts['gis_lr_gn'] = self.lr_gn
            self.shared_param_dicts['gis_lr_in'] = self.lr_in
            self.shared_param_dicts['gis_lr_pn'] = self.lr_pn
            self.shared_param_dicts['gis_mo_gn'] = self.mo_gn
            self.shared_param_dicts['gis_mo_in'] = self.mo_in
            self.shared_param_dicts['gis_mo_pn'] = self.mo_pn
            self.shared_param_dicts['gis_lam_nll'] = self.lam_nll
            self.shared_param_dicts['gis_lam_kld'] = self.lam_kld
            self.shared_param_dicts['gis_lam_cat'] = self.lam_cat
            self.shared_param_dicts['gis_lam_pea'] = self.lam_pea
            self.shared_param_dicts['gis_lam_ent'] = self.lam_ent
            self.shared_param_dicts['gis_lam_l2w'] = self.lam_l2w
        else:
            # use some shared parameters that are shared among all clones of
            # some "base" GIStack
            self.lr_gn = self.shared_param_dicts['gis_lr_gn']
            self.lr_in = self.shared_param_dicts['gis_lr_in']
            self.lr_pn = self.shared_param_dicts['gis_lr_pn']
            self.mo_gn = self.shared_param_dicts['gis_mo_gn']
            self.mo_in = self.shared_param_dicts['gis_mo_in']
            self.mo_pn = self.shared_param_dicts['gis_mo_pn']
            self.lam_nll = self.shared_param_dicts['gis_lam_nll']
            self.lam_kld = self.shared_param_dicts['gis_lam_kld']
            self.lam_cat = self.shared_param_dicts['gis_lam_cat']
            self.lam_pea = self.shared_param_dicts['gis_lam_pea']
            self.lam_ent = self.shared_param_dicts['gis_lam_ent']
            self.lam_l2w = self.shared_param_dicts['gis_lam_l2w']

        # Grab the full set of "optimizable" parameters from the generator
        # and inferencer networks that we'll be working with.
        self.gn_params = [p for p in self.GN.mlp_params]
        self.in_params = [p for p in self.IN.mlp_params]
        self.pn_params = [p for p in self.PN.proto_params]

        ###################################
        # CONSTRUCT THE COSTS TO OPTIMIZE #
        ###################################
        self.data_nll_cost = self.lam_nll[0] * self._construct_data_nll_cost()
        self.post_kld_cost = self.lam_kld[0] * self._construct_post_kld_cost()
        self.post_cat_cost = self.lam_cat[0] * self._construct_post_cat_cost()
        self.post_pea_cost = self.lam_pea[0] * self._construct_post_pea_cost()
        self.post_ent_cost = self.lam_ent[0] * self._construct_post_ent_cost()
        self.other_reg_cost = self._construct_other_reg_cost()
        self.joint_cost = self.data_nll_cost + self.post_kld_cost + self.post_cat_cost + \
                self.post_pea_cost + self.post_ent_cost + self.other_reg_cost

        # Initialize momentums for mini-batch SGD updates. All optimizable
        # parameters need to be safely nestled in their lists by now.
        self.joint_moms = OrderedDict()
        self.gn_moms = OrderedDict()
        self.in_moms = OrderedDict()
        self.pn_moms = OrderedDict()
        for p in self.gn_params:
            p_mo = np.zeros(p.get_value(borrow=True).shape) + 2.0
            self.gn_moms[p] = theano.shared(value=p_mo.astype(theano.config.floatX))
            self.joint_moms[p] = self.gn_moms[p]
        for p in self.in_params:
            p_mo = np.zeros(p.get_value(borrow=True).shape) + 2.0
            self.in_moms[p] = theano.shared(value=p_mo.astype(theano.config.floatX))
            self.joint_moms[p] = self.in_moms[p]
        for p in self.pn_params:
            p_mo = np.zeros(p.get_value(borrow=True).shape) + 2.0
            self.pn_moms[p] = theano.shared(value=p_mo.astype(theano.config.floatX))
            self.joint_moms[p] = self.pn_moms[p]

        # now, must construct updates for all parameters and their momentums
        self.joint_updates = OrderedDict()
        self.gn_updates = OrderedDict()
        self.in_updates = OrderedDict()
        self.pn_updates = OrderedDict()
        #######################################
        # Construct updates for the generator #
        #######################################
        for var in self.gn_params:
            # these updates are for trainable params in the generator net...
            # first, get gradient of cost w.r.t. var
            var_grad = T.grad(self.joint_cost, var, \
                    consider_constant=[self.GN.dist_mean, self.GN.dist_cov]).clip(-1.0,1.0)
            # get the momentum for this var
            var_mom = self.gn_moms[var]
            # update the momentum for this var using its grad
            self.gn_updates[var_mom] = (self.mo_gn[0] * var_mom) + \
                    ((1.0 - self.mo_gn[0]) * (var_grad**2.0))
            self.joint_updates[var_mom] = self.gn_updates[var_mom]
            # make basic update to the var
            var_new = var - (self.lr_gn[0] * (var_grad / T.sqrt(var_mom + 1e-2)))
            self.gn_updates[var] = var_new
            # add this var's update to the joint updates too
            self.joint_updates[var] = self.gn_updates[var]
        ###################################################
        # Construct updates for the continuous inferencer #
        ###################################################
        for var in self.in_params:
            # these updates are for trainable params in the inferencer net...
            # first, get gradient of cost w.r.t. var
            var_grad = T.grad(self.joint_cost, var, \
                    consider_constant=[self.GN.dist_mean, self.GN.dist_cov]).clip(-1.0,1.0)
            # get the momentum for this var
            var_mom = self.in_moms[var]
            # update the momentum for this var using its grad
            self.in_updates[var_mom] = (self.mo_in[0] * var_mom) + \
                    ((1.0 - self.mo_in[0]) * (var_grad**2.0))
            self.joint_updates[var_mom] = self.in_updates[var_mom]
            # make basic update to the var
            var_new = var - (self.lr_in[0] * (var_grad / T.sqrt(var_mom + 1e-2)))
            self.in_updates[var] = var_new
            # add this var's update to the joint updates too
            self.joint_updates[var] = self.in_updates[var]
        ####################################################
        # Construct updates for the categorical inferencer #
        ####################################################
        for var in self.pn_params:
            # these updates are for trainable params in the inferencer net...
            # first, get gradient of cost w.r.t. var
            var_grad = T.grad(self.joint_cost, var, \
                    consider_constant=[self.GN.dist_mean, self.GN.dist_cov]).clip(-1.0,1.0)
            # get the momentum for this var
            var_mom = self.pn_moms[var]
            # update the momentum for this var using its grad
            self.pn_updates[var_mom] = (self.mo_pn[0] * var_mom) + \
                    ((1.0 - self.mo_pn[0]) * (var_grad**2.0))
            self.joint_updates[var_mom] = self.pn_updates[var_mom]
            # make basic update to the var
            var_new = var - (self.lr_pn[0] * (var_grad / T.sqrt(var_mom + 1e-2)))
            self.pn_updates[var] = var_new
            # add this var's update to the joint updates too
            self.joint_updates[var] = self.pn_updates[var]

        # Construct batch-based training functions for the generator and
        # inferer networks, as well as a joint training function.
        #self.train_gn = self._construct_train_gn()
        #self.train_in = self._construct_train_in()
        self.train_joint = self._construct_train_joint()
        return