Example #1
0
    def __init__(self, rng=None, Xd=None, Xc=None, Xm=None, Xt=None, \
                 i_net=None, g_net=None, d_net=None, chain_len=None, \
                 data_dim=None, prior_dim=None, params=None):
        # Do some stuff!
        self.rng = RandStream(rng.randint(100000))
        self.data_dim = data_dim
        self.prior_dim = prior_dim
        self.prior_mean = 0.0
        self.prior_logvar = 0.0
        if params is None:
            self.params = {}
        else:
            self.params = params
        if 'cost_decay' in self.params:
            self.cost_decay = self.params['cost_decay']
        else:
            self.cost_decay = 0.1
        if 'chain_type' in self.params:
            assert((self.params['chain_type'] == 'walkback') or \
                (self.params['chain_type'] == 'walkout'))
            self.chain_type = self.params['chain_type']
        else:
            self.chain_type = 'walkout'
        if 'xt_transform' in self.params:
            assert((self.params['xt_transform'] == 'sigmoid') or \
                    (self.params['xt_transform'] == 'none'))
            if self.params['xt_transform'] == 'sigmoid':
                self.xt_transform = lambda x: T.nnet.sigmoid(x)
            else:
                self.xt_transform = lambda x: x
        else:
            self.xt_transform = lambda x: T.nnet.sigmoid(x)
        if 'logvar_bound' in self.params:
            self.logvar_bound = self.params['logvar_bound']
        else:
            self.logvar_bound = 10
        #
        # x_type: this tells if we're using bernoulli or gaussian model for
        #         the observations
        #
        self.x_type = self.params['x_type']
        assert((self.x_type == 'bernoulli') or (self.x_type == 'gaussian'))

        # symbolic var for inputting samples for initializing the VAE chain
        self.Xd = Xd
        # symbolic var for masking subsets of the state variables
        self.Xm = Xm
        # symbolic var for controlling subsets of the state variables
        self.Xc = Xc
        # symbolic var for inputting samples from the target distribution
        self.Xt = Xt
        # integer number of times to cycle the VAE loop
        self.chain_len = chain_len

        # symbolic matrix of indices for data inputs
        self.It = T.arange(self.Xt.shape[0])
        # symbolic matrix of indices for noise/generated inputs
        self.Id = T.arange(self.chain_len * self.Xd.shape[0]) + self.Xt.shape[0]

        # get a clone of the desired VAE, for easy access
        self.OSM = OneStageModel(rng=rng, Xd=self.Xd, Xc=self.Xc, Xm=self.Xm, \
                p_x_given_z=g_net, q_z_given_x=i_net, x_dim=self.data_dim, \
                z_dim=self.prior_dim, params=self.params)
        self.IN = self.OSM.q_z_given_x
        self.GN = self.OSM.p_x_given_z
        self.transform_x_to_z = self.OSM.transform_x_to_z
        self.transform_z_to_x = self.OSM.transform_z_to_x
        self.bounded_logvar = self.OSM.bounded_logvar
        # self-loop some clones of the main VAE into a chain.
        # ** All VAEs in the chain share the same Xc and Xm, which are the
        #    symbolic inputs for providing the observed portion of the input
        #    and a mask indicating which part of the input is "observed".
        #    These inputs are used for training "reconstruction" policies.
        self.IN_chain = []
        self.GN_chain = []
        self.Xg_chain = []
        _Xd = self.Xd
        print("Unrolling chain...")
        for i in range(self.chain_len):
            # create a VAE infer/generate pair with _Xd as input and with
            # masking variables shared by all VAEs in this chain
            _IN = self.IN.shared_param_clone(rng=rng, \
                    Xd=apply_mask(Xd=_Xd, Xc=self.Xc, Xm=self.Xm), \
                    build_funcs=False)
            _GN = self.GN.shared_param_clone(rng=rng, Xd=_IN.output, \
                    build_funcs=False)
            _Xd = self.xt_transform(_GN.output_mean)
            self.IN_chain.append(_IN)
            self.GN_chain.append(_GN)
            self.Xg_chain.append(_Xd)
            print("    step {}...".format(i))

        # make a clone of the desired discriminator network, which will try
        # to discriminate between samples from the training data and samples
        # generated by the self-looped VAE chain.
        self.DN = d_net.shared_param_clone(rng=rng, \
                Xd=T.vertical_stack(self.Xt, *self.Xg_chain))

        zero_ary = np.zeros((1,)).astype(theano.config.floatX)
        # init shared var for weighting nll of data given posterior sample
        self.lam_chain_nll = theano.shared(value=zero_ary, name='vcg_lam_chain_nll')
        self.set_lam_chain_nll(lam_chain_nll=1.0)
        # init shared var for weighting posterior KL-div from prior
        self.lam_chain_kld = theano.shared(value=zero_ary, name='vcg_lam_chain_kld')
        self.set_lam_chain_kld(lam_chain_kld=1.0)
        # init shared var for controlling l2 regularization on params
        self.lam_l2w = theano.shared(value=zero_ary, name='vcg_lam_l2w')
        self.set_lam_l2w(lam_l2w=1e-4)
        # shared var learning rates for all networks
        self.lr_dn = theano.shared(value=zero_ary, name='vcg_lr_dn')
        self.lr_gn = theano.shared(value=zero_ary, name='vcg_lr_gn')
        self.lr_in = theano.shared(value=zero_ary, name='vcg_lr_in')
        # shared var momentum parameters for all networks
        self.mom_1 = theano.shared(value=zero_ary, name='vcg_mom_1')
        self.mom_2 = theano.shared(value=zero_ary, name='vcg_mom_2')
        self.it_count = theano.shared(value=zero_ary, name='vcg_it_count')
        # shared var weights for adversarial classification objective
        self.dw_dn = theano.shared(value=zero_ary, name='vcg_dw_dn')
        self.dw_gn = theano.shared(value=zero_ary, name='vcg_dw_gn')
        # init parameters for controlling learning dynamics
        self.set_all_sgd_params()
        
        self.set_disc_weights()  # init adversarial cost weights for GN/DN
        # set a shared var for regularizing the output of the discriminator
        self.lam_l2d = theano.shared(value=(zero_ary + params['lam_l2d']), \
                name='vcg_lam_l2d')

        # Grab the full set of "optimizable" parameters from the generator
        # and discriminator networks that we'll be working with. We need to
        # ignore parameters in the final layers of the proto-networks in the
        # discriminator network (a generalized pseudo-ensemble). We ignore them
        # because the VCGair requires that they be "bypassed" in favor of some
        # binary classification layers that will be managed by this VCGair.
        self.dn_params = []
        for pn in self.DN.proto_nets:
            for pnl in pn[0:-1]:
                self.dn_params.extend(pnl.params)
        self.in_params = [p for p in self.IN.mlp_params]
        self.in_params.append(self.OSM.output_logvar)
        self.gn_params = [p for p in self.GN.mlp_params]
        self.joint_params = self.in_params + self.gn_params + self.dn_params

        # Now construct a binary discriminator layer for each proto-net in the
        # discriminator network. And, add their params to optimization list.
        self._construct_disc_layers(rng)
        self.disc_reg_cost = self.lam_l2d[0] * \
                T.sum([dl.act_l2_sum for dl in self.disc_layers])

        # Construct costs for the generator and discriminator networks based 
        # on adversarial binary classification
        self.disc_cost_dn, self.disc_cost_gn = self._construct_disc_costs()

        # first, build the cost to be optimized by the discriminator network,
        # in general this will be treated somewhat indepedently of the
        # optimization of the generator and inferencer networks.
        self.dn_cost = self.disc_cost_dn + self.DN.act_reg_cost + \
                self.disc_reg_cost

        # construct costs relevant to the optimization of the generator and
        # discriminator networks
        self.chain_nll_cost = self.lam_chain_nll[0] * \
                self._construct_chain_nll_cost(cost_decay=self.cost_decay)
        self.chain_kld_cost = self.lam_chain_kld[0] * \
                self._construct_chain_kld_cost(cost_decay=self.cost_decay)
        self.other_reg_cost = self._construct_other_reg_cost()
        self.osm_cost = self.disc_cost_gn + self.chain_nll_cost + \
                self.chain_kld_cost + self.other_reg_cost
        # compute total cost on the discriminator and VB generator/inferencer
        self.joint_cost = self.dn_cost + self.osm_cost

        # Get the gradient of the joint cost for all optimizable parameters
        self.joint_grads = OrderedDict()
        print("Computing VCGLoop DN cost gradients...")
        grad_list = T.grad(self.dn_cost, self.dn_params, disconnected_inputs='warn')
        for i, p in enumerate(self.dn_params):
            self.joint_grads[p] = grad_list[i]
        print("Computing VCGLoop IN cost gradients...")
        grad_list = T.grad(self.osm_cost, self.in_params, disconnected_inputs='warn')
        for i, p in enumerate(self.in_params):
            self.joint_grads[p] = grad_list[i]
        print("Computing VCGLoop GN cost gradients...")
        grad_list = T.grad(self.osm_cost, self.gn_params, disconnected_inputs='warn')
        for i, p in enumerate(self.gn_params):
            self.joint_grads[p] = grad_list[i]

        # construct the updates for the discriminator, generator and 
        # inferencer networks. all networks share the same first/second
        # moment momentum and iteration count. the networks each have their
        # own learning rates, which lets you turn their learning on/off.
        self.dn_updates = get_param_updates(params=self.dn_params, \
                grads=self.joint_grads, alpha=self.lr_dn, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)
        self.gn_updates = get_param_updates(params=self.gn_params, \
                grads=self.joint_grads, alpha=self.lr_gn, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)
        self.in_updates = get_param_updates(params=self.in_params, \
                grads=self.joint_grads, alpha=self.lr_in, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)

        # bag up all the updates required for training
        self.joint_updates = OrderedDict()
        for k in self.dn_updates:
            self.joint_updates[k] = self.dn_updates[k]
        for k in self.gn_updates:
            self.joint_updates[k] = self.gn_updates[k]
        for k in self.in_updates:
            self.joint_updates[k] = self.in_updates[k]
        # construct an update for tracking the mean KL divergence of
        # approximate posteriors for this chain
        new_kld_mean = (0.98 * self.IN.kld_mean) + ((0.02 / self.chain_len) * \
            sum([T.mean(I_N.kld_cost) for I_N in self.IN_chain]))
        self.joint_updates[self.IN.kld_mean] = T.cast(new_kld_mean, 'floatX')

        # construct the function for training on training data
        print("Compiling VCGLoop theano functions....")
        self.train_joint = self._construct_train_joint()
        return
Example #2
0
    def __init__(self, rng=None, \
            Xd=None, Xc=None, Xm=None, \
            p_x_given_z=None, q_z_given_x=None, \
            x_dim=None, z_dim=None, \
            params=None):
        # setup a rng for this GIPair
        self.rng = RandStream(rng.randint(100000))

        # grab the user-provided parameters
        if params is None:
            self.params = {}
        else:
            self.params = params
        if 'xt_transform' in self.params:
            assert((self.params['xt_transform'] == 'sigmoid') or \
                    (self.params['xt_transform'] == 'none'))
            if self.params['xt_transform'] == 'sigmoid':
                self.xt_transform = lambda x: T.nnet.sigmoid(x)
            else:
                self.xt_transform = lambda x: x
        else:
            self.xt_transform = lambda x: T.nnet.sigmoid(x)
        if 'logvar_bound' in self.params:
            self.logvar_bound = self.params['logvar_bound']
        else:
            self.logvar_bound = 10
        #
        # x_type: this tells if we're using bernoulli or gaussian model for
        #         the observations
        #
        self.x_type = self.params['x_type']
        assert ((self.x_type == 'bernoulli') or (self.x_type == 'gaussian'))

        # record the dimensions of various spaces relevant to this model
        self.z_dim = z_dim
        self.x_dim = x_dim

        # set parameters for the isotropic Gaussian prior over z
        self.prior_mean = 0.0
        self.prior_logvar = 0.0

        # record the symbolic variables that will provide inputs to the
        # computation graph created to describe this OneStageModel
        self.Xd = Xd
        self.Xc = Xc
        self.Xm = Xm
        self.batch_reps = T.lscalar()
        self.x = apply_mask(self.Xd, self.Xc, self.Xm)

        #####################################################################
        # Setup the computation graph that provides values in our objective #
        #####################################################################
        # inferencer model for latent prototypes given instances
        self.q_z_given_x = q_z_given_x.shared_param_clone(rng=rng, Xd=self.x)
        self.z = self.q_z_given_x.output
        self.z_mean = self.q_z_given_x.output_mean
        self.z_logvar = self.q_z_given_x.output_logvar
        # generator model for prototypes given latent prototypes
        self.p_x_given_z = p_x_given_z.shared_param_clone(rng=rng, Xd=self.z)
        self.xt = self.p_x_given_z.output_mean  # use deterministic output
        # construct the final output of generator, conditioned on z
        if self.x_type == 'bernoulli':
            self.xg = T.nnet.sigmoid(self.xt)
        else:
            self.xg = self.xt_transform(self.xt)

        # self.output_logvar modifies the output distribution
        self.output_logvar = self.p_x_given_z.sigma_layers[-1].b
        self.bounded_logvar = self.logvar_bound * \
                    T.tanh(self.output_logvar[0] / self.logvar_bound)

        ######################################################################
        # ALL SYMBOLIC VARS NEEDED FOR THE OBJECTIVE SHOULD NOW BE AVAILABLE #
        ######################################################################

        # shared var learning rate for generator and inferencer
        zero_ary = np.zeros((1, )).astype(theano.config.floatX)
        self.lr_1 = theano.shared(value=zero_ary, name='osm_lr_1')
        # shared var momentum parameters for generator and inferencer
        self.mom_1 = theano.shared(value=zero_ary, name='osm_mom_1')
        self.mom_2 = theano.shared(value=zero_ary, name='osm_mom_2')
        self.it_count = theano.shared(value=zero_ary, name='osm_it_count')
        # init parameters for controlling learning dynamics
        self.set_sgd_params()
        # init shared var for weighting nll of data given posterior sample
        self.lam_nll = theano.shared(value=zero_ary, name='osm_lam_nll')
        self.set_lam_nll(lam_nll=1.0)
        # init shared var for weighting prior kld against reconstruction
        self.lam_kld_1 = theano.shared(value=zero_ary, name='osm_lam_kld_1')
        self.lam_kld_2 = theano.shared(value=zero_ary, name='osm_lam_kld_2')
        self.set_lam_kld(lam_kld_1=1.0, lam_kld_2=0.0)
        # init shared var for controlling l2 regularization on params
        self.lam_l2w = theano.shared(value=zero_ary, name='osm_lam_l2w')
        self.set_lam_l2w(1e-4)

        # Grab all of the "optimizable" parameters in "group 1"
        self.group_1_params = []
        self.group_1_params.extend(self.q_z_given_x.mlp_params)
        self.group_1_params.extend(self.p_x_given_z.mlp_params)
        # Make a joint list of parameters
        self.joint_params = self.group_1_params

        ###################################
        # CONSTRUCT THE COSTS TO OPTIMIZE #
        ###################################
        self.nll_costs = self.lam_nll[0] * self._construct_nll_costs()
        self.nll_cost = T.mean(self.nll_costs)
        self.kld_costs_1, self.kld_costs_2 = self._construct_kld_costs()
        self.kld_costs = (self.lam_kld_1[0] * self.kld_costs_1) + \
                (self.lam_kld_2[0] * self.kld_costs_2)
        self.kld_cost = T.mean(self.kld_costs)
        act_reg_cost, param_reg_cost = self._construct_reg_costs()
        self.reg_cost = self.lam_l2w[0] * param_reg_cost
        self.joint_cost = self.nll_cost + self.kld_cost + self.reg_cost

        # Get the gradient of the joint cost for all optimizable parameters
        print("Computing OneStageModel cost gradients...")
        self.joint_grads = OrderedDict()
        grad_list = T.grad(self.joint_cost, self.joint_params)
        for i, p in enumerate(self.joint_params):
            self.joint_grads[p] = grad_list[i]

        # Construct the updates for the generator and inferencer networks
        self.joint_updates = get_param_updates(params=self.joint_params, \
                grads=self.joint_grads, alpha=self.lr_1, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)

        # Construct a function for jointly training the generator/inferencer
        print("Compiling OneStageModel theano functions...")
        self.train_joint = self._construct_train_joint()
        self.compute_fe_terms = self._construct_compute_fe_terms()
        self.compute_post_klds = self._construct_compute_post_klds()
        self.sample_from_prior = self._construct_sample_from_prior()
        self.transform_x_to_z = theano.function([self.q_z_given_x.Xd], \
                outputs=self.q_z_given_x.output_mean)
        self.transform_z_to_x = theano.function([self.p_x_given_z.Xd], \
                outputs=self.xt_transform(self.p_x_given_z.output_mean))
        self.inf_weights = self.q_z_given_x.shared_layers[0].W
        self.gen_weights = self.p_x_given_z.mu_layers[-1].W
        return
Example #3
0
    def __init__(self, rng=None, Xd=None, Xc=None, Xm=None, Xt=None, \
                 i_net=None, g_net=None, d_net=None, chain_len=None, \
                 data_dim=None, prior_dim=None, params=None):
        # Do some stuff!
        self.rng = RandStream(rng.randint(100000))
        self.data_dim = data_dim
        self.prior_dim = prior_dim
        self.prior_mean = 0.0
        self.prior_logvar = 0.0
        if params is None:
            self.params = {}
        else:
            self.params = params
        if 'cost_decay' in self.params:
            self.cost_decay = self.params['cost_decay']
        else:
            self.cost_decay = 0.1
        if 'chain_type' in self.params:
            assert((self.params['chain_type'] == 'walkback') or \
                (self.params['chain_type'] == 'walkout'))
            self.chain_type = self.params['chain_type']
        else:
            self.chain_type = 'walkout'
        if 'xt_transform' in self.params:
            assert((self.params['xt_transform'] == 'sigmoid') or \
                    (self.params['xt_transform'] == 'none'))
            if self.params['xt_transform'] == 'sigmoid':
                self.xt_transform = lambda x: T.nnet.sigmoid(x)
            else:
                self.xt_transform = lambda x: x
        else:
            self.xt_transform = lambda x: T.nnet.sigmoid(x)
        if 'logvar_bound' in self.params:
            self.logvar_bound = self.params['logvar_bound']
        else:
            self.logvar_bound = 10
        #
        # x_type: this tells if we're using bernoulli or gaussian model for
        #         the observations
        #
        self.x_type = self.params['x_type']
        assert ((self.x_type == 'bernoulli') or (self.x_type == 'gaussian'))

        # symbolic var for inputting samples for initializing the VAE chain
        self.Xd = Xd
        # symbolic var for masking subsets of the state variables
        self.Xm = Xm
        # symbolic var for controlling subsets of the state variables
        self.Xc = Xc
        # symbolic var for inputting samples from the target distribution
        self.Xt = Xt
        # integer number of times to cycle the VAE loop
        self.chain_len = chain_len

        # symbolic matrix of indices for data inputs
        self.It = T.arange(self.Xt.shape[0])
        # symbolic matrix of indices for noise/generated inputs
        self.Id = T.arange(
            self.chain_len * self.Xd.shape[0]) + self.Xt.shape[0]

        # get a clone of the desired VAE, for easy access
        self.OSM = OneStageModel(rng=rng, Xd=self.Xd, Xc=self.Xc, Xm=self.Xm, \
                p_x_given_z=g_net, q_z_given_x=i_net, x_dim=self.data_dim, \
                z_dim=self.prior_dim, params=self.params)
        self.IN = self.OSM.q_z_given_x
        self.GN = self.OSM.p_x_given_z
        self.transform_x_to_z = self.OSM.transform_x_to_z
        self.transform_z_to_x = self.OSM.transform_z_to_x
        self.bounded_logvar = self.OSM.bounded_logvar
        # self-loop some clones of the main VAE into a chain.
        # ** All VAEs in the chain share the same Xc and Xm, which are the
        #    symbolic inputs for providing the observed portion of the input
        #    and a mask indicating which part of the input is "observed".
        #    These inputs are used for training "reconstruction" policies.
        self.IN_chain = []
        self.GN_chain = []
        self.Xg_chain = []
        _Xd = self.Xd
        print("Unrolling chain...")
        for i in range(self.chain_len):
            # create a VAE infer/generate pair with _Xd as input and with
            # masking variables shared by all VAEs in this chain
            _IN = self.IN.shared_param_clone(rng=rng, \
                    Xd=apply_mask(Xd=_Xd, Xc=self.Xc, Xm=self.Xm), \
                    build_funcs=False)
            _GN = self.GN.shared_param_clone(rng=rng, Xd=_IN.output, \
                    build_funcs=False)
            _Xd = self.xt_transform(_GN.output_mean)
            self.IN_chain.append(_IN)
            self.GN_chain.append(_GN)
            self.Xg_chain.append(_Xd)
            print("    step {}...".format(i))

        # make a clone of the desired discriminator network, which will try
        # to discriminate between samples from the training data and samples
        # generated by the self-looped VAE chain.
        self.DN = d_net.shared_param_clone(rng=rng, \
                Xd=T.vertical_stack(self.Xt, *self.Xg_chain))

        zero_ary = np.zeros((1, )).astype(theano.config.floatX)
        # init shared var for weighting nll of data given posterior sample
        self.lam_chain_nll = theano.shared(value=zero_ary,
                                           name='vcg_lam_chain_nll')
        self.set_lam_chain_nll(lam_chain_nll=1.0)
        # init shared var for weighting posterior KL-div from prior
        self.lam_chain_kld = theano.shared(value=zero_ary,
                                           name='vcg_lam_chain_kld')
        self.set_lam_chain_kld(lam_chain_kld=1.0)
        # init shared var for controlling l2 regularization on params
        self.lam_l2w = theano.shared(value=zero_ary, name='vcg_lam_l2w')
        self.set_lam_l2w(lam_l2w=1e-4)
        # shared var learning rates for all networks
        self.lr_dn = theano.shared(value=zero_ary, name='vcg_lr_dn')
        self.lr_gn = theano.shared(value=zero_ary, name='vcg_lr_gn')
        self.lr_in = theano.shared(value=zero_ary, name='vcg_lr_in')
        # shared var momentum parameters for all networks
        self.mom_1 = theano.shared(value=zero_ary, name='vcg_mom_1')
        self.mom_2 = theano.shared(value=zero_ary, name='vcg_mom_2')
        self.it_count = theano.shared(value=zero_ary, name='vcg_it_count')
        # shared var weights for adversarial classification objective
        self.dw_dn = theano.shared(value=zero_ary, name='vcg_dw_dn')
        self.dw_gn = theano.shared(value=zero_ary, name='vcg_dw_gn')
        # init parameters for controlling learning dynamics
        self.set_all_sgd_params()

        self.set_disc_weights()  # init adversarial cost weights for GN/DN
        # set a shared var for regularizing the output of the discriminator
        self.lam_l2d = theano.shared(value=(zero_ary + params['lam_l2d']), \
                name='vcg_lam_l2d')

        # Grab the full set of "optimizable" parameters from the generator
        # and discriminator networks that we'll be working with. We need to
        # ignore parameters in the final layers of the proto-networks in the
        # discriminator network (a generalized pseudo-ensemble). We ignore them
        # because the VCGair requires that they be "bypassed" in favor of some
        # binary classification layers that will be managed by this VCGair.
        self.dn_params = []
        for pn in self.DN.proto_nets:
            for pnl in pn[0:-1]:
                self.dn_params.extend(pnl.params)
        self.in_params = [p for p in self.IN.mlp_params]
        self.in_params.append(self.OSM.output_logvar)
        self.gn_params = [p for p in self.GN.mlp_params]
        self.joint_params = self.in_params + self.gn_params + self.dn_params

        # Now construct a binary discriminator layer for each proto-net in the
        # discriminator network. And, add their params to optimization list.
        self._construct_disc_layers(rng)
        self.disc_reg_cost = self.lam_l2d[0] * \
                T.sum([dl.act_l2_sum for dl in self.disc_layers])

        # Construct costs for the generator and discriminator networks based
        # on adversarial binary classification
        self.disc_cost_dn, self.disc_cost_gn = self._construct_disc_costs()

        # first, build the cost to be optimized by the discriminator network,
        # in general this will be treated somewhat indepedently of the
        # optimization of the generator and inferencer networks.
        self.dn_cost = self.disc_cost_dn + self.DN.act_reg_cost + \
                self.disc_reg_cost

        # construct costs relevant to the optimization of the generator and
        # discriminator networks
        self.chain_nll_cost = self.lam_chain_nll[0] * \
                self._construct_chain_nll_cost(cost_decay=self.cost_decay)
        self.chain_kld_cost = self.lam_chain_kld[0] * \
                self._construct_chain_kld_cost(cost_decay=self.cost_decay)
        self.other_reg_cost = self._construct_other_reg_cost()
        self.osm_cost = self.disc_cost_gn + self.chain_nll_cost + \
                self.chain_kld_cost + self.other_reg_cost
        # compute total cost on the discriminator and VB generator/inferencer
        self.joint_cost = self.dn_cost + self.osm_cost

        # Get the gradient of the joint cost for all optimizable parameters
        self.joint_grads = OrderedDict()
        print("Computing VCGLoop DN cost gradients...")
        grad_list = T.grad(self.dn_cost,
                           self.dn_params,
                           disconnected_inputs='warn')
        for i, p in enumerate(self.dn_params):
            self.joint_grads[p] = grad_list[i]
        print("Computing VCGLoop IN cost gradients...")
        grad_list = T.grad(self.osm_cost,
                           self.in_params,
                           disconnected_inputs='warn')
        for i, p in enumerate(self.in_params):
            self.joint_grads[p] = grad_list[i]
        print("Computing VCGLoop GN cost gradients...")
        grad_list = T.grad(self.osm_cost,
                           self.gn_params,
                           disconnected_inputs='warn')
        for i, p in enumerate(self.gn_params):
            self.joint_grads[p] = grad_list[i]

        # construct the updates for the discriminator, generator and
        # inferencer networks. all networks share the same first/second
        # moment momentum and iteration count. the networks each have their
        # own learning rates, which lets you turn their learning on/off.
        self.dn_updates = get_param_updates(params=self.dn_params, \
                grads=self.joint_grads, alpha=self.lr_dn, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)
        self.gn_updates = get_param_updates(params=self.gn_params, \
                grads=self.joint_grads, alpha=self.lr_gn, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)
        self.in_updates = get_param_updates(params=self.in_params, \
                grads=self.joint_grads, alpha=self.lr_in, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)

        # bag up all the updates required for training
        self.joint_updates = OrderedDict()
        for k in self.dn_updates:
            self.joint_updates[k] = self.dn_updates[k]
        for k in self.gn_updates:
            self.joint_updates[k] = self.gn_updates[k]
        for k in self.in_updates:
            self.joint_updates[k] = self.in_updates[k]
        # construct an update for tracking the mean KL divergence of
        # approximate posteriors for this chain
        new_kld_mean = (0.98 * self.IN.kld_mean) + ((0.02 / self.chain_len) * \
            sum([T.mean(I_N.kld_cost) for I_N in self.IN_chain]))
        self.joint_updates[self.IN.kld_mean] = T.cast(new_kld_mean, 'floatX')

        # construct the function for training on training data
        print("Compiling VCGLoop theano functions....")
        self.train_joint = self._construct_train_joint()
        return
    def __init__(self, rng=None, \
            Xd=None, Xc=None, Xm=None, \
            p_x_given_z=None, q_z_given_x=None, \
            x_dim=None, z_dim=None, \
            params=None):
        # setup a rng for this GIPair
        self.rng = RandStream(rng.randint(100000))

        # grab the user-provided parameters
        if params is None:
            self.params = {}
        else:
            self.params = params
        if 'xt_transform' in self.params:
            assert((self.params['xt_transform'] == 'sigmoid') or \
                    (self.params['xt_transform'] == 'none'))
            if self.params['xt_transform'] == 'sigmoid':
                self.xt_transform = lambda x: T.nnet.sigmoid(x)
            else:
                self.xt_transform = lambda x: x
        else:
            self.xt_transform = lambda x: T.nnet.sigmoid(x)
        if 'logvar_bound' in self.params:
            self.logvar_bound = self.params['logvar_bound']
        else:
            self.logvar_bound = 10
        #
        # x_type: this tells if we're using bernoulli or gaussian model for
        #         the observations
        #
        self.x_type = self.params['x_type']
        assert((self.x_type == 'bernoulli') or (self.x_type == 'gaussian'))

        # record the dimensions of various spaces relevant to this model
        self.z_dim = z_dim
        self.x_dim = x_dim

        # set parameters for the isotropic Gaussian prior over z
        self.prior_mean = 0.0
        self.prior_logvar = 0.0

        # record the symbolic variables that will provide inputs to the
        # computation graph created to describe this OneStageModel
        self.Xd = Xd
        self.Xc = Xc
        self.Xm = Xm
        self.batch_reps = T.lscalar()
        self.x = apply_mask(self.Xd, self.Xc, self.Xm)

        #####################################################################
        # Setup the computation graph that provides values in our objective #
        #####################################################################
        # inferencer model for latent prototypes given instances
        self.q_z_given_x = q_z_given_x.shared_param_clone(rng=rng, Xd=self.x)
        self.z = self.q_z_given_x.output
        self.z_mean = self.q_z_given_x.output_mean
        self.z_logvar = self.q_z_given_x.output_logvar
        # generator model for prototypes given latent prototypes
        self.p_x_given_z = p_x_given_z.shared_param_clone(rng=rng, Xd=self.z)
        self.xt = self.p_x_given_z.output_mean # use deterministic output
        # construct the final output of generator, conditioned on z
        if self.x_type == 'bernoulli':
            self.xg = T.nnet.sigmoid(self.xt)
        else:
            self.xg = self.xt_transform(self.xt)

        # self.output_logvar modifies the output distribution
        self.output_logvar = self.p_x_given_z.sigma_layers[-1].b
        self.bounded_logvar = self.logvar_bound * \
                    T.tanh(self.output_logvar[0] / self.logvar_bound)

        ######################################################################
        # ALL SYMBOLIC VARS NEEDED FOR THE OBJECTIVE SHOULD NOW BE AVAILABLE #
        ######################################################################

        # shared var learning rate for generator and inferencer
        zero_ary = np.zeros((1,)).astype(theano.config.floatX)
        self.lr_1 = theano.shared(value=zero_ary, name='osm_lr_1')
        # shared var momentum parameters for generator and inferencer
        self.mom_1 = theano.shared(value=zero_ary, name='osm_mom_1')
        self.mom_2 = theano.shared(value=zero_ary, name='osm_mom_2')
        self.it_count = theano.shared(value=zero_ary, name='osm_it_count')
        # init parameters for controlling learning dynamics
        self.set_sgd_params()
        # init shared var for weighting nll of data given posterior sample
        self.lam_nll = theano.shared(value=zero_ary, name='osm_lam_nll')
        self.set_lam_nll(lam_nll=1.0)
        # init shared var for weighting prior kld against reconstruction
        self.lam_kld_1 = theano.shared(value=zero_ary, name='osm_lam_kld_1')
        self.lam_kld_2 = theano.shared(value=zero_ary, name='osm_lam_kld_2')
        self.set_lam_kld(lam_kld_1=1.0, lam_kld_2=0.0)
        # init shared var for controlling l2 regularization on params
        self.lam_l2w = theano.shared(value=zero_ary, name='osm_lam_l2w')
        self.set_lam_l2w(1e-4)

        # Grab all of the "optimizable" parameters in "group 1"
        self.group_1_params = []
        self.group_1_params.extend(self.q_z_given_x.mlp_params)
        self.group_1_params.extend(self.p_x_given_z.mlp_params)
        # Make a joint list of parameters
        self.joint_params = self.group_1_params

        ###################################
        # CONSTRUCT THE COSTS TO OPTIMIZE #
        ###################################
        self.nll_costs = self.lam_nll[0] * self._construct_nll_costs()
        self.nll_cost = T.mean(self.nll_costs)
        self.kld_costs_1, self.kld_costs_2 = self._construct_kld_costs()
        self.kld_costs = (self.lam_kld_1[0] * self.kld_costs_1) + \
                (self.lam_kld_2[0] * self.kld_costs_2)
        self.kld_cost = T.mean(self.kld_costs)
        act_reg_cost, param_reg_cost = self._construct_reg_costs()
        self.reg_cost = self.lam_l2w[0] * param_reg_cost
        self.joint_cost = self.nll_cost + self.kld_cost + self.reg_cost

        # Get the gradient of the joint cost for all optimizable parameters
        print("Computing OneStageModel cost gradients...")
        self.joint_grads = OrderedDict()
        grad_list = T.grad(self.joint_cost, self.joint_params)
        for i, p in enumerate(self.joint_params):
            self.joint_grads[p] = grad_list[i]

        # Construct the updates for the generator and inferencer networks
        self.joint_updates = get_param_updates(params=self.joint_params, \
                grads=self.joint_grads, alpha=self.lr_1, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)

        # Construct a function for jointly training the generator/inferencer
        print("Compiling OneStageModel theano functions...")
        self.train_joint = self._construct_train_joint()
        self.compute_fe_terms = self._construct_compute_fe_terms()
        self.compute_post_klds = self._construct_compute_post_klds()
        self.sample_from_prior = self._construct_sample_from_prior()
        self.transform_x_to_z = theano.function([self.q_z_given_x.Xd], \
                outputs=self.q_z_given_x.output_mean)
        self.transform_z_to_x = theano.function([self.p_x_given_z.Xd], \
                outputs=self.xt_transform(self.p_x_given_z.output_mean))
        self.inf_weights = self.q_z_given_x.shared_layers[0].W
        self.gen_weights = self.p_x_given_z.mu_layers[-1].W
        return
Example #5
0
    def __init__(self, rng=None, x_in=None, \
            p_s0_obs_given_z_obs=None, p_hi_given_si=None, p_sip1_given_si_hi=None, \
            p_x_given_si_hi=None, q_z_given_x=None, q_hi_given_x_si=None, \
            obs_dim=None, z_rnn_dim=None, z_obs_dim=None, h_dim=None, \
            model_init_obs=True, model_init_rnn=True, ir_steps=2, \
            params=None):
        # setup a rng for this GIPair
        self.rng = RandStream(rng.randint(100000))

        # TODO: implement functionality for working with "latent" si
        assert (p_x_given_si_hi is None)

        # decide whether to initialize from a model or from a "constant"
        self.model_init_obs = model_init_obs
        self.model_init_rnn = model_init_rnn

        # grab the user-provided parameters
        self.params = params
        self.x_type = self.params['x_type']
        assert ((self.x_type == 'bernoulli') or (self.x_type == 'gaussian'))
        if 'obs_transform' in self.params:
            assert((self.params['obs_transform'] == 'sigmoid') or \
                    (self.params['obs_transform'] == 'none'))
            if self.params['obs_transform'] == 'sigmoid':
                self.obs_transform = lambda x: T.nnet.sigmoid(x)
            else:
                self.obs_transform = lambda x: x
        else:
            self.obs_transform = lambda x: T.nnet.sigmoid(x)
        if self.x_type == 'bernoulli':
            self.obs_transform = lambda x: T.nnet.sigmoid(x)

        # record the dimensions of various spaces relevant to this model
        self.obs_dim = obs_dim
        self.rnn_dim = z_rnn_dim
        self.z_dim = z_rnn_dim + z_obs_dim
        self.z_rnn_dim = z_rnn_dim
        self.z_obs_dim = z_obs_dim
        self.h_dim = h_dim
        self.ir_steps = ir_steps

        # record the symbolic variables that will provide inputs to the
        # computation graph created to describe this MultiStageModel
        self.x = x_in
        self.batch_reps = T.lscalar()

        # setup switching variable for changing between sampling/training
        zero_ary = np.zeros((1, )).astype(theano.config.floatX)
        self.train_switch = theano.shared(value=zero_ary,
                                          name='msm_train_switch')
        self.set_train_switch(1.0)
        # setup a weight for pulling priors over hi given si towards a
        # shared global prior -- e.g. zero mean and unit variance.
        self.kzg_weight = theano.shared(value=zero_ary, name='msm_kzg_weight')
        self.set_kzg_weight(0.1)
        # this weight balances l1 vs. l2 penalty on posterior KLds
        self.l1l2_weight = theano.shared(value=zero_ary,
                                         name='msm_l1l2_weight')
        self.set_l1l2_weight(1.0)

        #############################
        # Setup self.z and self.s0. #
        #############################
        print("Building MSM step 0...")
        obs_scale = 0.0
        rnn_scale = 0.0
        if self.model_init_obs:  # initialize obs state from generative model
            obs_scale = 1.0
        if self.model_init_rnn:  # initialize rnn state from generative model
            rnn_scale = 1.0
        self.q_z_given_x = q_z_given_x.shared_param_clone(rng=rng, Xd=self.x)
        self.z = self.q_z_given_x.output
        self.z_rnn = self.z[:, :self.z_rnn_dim]
        self.z_obs = self.z[:, self.z_rnn_dim:]
        self.p_s0_obs_given_z_obs = p_s0_obs_given_z_obs.shared_param_clone( \
                rng=rng, Xd=self.z_obs)
        _s0_obs_model = self.p_s0_obs_given_z_obs.output_mean
        _s0_obs_const = self.p_s0_obs_given_z_obs.mu_layers[-1].b
        self.s0_obs = (obs_scale * _s0_obs_model) + \
                ((1.0 - obs_scale) * _s0_obs_const)
        _s0_rnn_model = self.z_rnn
        _s0_rnn_const = self.q_z_given_x.mu_layers[-1].b[:self.z_rnn_dim]
        self.s0_rnn = (rnn_scale * _s0_rnn_model) + \
                ((1.0 - rnn_scale) * _s0_rnn_const)
        self.s0_jnt = T.horizontal_stack(self.s0_obs, self.s0_rnn)
        self.output_logvar = self.p_s0_obs_given_z_obs.sigma_layers[-1].b
        self.bounded_logvar = 8.0 * T.tanh((1.0 / 8.0) * self.output_logvar)

        ###############################################################
        # Setup the iterative refinement loop, starting from self.s0. #
        ###############################################################
        self.p_hi_given_si = []  # holds p_hi_given_si for each i
        self.p_sip1_given_si_hi = []  # holds p_sip1_given_si_hi for each i
        self.q_hi_given_x_si = []  # holds q_hi_given_x_si for each i
        self.si = [self.s0_jnt]  # holds si for each i
        self.hi = []  # holds hi for each i
        for i in range(self.ir_steps):
            print("Building MSM step {0:d}...".format(i + 1))
            _si = self.si[i]
            si_obs = _si[:, :self.obs_dim]
            si_rnn = _si[:, self.obs_dim:]
            # get samples of next hi, conditioned on current si
            self.p_hi_given_si.append( \
                    p_hi_given_si.shared_param_clone(rng=rng, \
                    Xd=T.horizontal_stack( \
                    self.obs_transform(si_obs), si_rnn)))
            hi_p = self.p_hi_given_si[i].output
            # now we build the model for variational hi given si
            grad_ll = self.x - self.obs_transform(si_obs)
            self.q_hi_given_x_si.append(\
                    q_hi_given_x_si.shared_param_clone(rng=rng, \
                    Xd=T.horizontal_stack( \
                    grad_ll, self.obs_transform(si_obs), si_rnn)))
            hi_q = self.q_hi_given_x_si[i].output
            # make hi samples that can be switched between hi_p and hi_q
            self.hi.append( ((self.train_switch[0] * hi_q) + \
                    ((1.0 - self.train_switch[0]) * hi_p)) )
            # p_sip1_given_si_hi is conditioned on hi and the "rnn" part of si.
            self.p_sip1_given_si_hi.append( \
                    p_sip1_given_si_hi.shared_param_clone(rng=rng, \
                    Xd=T.horizontal_stack(self.hi[i], si_rnn)))
            # construct the update from si_obs/si_rnn to sip1_obs/sip1_rnn
            sip1_obs = si_obs + self.p_sip1_given_si_hi[i].output_mean
            sip1_rnn = si_rnn
            sip1_jnt = T.horizontal_stack(sip1_obs, sip1_rnn)
            # record the updated state of the generative process
            self.si.append(sip1_jnt)
        # check that input/output dimensions of our models agree
        self._check_model_shapes()

        ######################################################################
        # ALL SYMBOLIC VARS NEEDED FOR THE OBJECTIVE SHOULD NOW BE AVAILABLE #
        ######################################################################

        # shared var learning rate for generator and inferencer
        zero_ary = np.zeros((1, )).astype(theano.config.floatX)
        self.lr_1 = theano.shared(value=zero_ary, name='msm_lr_1')
        self.lr_2 = theano.shared(value=zero_ary, name='msm_lr_2')
        # shared var momentum parameters for generator and inferencer
        self.mom_1 = theano.shared(value=zero_ary, name='msm_mom_1')
        self.mom_2 = theano.shared(value=zero_ary, name='msm_mom_2')
        self.it_count = theano.shared(value=zero_ary, name='msm_it_count')
        # init parameters for controlling learning dynamics
        self.set_sgd_params()
        # init shared var for weighting nll of data given posterior sample
        self.lam_nll = theano.shared(value=zero_ary, name='msm_lam_nll')
        self.set_lam_nll(lam_nll=1.0)
        # init shared var for weighting prior kld against reconstruction
        self.lam_kld_1 = theano.shared(value=zero_ary, name='msm_lam_kld_1')
        self.lam_kld_2 = theano.shared(value=zero_ary, name='msm_lam_kld_2')
        self.set_lam_kld(lam_kld_1=1.0, lam_kld_2=1.0)
        # init shared var for controlling l2 regularization on params
        self.lam_l2w = theano.shared(value=zero_ary, name='msm_lam_l2w')
        self.set_lam_l2w(1e-5)

        # Grab all of the "optimizable" parameters in "group 1"
        self.group_1_params = []
        self.group_1_params.extend(self.q_z_given_x.mlp_params)
        self.group_1_params.extend(self.p_s0_obs_given_z_obs.mlp_params)
        # Grab all of the "optimizable" parameters in "group 2"
        self.group_2_params = []
        for i in range(self.ir_steps):
            self.group_2_params.extend(self.q_hi_given_x_si[i].mlp_params)
            self.group_2_params.extend(self.p_hi_given_si[i].mlp_params)
            self.group_2_params.extend(self.p_sip1_given_si_hi[i].mlp_params)
        # Make a joint list of parameters group 1/2
        self.joint_params = self.group_1_params + self.group_2_params

        #################################
        # CONSTRUCT THE KLD-BASED COSTS #
        #################################
        self.kld_z, self.kld_hi_cond, self.kld_hi_glob = \
                self._construct_kld_costs()
        self.kld_cost = (self.lam_kld_1[0] * T.mean(self.kld_z)) + \
                (self.lam_kld_2[0] * (T.mean(self.kld_hi_cond) + \
                (self.kzg_weight[0] * T.mean(self.kld_hi_glob))))
        #################################
        # CONSTRUCT THE NLL-BASED COSTS #
        #################################
        self.nll_costs = self._construct_nll_costs()
        self.nll_cost = self.lam_nll[0] * T.mean(self.nll_costs)
        ########################################
        # CONSTRUCT THE REST OF THE JOINT COST #
        ########################################
        param_reg_cost = self._construct_reg_costs()
        self.reg_cost = self.lam_l2w[0] * param_reg_cost
        self.joint_cost = self.nll_cost + self.kld_cost + self.reg_cost

        # Get the gradient of the joint cost for all optimizable parameters
        self.joint_grads = OrderedDict()
        for p in self.joint_params:
            self.joint_grads[p] = T.grad(self.joint_cost, p)

        # Construct the updates for the generator and inferencer networks
        self.group_1_updates = get_param_updates(params=self.group_1_params, \
                grads=self.joint_grads, alpha=self.lr_1, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)
        self.group_2_updates = get_param_updates(params=self.group_2_params, \
                grads=self.joint_grads, alpha=self.lr_2, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)
        self.joint_updates = OrderedDict()
        for k in self.group_1_updates:
            self.joint_updates[k] = self.group_1_updates[k]
        for k in self.group_2_updates:
            self.joint_updates[k] = self.group_2_updates[k]

        # Construct a function for jointly training the generator/inferencer
        print("Compiling training function...")
        self.train_joint = self._construct_train_joint()
        self.compute_post_klds = self._construct_compute_post_klds()
        self.compute_fe_terms = self._construct_compute_fe_terms()
        self.sample_from_prior = self._construct_sample_from_prior()
        # make easy access points for some interesting parameters
        self.inf_1_weights = self.q_z_given_x.shared_layers[0].W
        self.gen_1_weights = self.p_s0_obs_given_z_obs.mu_layers[-1].W
        self.inf_2_weights = self.q_hi_given_x_si[0].shared_layers[0].W
        self.gen_2_weights = self.p_sip1_given_si_hi[0].mu_layers[-1].W
        self.gen_inf_weights = self.p_hi_given_si[0].shared_layers[0].W
        return
    def __init__(self, rng=None, x_in=None, \
            p_s0_obs_given_z_obs=None, p_hi_given_si=None, p_sip1_given_si_hi=None, \
            p_x_given_si_hi=None, q_z_given_x=None, q_hi_given_x_si=None, \
            obs_dim=None, z_rnn_dim=None, z_obs_dim=None, h_dim=None, \
            model_init_obs=True, model_init_rnn=True, ir_steps=2, \
            params=None):
        # setup a rng for this GIPair
        self.rng = RandStream(rng.randint(100000))

        # TODO: implement functionality for working with "latent" si
        assert(p_x_given_si_hi is None)

        # decide whether to initialize from a model or from a "constant"
        self.model_init_obs = model_init_obs
        self.model_init_rnn = model_init_rnn

        # grab the user-provided parameters
        self.params = params
        self.x_type = self.params['x_type']
        assert((self.x_type == 'bernoulli') or (self.x_type == 'gaussian'))
        if 'obs_transform' in self.params:
            assert((self.params['obs_transform'] == 'sigmoid') or \
                    (self.params['obs_transform'] == 'none'))
            if self.params['obs_transform'] == 'sigmoid':
                self.obs_transform = lambda x: T.nnet.sigmoid(x)
            else:
                self.obs_transform = lambda x: x
        else:
            self.obs_transform = lambda x: T.nnet.sigmoid(x)
        if self.x_type == 'bernoulli':
            self.obs_transform = lambda x: T.nnet.sigmoid(x)

        # record the dimensions of various spaces relevant to this model
        self.obs_dim = obs_dim
        self.rnn_dim = z_rnn_dim
        self.z_dim = z_rnn_dim + z_obs_dim
        self.z_rnn_dim = z_rnn_dim
        self.z_obs_dim = z_obs_dim
        self.h_dim = h_dim
        self.ir_steps = ir_steps

        # record the symbolic variables that will provide inputs to the
        # computation graph created to describe this MultiStageModel
        self.x = x_in
        self.batch_reps = T.lscalar()

        # setup switching variable for changing between sampling/training
        zero_ary = np.zeros((1,)).astype(theano.config.floatX)
        self.train_switch = theano.shared(value=zero_ary, name='msm_train_switch')
        self.set_train_switch(1.0)
        # setup a weight for pulling priors over hi given si towards a
        # shared global prior -- e.g. zero mean and unit variance.
        self.kzg_weight = theano.shared(value=zero_ary, name='msm_kzg_weight')
        self.set_kzg_weight(0.1)
        # this weight balances l1 vs. l2 penalty on posterior KLds
        self.l1l2_weight = theano.shared(value=zero_ary, name='msm_l1l2_weight')
        self.set_l1l2_weight(1.0)

        #############################
        # Setup self.z and self.s0. #
        #############################
        print("Building MSM step 0...")
        obs_scale = 0.0
        rnn_scale = 0.0
        if self.model_init_obs: # initialize obs state from generative model
            obs_scale = 1.0
        if self.model_init_rnn: # initialize rnn state from generative model
            rnn_scale = 1.0
        self.q_z_given_x = q_z_given_x.shared_param_clone(rng=rng, Xd=self.x)
        self.z = self.q_z_given_x.output
        self.z_rnn = self.z[:,:self.z_rnn_dim]
        self.z_obs = self.z[:,self.z_rnn_dim:]
        self.p_s0_obs_given_z_obs = p_s0_obs_given_z_obs.shared_param_clone( \
                rng=rng, Xd=self.z_obs)
        _s0_obs_model = self.p_s0_obs_given_z_obs.output_mean
        _s0_obs_const = self.p_s0_obs_given_z_obs.mu_layers[-1].b
        self.s0_obs = (obs_scale * _s0_obs_model) + \
                ((1.0 - obs_scale) * _s0_obs_const)
        _s0_rnn_model = self.z_rnn
        _s0_rnn_const = self.q_z_given_x.mu_layers[-1].b[:self.z_rnn_dim]
        self.s0_rnn = (rnn_scale * _s0_rnn_model) + \
                ((1.0 - rnn_scale) * _s0_rnn_const)
        self.s0_jnt = T.horizontal_stack(self.s0_obs, self.s0_rnn)
        self.output_logvar = self.p_s0_obs_given_z_obs.sigma_layers[-1].b
        self.bounded_logvar = 8.0 * T.tanh((1.0/8.0) * self.output_logvar)

        ###############################################################
        # Setup the iterative refinement loop, starting from self.s0. #
        ###############################################################
        self.p_hi_given_si = []       # holds p_hi_given_si for each i
        self.p_sip1_given_si_hi = []  # holds p_sip1_given_si_hi for each i
        self.q_hi_given_x_si = []     # holds q_hi_given_x_si for each i
        self.si = [self.s0_jnt]       # holds si for each i
        self.hi = []                  # holds hi for each i
        for i in range(self.ir_steps):
            print("Building MSM step {0:d}...".format(i+1))
            _si = self.si[i]
            si_obs = _si[:,:self.obs_dim]
            si_rnn = _si[:,self.obs_dim:]
            # get samples of next hi, conditioned on current si
            self.p_hi_given_si.append( \
                    p_hi_given_si.shared_param_clone(rng=rng, \
                    Xd=T.horizontal_stack( \
                    self.obs_transform(si_obs), si_rnn)))
            hi_p = self.p_hi_given_si[i].output
            # now we build the model for variational hi given si
            grad_ll = self.x - self.obs_transform(si_obs)
            self.q_hi_given_x_si.append(\
                    q_hi_given_x_si.shared_param_clone(rng=rng, \
                    Xd=T.horizontal_stack( \
                    grad_ll, self.obs_transform(si_obs), si_rnn)))
            hi_q = self.q_hi_given_x_si[i].output
            # make hi samples that can be switched between hi_p and hi_q
            self.hi.append( ((self.train_switch[0] * hi_q) + \
                    ((1.0 - self.train_switch[0]) * hi_p)) )
            # p_sip1_given_si_hi is conditioned on hi and the "rnn" part of si.
            self.p_sip1_given_si_hi.append( \
                    p_sip1_given_si_hi.shared_param_clone(rng=rng, \
                    Xd=T.horizontal_stack(self.hi[i], si_rnn)))
            # construct the update from si_obs/si_rnn to sip1_obs/sip1_rnn
            sip1_obs = si_obs + self.p_sip1_given_si_hi[i].output_mean
            sip1_rnn = si_rnn
            sip1_jnt = T.horizontal_stack(sip1_obs, sip1_rnn)
            # record the updated state of the generative process
            self.si.append(sip1_jnt)
        # check that input/output dimensions of our models agree
        self._check_model_shapes()

        ######################################################################
        # ALL SYMBOLIC VARS NEEDED FOR THE OBJECTIVE SHOULD NOW BE AVAILABLE #
        ######################################################################

        # shared var learning rate for generator and inferencer
        zero_ary = np.zeros((1,)).astype(theano.config.floatX)
        self.lr_1 = theano.shared(value=zero_ary, name='msm_lr_1')
        self.lr_2 = theano.shared(value=zero_ary, name='msm_lr_2')
        # shared var momentum parameters for generator and inferencer
        self.mom_1 = theano.shared(value=zero_ary, name='msm_mom_1')
        self.mom_2 = theano.shared(value=zero_ary, name='msm_mom_2')
        self.it_count = theano.shared(value=zero_ary, name='msm_it_count')
        # init parameters for controlling learning dynamics
        self.set_sgd_params()
        # init shared var for weighting nll of data given posterior sample
        self.lam_nll = theano.shared(value=zero_ary, name='msm_lam_nll')
        self.set_lam_nll(lam_nll=1.0)
        # init shared var for weighting prior kld against reconstruction
        self.lam_kld_1 = theano.shared(value=zero_ary, name='msm_lam_kld_1')
        self.lam_kld_2 = theano.shared(value=zero_ary, name='msm_lam_kld_2')
        self.set_lam_kld(lam_kld_1=1.0, lam_kld_2=1.0)
        # init shared var for controlling l2 regularization on params
        self.lam_l2w = theano.shared(value=zero_ary, name='msm_lam_l2w')
        self.set_lam_l2w(1e-5)

        # Grab all of the "optimizable" parameters in "group 1"
        self.group_1_params = []
        self.group_1_params.extend(self.q_z_given_x.mlp_params)
        self.group_1_params.extend(self.p_s0_obs_given_z_obs.mlp_params)
        # Grab all of the "optimizable" parameters in "group 2"
        self.group_2_params = []
        for i in range(self.ir_steps):
            self.group_2_params.extend(self.q_hi_given_x_si[i].mlp_params)
            self.group_2_params.extend(self.p_hi_given_si[i].mlp_params)
            self.group_2_params.extend(self.p_sip1_given_si_hi[i].mlp_params)
        # Make a joint list of parameters group 1/2
        self.joint_params = self.group_1_params + self.group_2_params

        #################################
        # CONSTRUCT THE KLD-BASED COSTS #
        #################################
        self.kld_z, self.kld_hi_cond, self.kld_hi_glob = \
                self._construct_kld_costs()
        self.kld_cost = (self.lam_kld_1[0] * T.mean(self.kld_z)) + \
                (self.lam_kld_2[0] * (T.mean(self.kld_hi_cond) + \
                (self.kzg_weight[0] * T.mean(self.kld_hi_glob))))
        #################################
        # CONSTRUCT THE NLL-BASED COSTS #
        #################################
        self.nll_costs = self._construct_nll_costs()
        self.nll_cost = self.lam_nll[0] * T.mean(self.nll_costs)
        ########################################
        # CONSTRUCT THE REST OF THE JOINT COST #
        ########################################
        param_reg_cost = self._construct_reg_costs()
        self.reg_cost = self.lam_l2w[0] * param_reg_cost
        self.joint_cost = self.nll_cost + self.kld_cost + self.reg_cost

        # Get the gradient of the joint cost for all optimizable parameters
        self.joint_grads = OrderedDict()
        for p in self.joint_params:
            self.joint_grads[p] = T.grad(self.joint_cost, p)

        # Construct the updates for the generator and inferencer networks
        self.group_1_updates = get_param_updates(params=self.group_1_params, \
                grads=self.joint_grads, alpha=self.lr_1, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)
        self.group_2_updates = get_param_updates(params=self.group_2_params, \
                grads=self.joint_grads, alpha=self.lr_2, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)
        self.joint_updates = OrderedDict()
        for k in self.group_1_updates:
            self.joint_updates[k] = self.group_1_updates[k]
        for k in self.group_2_updates:
            self.joint_updates[k] = self.group_2_updates[k]

        # Construct a function for jointly training the generator/inferencer
        print("Compiling training function...")
        self.train_joint = self._construct_train_joint()
        self.compute_post_klds = self._construct_compute_post_klds()
        self.compute_fe_terms = self._construct_compute_fe_terms()
        self.sample_from_prior = self._construct_sample_from_prior()
        # make easy access points for some interesting parameters
        self.inf_1_weights = self.q_z_given_x.shared_layers[0].W
        self.gen_1_weights = self.p_s0_obs_given_z_obs.mu_layers[-1].W
        self.inf_2_weights = self.q_hi_given_x_si[0].shared_layers[0].W
        self.gen_2_weights = self.p_sip1_given_si_hi[0].mu_layers[-1].W
        self.gen_inf_weights = self.p_hi_given_si[0].shared_layers[0].W
        return