Пример #1
0
def pretrain_osm(lam_kld=0.0):
    # Initialize a source of randomness
    rng = np.random.RandomState(1234)

    # Load some data to train/validate/test with
    data_file = 'data/tfd_data_48x48.pkl'
    dataset = load_tfd(tfd_pkl_name=data_file,
                       which_set='unlabeled',
                       fold='all')
    Xtr_unlabeled = dataset[0]
    dataset = load_tfd(tfd_pkl_name=data_file, which_set='train', fold='all')
    Xtr_train = dataset[0]
    Xtr = np.vstack([Xtr_unlabeled, Xtr_train])
    dataset = load_tfd(tfd_pkl_name=data_file, which_set='valid', fold='all')
    Xva = dataset[0]
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 400
    batch_reps = 6
    carry_frac = 0.25
    carry_size = int(batch_size * carry_frac)
    reset_prob = 0.04

    # setup some symbolic variables and stuff
    Xd = T.matrix('Xd_base')
    Xc = T.matrix('Xc_base')
    Xm = T.matrix('Xm_base')
    data_dim = Xtr.shape[1]
    prior_sigma = 1.0
    Xtr_mean = np.mean(Xtr, axis=0)

    ##########################
    # NETWORK CONFIGURATIONS #
    ##########################
    gn_params = {}
    shared_config = [PRIOR_DIM, 1500, 1500]
    top_config = [shared_config[-1], data_dim]
    gn_params['shared_config'] = shared_config
    gn_params['mu_config'] = top_config
    gn_params['sigma_config'] = top_config
    gn_params['activation'] = relu_actfun
    gn_params['init_scale'] = 1.4
    gn_params['lam_l2a'] = 0.0
    gn_params['vis_drop'] = 0.0
    gn_params['hid_drop'] = 0.0
    gn_params['bias_noise'] = 0.0
    gn_params['input_noise'] = 0.0
    # choose some parameters for the continuous inferencer
    in_params = {}
    shared_config = [data_dim, 1500, 1500]
    top_config = [shared_config[-1], PRIOR_DIM]
    in_params['shared_config'] = shared_config
    in_params['mu_config'] = top_config
    in_params['sigma_config'] = top_config
    in_params['activation'] = relu_actfun
    in_params['init_scale'] = 1.4
    in_params['lam_l2a'] = 0.0
    in_params['vis_drop'] = 0.0
    in_params['hid_drop'] = 0.0
    in_params['bias_noise'] = 0.0
    in_params['input_noise'] = 0.0
    # Initialize the base networks for this OneStageModel
    IN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \
            params=in_params, shared_param_dicts=None)
    GN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \
            params=gn_params, shared_param_dicts=None)
    # Initialize biases in IN and GN
    IN.init_biases(0.2)
    GN.init_biases(0.2)

    ######################################
    # LOAD AND RESTART FROM SAVED PARAMS #
    ######################################
    # gn_fname = RESULT_PATH+"pt_osm_params_b110000_GN.pkl"
    # in_fname = RESULT_PATH+"pt_osm_params_b110000_IN.pkl"
    # IN = load_infnet_from_file(f_name=in_fname, rng=rng, Xd=Xd, \
    #         new_params=None)
    # GN = load_infnet_from_file(f_name=gn_fname, rng=rng, Xd=Xd, \
    #         new_params=None)
    # in_params = IN.params
    # gn_params = GN.params

    #########################
    # INITIALIZE THE GIPAIR #
    #########################
    osm_params = {}
    osm_params['x_type'] = 'bernoulli'
    osm_params['xt_transform'] = 'sigmoid'
    osm_params['logvar_bound'] = LOGVAR_BOUND
    OSM = OneStageModel(rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, \
            p_x_given_z=GN, q_z_given_x=IN, \
            x_dim=data_dim, z_dim=PRIOR_DIM, params=osm_params)
    OSM.set_lam_l2w(1e-5)
    safe_mean = (0.9 * Xtr_mean) + 0.05
    safe_mean_logit = np.log(safe_mean / (1.0 - safe_mean))
    OSM.set_output_bias(safe_mean_logit)
    OSM.set_input_bias(-Xtr_mean)

    ######################
    # BASIC VAE TRAINING #
    ######################
    out_file = open(RESULT_PATH + "pt_osm_results.txt", 'wb')
    # Set initial learning rate and basic SGD hyper parameters
    obs_costs = np.zeros((batch_size, ))
    costs = [0. for i in range(10)]
    learn_rate = 0.002
    for i in range(200000):
        scale = min(1.0, float(i) / 5000.0)
        if ((i > 1) and ((i % 20000) == 0)):
            learn_rate = learn_rate * 0.8
        if (i < 50000):
            momentum = 0.5
        elif (i < 10000):
            momentum = 0.7
        else:
            momentum = 0.9
        if ((i == 0) or (npr.rand() < reset_prob)):
            # sample a fully random batch
            batch_idx = npr.randint(low=0,
                                    high=tr_samples,
                                    size=(batch_size, ))
        else:
            # sample a partially random batch, which retains some portion of
            # the worst scoring examples from the previous batch
            fresh_idx = npr.randint(low=0,
                                    high=tr_samples,
                                    size=(batch_size - carry_size, ))
            batch_idx = np.concatenate((fresh_idx.ravel(), carry_idx.ravel()))
        # do a minibatch update of the model, and compute some costs
        tr_idx = npr.randint(low=0, high=tr_samples, size=(batch_size, ))
        Xd_batch = Xtr.take(tr_idx, axis=0)
        Xc_batch = 0.0 * Xd_batch
        Xm_batch = 0.0 * Xd_batch
        # do a minibatch update of the model, and compute some costs
        OSM.set_sgd_params(lr_1=(scale*learn_rate), \
                mom_1=(scale*momentum), mom_2=0.98)
        OSM.set_lam_nll(1.0)
        OSM.set_lam_kld(lam_kld_1=scale * lam_kld,
                        lam_kld_2=0.0,
                        lam_kld_c=50.0)
        result = OSM.train_joint(Xd_batch, Xc_batch, Xm_batch, batch_reps)
        batch_costs = result[4] + result[5]
        obs_costs = collect_obs_costs(batch_costs, batch_reps)
        carry_idx = batch_idx[np.argsort(-obs_costs)[0:carry_size]]
        costs = [(costs[j] + result[j]) for j in range(len(result))]
        if ((i % 1000) == 0):
            # record and then reset the cost trackers
            costs = [(v / 1000.0) for v in costs]
            str_1 = "-- batch {0:d} --".format(i)
            str_2 = "    joint_cost: {0:.4f}".format(costs[0])
            str_3 = "    nll_cost  : {0:.4f}".format(costs[1])
            str_4 = "    kld_cost  : {0:.4f}".format(costs[2])
            str_5 = "    reg_cost  : {0:.4f}".format(costs[3])
            costs = [0.0 for v in costs]
            # print out some diagnostic information
            joint_str = "\n".join([str_1, str_2, str_3, str_4, str_5])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
        if ((i % 2000) == 0):
            Xva = row_shuffle(Xva)
            model_samps = OSM.sample_from_prior(500)
            file_name = RESULT_PATH + "pt_osm_samples_b{0:d}_XG.png".format(i)
            utils.visualize_samples(model_samps, file_name, num_rows=20)
            file_name = RESULT_PATH + "pt_osm_inf_weights_b{0:d}.png".format(i)
            utils.visualize_samples(OSM.inf_weights.get_value(borrow=False).T, \
                    file_name, num_rows=30)
            file_name = RESULT_PATH + "pt_osm_gen_weights_b{0:d}.png".format(i)
            utils.visualize_samples(OSM.gen_weights.get_value(borrow=False), \
                    file_name, num_rows=30)
            # compute information about free-energy on validation set
            file_name = RESULT_PATH + "pt_osm_free_energy_b{0:d}.png".format(i)
            fe_terms = OSM.compute_fe_terms(Xva[0:2500], 20)
            fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1])
            fe_str = "    nll_bound : {0:.4f}".format(fe_mean)
            print(fe_str)
            out_file.write(fe_str + "\n")
            utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \
                    x_label='Posterior KLd', y_label='Negative Log-likelihood')
            # compute information about posterior KLds on validation set
            file_name = RESULT_PATH + "pt_osm_post_klds_b{0:d}.png".format(i)
            post_klds = OSM.compute_post_klds(Xva[0:2500])
            post_dim_klds = np.mean(post_klds, axis=0)
            utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \
                    file_name)
        if ((i % 5000) == 0):
            IN.save_to_file(f_name=RESULT_PATH +
                            "pt_osm_params_b{0:d}_IN.pkl".format(i))
            GN.save_to_file(f_name=RESULT_PATH +
                            "pt_osm_params_b{0:d}_GN.pkl".format(i))
    IN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_IN.pkl")
    GN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_GN.pkl")
    return
Пример #2
0
def pretrain_osm(lam_kld=0.0):
    # Initialize a source of randomness
    rng = np.random.RandomState(1234)

    # Load some data to train/validate/test with
    dataset = 'data/mnist.pkl.gz'
    datasets = load_udm(dataset, zero_mean=False)
    Xtr = datasets[0][0]
    Xtr = Xtr.get_value(borrow=False)
    Xva = datasets[2][0]
    Xva = Xva.get_value(borrow=False)
    print("Xtr.shape: {0:s}, Xva.shape: {1:s}".format(str(Xtr.shape),
                                                      str(Xva.shape)))

    # get and set some basic dataset information
    Xtr_mean = np.mean(Xtr, axis=0)
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 100
    batch_reps = 5

    # setup some symbolic variables and stuff
    Xd = T.matrix('Xd_base')
    Xc = T.matrix('Xc_base')
    Xm = T.matrix('Xm_base')
    data_dim = Xtr.shape[1]
    prior_sigma = 1.0

    ##########################
    # NETWORK CONFIGURATIONS #
    ##########################
    gn_params = {}
    shared_config = [PRIOR_DIM, 1000, 1000]
    top_config = [shared_config[-1], data_dim]
    gn_params['shared_config'] = shared_config
    gn_params['mu_config'] = top_config
    gn_params['sigma_config'] = top_config
    gn_params['activation'] = relu_actfun
    gn_params['init_scale'] = 1.4
    gn_params['lam_l2a'] = 0.0
    gn_params['vis_drop'] = 0.0
    gn_params['hid_drop'] = 0.0
    gn_params['bias_noise'] = 0.0
    gn_params['input_noise'] = 0.0
    # choose some parameters for the continuous inferencer
    in_params = {}
    shared_config = [data_dim, 1000, 1000]
    top_config = [shared_config[-1], PRIOR_DIM]
    in_params['shared_config'] = shared_config
    in_params['mu_config'] = top_config
    in_params['sigma_config'] = top_config
    in_params['activation'] = relu_actfun
    in_params['init_scale'] = 1.4
    in_params['lam_l2a'] = 0.0
    in_params['vis_drop'] = 0.0
    in_params['hid_drop'] = 0.0
    in_params['bias_noise'] = 0.0
    in_params['input_noise'] = 0.0
    # Initialize the base networks for this OneStageModel
    IN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \
            params=in_params, shared_param_dicts=None)
    GN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \
            params=gn_params, shared_param_dicts=None)
    # Initialize biases in IN and GN
    IN.init_biases(0.2)
    GN.init_biases(0.2)

    #########################
    # INITIALIZE THE GIPAIR #
    #########################
    osm_params = {}
    osm_params['x_type'] = 'bernoulli'
    osm_params['xt_transform'] = 'sigmoid'
    osm_params['logvar_bound'] = LOGVAR_BOUND
    OSM = OneStageModel(rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, \
            p_x_given_z=GN, q_z_given_x=IN, \
            x_dim=data_dim, z_dim=PRIOR_DIM, params=osm_params)
    OSM.set_lam_l2w(1e-5)
    safe_mean = (0.9 * Xtr_mean) + 0.05
    safe_mean_logit = np.log(safe_mean / (1.0 - safe_mean))
    OSM.set_output_bias(safe_mean_logit)
    OSM.set_input_bias(-Xtr_mean)

    ######################
    # BASIC VAE TRAINING #
    ######################
    out_file = open(RESULT_PATH + "pt_osm_results.txt", 'wb')
    # Set initial learning rate and basic SGD hyper parameters
    obs_costs = np.zeros((batch_size, ))
    costs = [0. for i in range(10)]
    learn_rate = 0.0005
    for i in range(150000):
        scale = min(1.0, float(i) / 10000.0)
        if ((i > 1) and ((i % 20000) == 0)):
            learn_rate = learn_rate * 0.9
        # do a minibatch update of the model, and compute some costs
        tr_idx = npr.randint(low=0, high=tr_samples, size=(batch_size, ))
        Xd_batch = Xtr.take(tr_idx, axis=0)
        Xc_batch = 0.0 * Xd_batch
        Xm_batch = 0.0 * Xd_batch
        # do a minibatch update of the model, and compute some costs
        OSM.set_sgd_params(lr_1=(scale * learn_rate), mom_1=0.5, mom_2=0.98)
        OSM.set_lam_nll(1.0)
        OSM.set_lam_kld(lam_kld_1=(1.0 + (scale * (lam_kld - 1.0))),
                        lam_kld_2=0.0)
        result = OSM.train_joint(Xd_batch, Xc_batch, Xm_batch, batch_reps)
        costs = [(costs[j] + result[j]) for j in range(len(result))]
        if ((i % 1000) == 0):
            # record and then reset the cost trackers
            costs = [(v / 1000.0) for v in costs]
            str_1 = "-- batch {0:d} --".format(i)
            str_2 = "    joint_cost: {0:.4f}".format(costs[0])
            str_3 = "    nll_cost  : {0:.4f}".format(costs[1])
            str_4 = "    kld_cost  : {0:.4f}".format(costs[2])
            str_5 = "    reg_cost  : {0:.4f}".format(costs[3])
            costs = [0.0 for v in costs]
            # print out some diagnostic information
            joint_str = "\n".join([str_1, str_2, str_3, str_4, str_5])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
        if ((i % 2000) == 0):
            Xva = row_shuffle(Xva)
            model_samps = OSM.sample_from_prior(500)
            file_name = RESULT_PATH + "pt_osm_samples_b{0:d}_XG.png".format(i)
            utils.visualize_samples(model_samps, file_name, num_rows=20)
            # compute information about free-energy on validation set
            file_name = RESULT_PATH + "pt_osm_free_energy_b{0:d}.png".format(i)
            fe_terms = OSM.compute_fe_terms(Xva[0:2500], 20)
            fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1])
            fe_str = "    nll_bound : {0:.4f}".format(fe_mean)
            print(fe_str)
            out_file.write(fe_str + "\n")
            utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \
                    x_label='Posterior KLd', y_label='Negative Log-likelihood')
            # compute information about posterior KLds on validation set
            file_name = RESULT_PATH + "pt_osm_post_klds_b{0:d}.png".format(i)
            post_klds = OSM.compute_post_klds(Xva[0:2500])
            post_dim_klds = np.mean(post_klds, axis=0)
            utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \
                    file_name)
        if ((i % 5000) == 0):
            IN.save_to_file(f_name=RESULT_PATH +
                            "pt_osm_params_b{0:d}_IN.pkl".format(i))
            GN.save_to_file(f_name=RESULT_PATH +
                            "pt_osm_params_b{0:d}_GN.pkl".format(i))
    IN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_IN.pkl")
    GN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_GN.pkl")
    return
Пример #3
0
    def __init__(self, rng=None, Xd=None, Xc=None, Xm=None, Xt=None, \
                 i_net=None, g_net=None, d_net=None, chain_len=None, \
                 data_dim=None, prior_dim=None, params=None):
        # Do some stuff!
        self.rng = RandStream(rng.randint(100000))
        self.data_dim = data_dim
        self.prior_dim = prior_dim
        self.prior_mean = 0.0
        self.prior_logvar = 0.0
        if params is None:
            self.params = {}
        else:
            self.params = params
        if 'cost_decay' in self.params:
            self.cost_decay = self.params['cost_decay']
        else:
            self.cost_decay = 0.1
        if 'chain_type' in self.params:
            assert((self.params['chain_type'] == 'walkback') or \
                (self.params['chain_type'] == 'walkout'))
            self.chain_type = self.params['chain_type']
        else:
            self.chain_type = 'walkout'
        if 'xt_transform' in self.params:
            assert((self.params['xt_transform'] == 'sigmoid') or \
                    (self.params['xt_transform'] == 'none'))
            if self.params['xt_transform'] == 'sigmoid':
                self.xt_transform = lambda x: T.nnet.sigmoid(x)
            else:
                self.xt_transform = lambda x: x
        else:
            self.xt_transform = lambda x: T.nnet.sigmoid(x)
        if 'logvar_bound' in self.params:
            self.logvar_bound = self.params['logvar_bound']
        else:
            self.logvar_bound = 10
        #
        # x_type: this tells if we're using bernoulli or gaussian model for
        #         the observations
        #
        self.x_type = self.params['x_type']
        assert ((self.x_type == 'bernoulli') or (self.x_type == 'gaussian'))

        # symbolic var for inputting samples for initializing the VAE chain
        self.Xd = Xd
        # symbolic var for masking subsets of the state variables
        self.Xm = Xm
        # symbolic var for controlling subsets of the state variables
        self.Xc = Xc
        # symbolic var for inputting samples from the target distribution
        self.Xt = Xt
        # integer number of times to cycle the VAE loop
        self.chain_len = chain_len

        # symbolic matrix of indices for data inputs
        self.It = T.arange(self.Xt.shape[0])
        # symbolic matrix of indices for noise/generated inputs
        self.Id = T.arange(
            self.chain_len * self.Xd.shape[0]) + self.Xt.shape[0]

        # get a clone of the desired VAE, for easy access
        self.OSM = OneStageModel(rng=rng, Xd=self.Xd, Xc=self.Xc, Xm=self.Xm, \
                p_x_given_z=g_net, q_z_given_x=i_net, x_dim=self.data_dim, \
                z_dim=self.prior_dim, params=self.params)
        self.IN = self.OSM.q_z_given_x
        self.GN = self.OSM.p_x_given_z
        self.transform_x_to_z = self.OSM.transform_x_to_z
        self.transform_z_to_x = self.OSM.transform_z_to_x
        self.bounded_logvar = self.OSM.bounded_logvar
        # self-loop some clones of the main VAE into a chain.
        # ** All VAEs in the chain share the same Xc and Xm, which are the
        #    symbolic inputs for providing the observed portion of the input
        #    and a mask indicating which part of the input is "observed".
        #    These inputs are used for training "reconstruction" policies.
        self.IN_chain = []
        self.GN_chain = []
        self.Xg_chain = []
        _Xd = self.Xd
        print("Unrolling chain...")
        for i in range(self.chain_len):
            # create a VAE infer/generate pair with _Xd as input and with
            # masking variables shared by all VAEs in this chain
            _IN = self.IN.shared_param_clone(rng=rng, \
                    Xd=apply_mask(Xd=_Xd, Xc=self.Xc, Xm=self.Xm), \
                    build_funcs=False)
            _GN = self.GN.shared_param_clone(rng=rng, Xd=_IN.output, \
                    build_funcs=False)
            _Xd = self.xt_transform(_GN.output_mean)
            self.IN_chain.append(_IN)
            self.GN_chain.append(_GN)
            self.Xg_chain.append(_Xd)
            print("    step {}...".format(i))

        # make a clone of the desired discriminator network, which will try
        # to discriminate between samples from the training data and samples
        # generated by the self-looped VAE chain.
        self.DN = d_net.shared_param_clone(rng=rng, \
                Xd=T.vertical_stack(self.Xt, *self.Xg_chain))

        zero_ary = np.zeros((1, )).astype(theano.config.floatX)
        # init shared var for weighting nll of data given posterior sample
        self.lam_chain_nll = theano.shared(value=zero_ary,
                                           name='vcg_lam_chain_nll')
        self.set_lam_chain_nll(lam_chain_nll=1.0)
        # init shared var for weighting posterior KL-div from prior
        self.lam_chain_kld = theano.shared(value=zero_ary,
                                           name='vcg_lam_chain_kld')
        self.set_lam_chain_kld(lam_chain_kld=1.0)
        # init shared var for controlling l2 regularization on params
        self.lam_l2w = theano.shared(value=zero_ary, name='vcg_lam_l2w')
        self.set_lam_l2w(lam_l2w=1e-4)
        # shared var learning rates for all networks
        self.lr_dn = theano.shared(value=zero_ary, name='vcg_lr_dn')
        self.lr_gn = theano.shared(value=zero_ary, name='vcg_lr_gn')
        self.lr_in = theano.shared(value=zero_ary, name='vcg_lr_in')
        # shared var momentum parameters for all networks
        self.mom_1 = theano.shared(value=zero_ary, name='vcg_mom_1')
        self.mom_2 = theano.shared(value=zero_ary, name='vcg_mom_2')
        self.it_count = theano.shared(value=zero_ary, name='vcg_it_count')
        # shared var weights for adversarial classification objective
        self.dw_dn = theano.shared(value=zero_ary, name='vcg_dw_dn')
        self.dw_gn = theano.shared(value=zero_ary, name='vcg_dw_gn')
        # init parameters for controlling learning dynamics
        self.set_all_sgd_params()

        self.set_disc_weights()  # init adversarial cost weights for GN/DN
        # set a shared var for regularizing the output of the discriminator
        self.lam_l2d = theano.shared(value=(zero_ary + params['lam_l2d']), \
                name='vcg_lam_l2d')

        # Grab the full set of "optimizable" parameters from the generator
        # and discriminator networks that we'll be working with. We need to
        # ignore parameters in the final layers of the proto-networks in the
        # discriminator network (a generalized pseudo-ensemble). We ignore them
        # because the VCGair requires that they be "bypassed" in favor of some
        # binary classification layers that will be managed by this VCGair.
        self.dn_params = []
        for pn in self.DN.proto_nets:
            for pnl in pn[0:-1]:
                self.dn_params.extend(pnl.params)
        self.in_params = [p for p in self.IN.mlp_params]
        self.in_params.append(self.OSM.output_logvar)
        self.gn_params = [p for p in self.GN.mlp_params]
        self.joint_params = self.in_params + self.gn_params + self.dn_params

        # Now construct a binary discriminator layer for each proto-net in the
        # discriminator network. And, add their params to optimization list.
        self._construct_disc_layers(rng)
        self.disc_reg_cost = self.lam_l2d[0] * \
                T.sum([dl.act_l2_sum for dl in self.disc_layers])

        # Construct costs for the generator and discriminator networks based
        # on adversarial binary classification
        self.disc_cost_dn, self.disc_cost_gn = self._construct_disc_costs()

        # first, build the cost to be optimized by the discriminator network,
        # in general this will be treated somewhat indepedently of the
        # optimization of the generator and inferencer networks.
        self.dn_cost = self.disc_cost_dn + self.DN.act_reg_cost + \
                self.disc_reg_cost

        # construct costs relevant to the optimization of the generator and
        # discriminator networks
        self.chain_nll_cost = self.lam_chain_nll[0] * \
                self._construct_chain_nll_cost(cost_decay=self.cost_decay)
        self.chain_kld_cost = self.lam_chain_kld[0] * \
                self._construct_chain_kld_cost(cost_decay=self.cost_decay)
        self.other_reg_cost = self._construct_other_reg_cost()
        self.osm_cost = self.disc_cost_gn + self.chain_nll_cost + \
                self.chain_kld_cost + self.other_reg_cost
        # compute total cost on the discriminator and VB generator/inferencer
        self.joint_cost = self.dn_cost + self.osm_cost

        # Get the gradient of the joint cost for all optimizable parameters
        self.joint_grads = OrderedDict()
        print("Computing VCGLoop DN cost gradients...")
        grad_list = T.grad(self.dn_cost,
                           self.dn_params,
                           disconnected_inputs='warn')
        for i, p in enumerate(self.dn_params):
            self.joint_grads[p] = grad_list[i]
        print("Computing VCGLoop IN cost gradients...")
        grad_list = T.grad(self.osm_cost,
                           self.in_params,
                           disconnected_inputs='warn')
        for i, p in enumerate(self.in_params):
            self.joint_grads[p] = grad_list[i]
        print("Computing VCGLoop GN cost gradients...")
        grad_list = T.grad(self.osm_cost,
                           self.gn_params,
                           disconnected_inputs='warn')
        for i, p in enumerate(self.gn_params):
            self.joint_grads[p] = grad_list[i]

        # construct the updates for the discriminator, generator and
        # inferencer networks. all networks share the same first/second
        # moment momentum and iteration count. the networks each have their
        # own learning rates, which lets you turn their learning on/off.
        self.dn_updates = get_param_updates(params=self.dn_params, \
                grads=self.joint_grads, alpha=self.lr_dn, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)
        self.gn_updates = get_param_updates(params=self.gn_params, \
                grads=self.joint_grads, alpha=self.lr_gn, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)
        self.in_updates = get_param_updates(params=self.in_params, \
                grads=self.joint_grads, alpha=self.lr_in, \
                beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \
                mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0)

        # bag up all the updates required for training
        self.joint_updates = OrderedDict()
        for k in self.dn_updates:
            self.joint_updates[k] = self.dn_updates[k]
        for k in self.gn_updates:
            self.joint_updates[k] = self.gn_updates[k]
        for k in self.in_updates:
            self.joint_updates[k] = self.in_updates[k]
        # construct an update for tracking the mean KL divergence of
        # approximate posteriors for this chain
        new_kld_mean = (0.98 * self.IN.kld_mean) + ((0.02 / self.chain_len) * \
            sum([T.mean(I_N.kld_cost) for I_N in self.IN_chain]))
        self.joint_updates[self.IN.kld_mean] = T.cast(new_kld_mean, 'floatX')

        # construct the function for training on training data
        print("Compiling VCGLoop theano functions....")
        self.train_joint = self._construct_train_joint()
        return
Пример #4
0
def test_gip_sigma_scale_tfd():
    from LogPDFs import cross_validate_sigma
    # Simple test code, to check that everything is basically functional.
    print("TESTING...")

    # Initialize a source of randomness
    rng = np.random.RandomState(12345)

    # Load some data to train/validate/test with
    data_file = 'data/tfd_data_48x48.pkl'
    dataset = load_tfd(tfd_pkl_name=data_file,
                       which_set='unlabeled',
                       fold='all')
    Xtr_unlabeled = dataset[0]
    dataset = load_tfd(tfd_pkl_name=data_file, which_set='train', fold='all')
    Xtr_train = dataset[0]
    Xtr = np.vstack([Xtr_unlabeled, Xtr_train])
    dataset = load_tfd(tfd_pkl_name=data_file, which_set='test', fold='all')
    Xva = dataset[0]
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    print("Xtr.shape: {0:s}, Xva.shape: {1:s}".format(str(Xtr.shape),
                                                      str(Xva.shape)))

    # get and set some basic dataset information
    tr_samples = Xtr.shape[0]
    data_dim = Xtr.shape[1]
    batch_size = 100

    # Symbolic inputs
    Xd = T.matrix(name='Xd')
    Xc = T.matrix(name='Xc')
    Xm = T.matrix(name='Xm')
    Xt = T.matrix(name='Xt')

    # Load inferencer and generator from saved parameters
    gn_fname = "TFD_WALKOUT_TEST_KLD/pt_walk_params_b25000_GN.pkl"
    in_fname = "TFD_WALKOUT_TEST_KLD/pt_walk_params_b25000_IN.pkl"
    IN = load_infnet_from_file(f_name=in_fname, rng=rng, Xd=Xd)
    GN = load_infnet_from_file(f_name=gn_fname, rng=rng, Xd=Xd)
    x_dim = IN.shared_layers[0].in_dim
    z_dim = IN.mu_layers[-1].out_dim
    # construct a GIPair with the loaded InfNet and GenNet
    osm_params = {}
    osm_params['x_type'] = 'gaussian'
    osm_params['xt_transform'] = 'sigmoid'
    osm_params['logvar_bound'] = LOGVAR_BOUND
    OSM = OneStageModel(rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, \
            p_x_given_z=GN, q_z_given_x=IN, \
            x_dim=x_dim, z_dim=z_dim, params=osm_params)

    # # compute variational likelihood bound and its sub-components
    Xva = row_shuffle(Xva)
    Xb = Xva[0:5000]
    # file_name = "A_TFD_POST_KLDS.png"
    # post_klds = OSM.compute_post_klds(Xb)
    # post_dim_klds = np.mean(post_klds, axis=0)
    # utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \
    #         file_name)
    # compute information about free-energy on validation set
    file_name = "A_TFD_KLD_FREE_ENERGY.png"
    fe_terms = OSM.compute_fe_terms(Xb, 20)
    utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \
            x_label='Posterior KLd', y_label='Negative Log-likelihood')

    # bound_results = OSM.compute_ll_bound(Xva)
    # ll_bounds = bound_results[0]
    # post_klds = bound_results[1]
    # log_likelihoods = bound_results[2]
    # max_lls = bound_results[3]
    # print("mean ll bound: {0:.4f}".format(np.mean(ll_bounds)))
    # print("mean posterior KLd: {0:.4f}".format(np.mean(post_klds)))
    # print("mean log-likelihood: {0:.4f}".format(np.mean(log_likelihoods)))
    # print("mean max log-likelihood: {0:.4f}".format(np.mean(max_lls)))
    # print("min ll bound: {0:.4f}".format(np.min(ll_bounds)))
    # print("max posterior KLd: {0:.4f}".format(np.max(post_klds)))
    # print("min log-likelihood: {0:.4f}".format(np.min(log_likelihoods)))
    # print("min max log-likelihood: {0:.4f}".format(np.min(max_lls)))
    # # compute some information about the approximate posteriors
    # post_stats = OSM.compute_post_stats(Xva, 0.0*Xva, 0.0*Xva)
    # all_post_klds = np.sort(post_stats[0].ravel()) # post KLds for each obs and dim
    # obs_post_klds = np.sort(post_stats[1]) # summed post KLds for each obs
    # post_dim_klds = post_stats[2] # average post KLds for each post dim
    # post_dim_vars = post_stats[3] # average squared mean for each post dim
    # utils.plot_line(np.arange(all_post_klds.shape[0]), all_post_klds, "AAA_ALL_POST_KLDS.png")
    # utils.plot_line(np.arange(obs_post_klds.shape[0]), obs_post_klds, "AAA_OBS_POST_KLDS.png")
    # utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, "AAA_POST_DIM_KLDS.png")
    # utils.plot_stem(np.arange(post_dim_vars.shape[0]), post_dim_vars, "AAA_POST_DIM_VARS.png")

    # draw many samples from the GIP
    for i in range(5):
        tr_idx = npr.randint(low=0, high=tr_samples, size=(100, ))
        Xd_batch = Xtr.take(tr_idx, axis=0)
        Xs = []
        for row in range(3):
            Xs.append([])
            for col in range(3):
                sample_lists = OSM.sample_from_chain(Xd_batch[0:10,:], loop_iters=100, \
                        sigma_scale=1.0)
                Xs[row].append(group_chains(sample_lists['data samples']))
        Xs, block_im_dim = block_video(Xs, (48, 48), (3, 3))
        to_video(Xs,
                 block_im_dim,
                 "A_TFD_KLD_CHAIN_VIDEO_{0:d}.avi".format(i),
                 frame_rate=10)
        #sample_lists = GIP.sample_from_chain(Xd_batch[0,:].reshape((1,data_dim)), loop_iters=300, \
        #        sigma_scale=1.0)
        #Xs = np.vstack(sample_lists["data samples"])
        #file_name = "TFD_TEST_{0:d}.png".format(i)
        #utils.visualize_samples(Xs, file_name, num_rows=15)
    file_name = "A_TFD_KLD_PRIOR_SAMPLE.png"
    Xs = OSM.sample_from_prior(20 * 20)
    utils.visualize_samples(Xs, file_name, num_rows=20)

    # test Parzen density estimator built from prior samples
    # Xs = OSM.sample_from_prior(10000)
    # [best_sigma, best_ll, best_lls] = \
    #         cross_validate_sigma(Xs, Xva, [0.09, 0.095, 0.1, 0.105, 0.11], 10)
    # sort_idx = np.argsort(best_lls)
    # sort_idx = sort_idx[0:400]
    # utils.plot_line(np.arange(sort_idx.shape[0]), best_lls[sort_idx], "A_TFD_BEST_LLS_1.png")
    # utils.visualize_samples(Xva[sort_idx], "A_TFD_BAD_FACES_1.png", num_rows=20)
    return
Пример #5
0
def test_svhn(occ_dim=15, drop_prob=0.0):
    RESULT_PATH = "IMP_SVHN_VAE/"
    #########################################
    # Format the result tag more thoroughly #
    #########################################
    dp_int = int(100.0 * drop_prob)
    result_tag = "{}VAE_OD{}_DP{}".format(RESULT_PATH, occ_dim, dp_int)

    ##########################
    # Get some training data #
    ##########################
    tr_file = 'data/svhn_train_gray.pkl'
    te_file = 'data/svhn_test_gray.pkl'
    ex_file = 'data/svhn_extra_gray.pkl'
    data = load_svhn_gray(tr_file, te_file, ex_file=ex_file, ex_count=200000)
    Xtr = to_fX( shift_and_scale_into_01(np.vstack([data['Xtr'], data['Xex']])) )
    Xva = to_fX( shift_and_scale_into_01(data['Xte']) )
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 250
    all_pix_mean = np.mean(np.mean(Xtr, axis=1))
    data_mean = to_fX( all_pix_mean * np.ones((Xtr.shape[1],)) )

    ############################################################
    # Setup some parameters for the Iterative Refinement Model #
    ############################################################
    obs_dim = Xtr.shape[1]
    z_dim = 100
    imp_steps = 15 # we'll check for the best step count (found oracularly)
    init_scale = 1.0

    x_in_sym = T.matrix('x_in_sym')
    x_out_sym = T.matrix('x_out_sym')
    x_mask_sym = T.matrix('x_mask_sym')

    #################
    # p_zi_given_xi #
    #################
    params = {}
    shared_config = [obs_dim, 1000, 1000]
    top_config = [shared_config[-1], z_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = relu_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_zi_given_xi = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_zi_given_xi.init_biases(0.2)
    ###################
    # p_xip1_given_zi #
    ###################
    params = {}
    shared_config = [z_dim, 1000, 1000]
    output_config = [obs_dim, obs_dim]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['activation'] = relu_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    p_xip1_given_zi = HydraNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    p_xip1_given_zi.init_biases(0.2)
    ###################
    # q_zi_given_x_xi #
    ###################
    params = {}
    shared_config = [(obs_dim + obs_dim), 1000, 1000]
    top_config = [shared_config[-1], z_dim]
    params['shared_config'] = shared_config
    params['mu_config'] = top_config
    params['sigma_config'] = top_config
    params['activation'] = relu_actfun
    params['init_scale'] = init_scale
    params['lam_l2a'] = 0.0
    params['vis_drop'] = 0.0
    params['hid_drop'] = 0.0
    params['bias_noise'] = 0.0
    params['input_noise'] = 0.0
    params['build_theano_funcs'] = False
    q_zi_given_x_xi = InfNet(rng=rng, Xd=x_in_sym, \
            params=params, shared_param_dicts=None)
    q_zi_given_x_xi.init_biases(0.2)


    ###########################################################
    # Define parameters for the GPSImputer, and initialize it #
    ###########################################################
    print("Building the GPSImputer...")
    gpsi_params = {}
    gpsi_params['obs_dim'] = obs_dim
    gpsi_params['z_dim'] = z_dim
    gpsi_params['imp_steps'] = imp_steps
    gpsi_params['step_type'] = 'jump'
    gpsi_params['x_type'] = 'bernoulli'
    gpsi_params['obs_transform'] = 'sigmoid'
    gpsi_params['use_osm_mode'] = True
    GPSI = GPSImputer(rng=rng, 
            x_in=x_in_sym, x_out=x_out_sym, x_mask=x_mask_sym, \
            p_zi_given_xi=p_zi_given_xi, \
            p_xip1_given_zi=p_xip1_given_zi, \
            q_zi_given_x_xi=q_zi_given_x_xi, \
            params=gpsi_params, \
            shared_param_dicts=None)
    #########################################################################
    # Define parameters for the underlying OneStageModel, and initialize it #
    #########################################################################
    print("Building the OneStageModel...")
    osm_params = {}
    osm_params['x_type'] = 'bernoulli'
    osm_params['xt_transform'] = 'sigmoid'
    OSM = OneStageModel(rng=rng, \
            x_in=x_in_sym, \
            p_x_given_z=p_xip1_given_zi, \
            q_z_given_x=p_zi_given_xi, \
            x_dim=obs_dim, z_dim=z_dim, \
            params=osm_params)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    log_name = "{}_RESULTS.txt".format(result_tag)
    out_file = open(log_name, 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0002
    momentum = 0.5
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(200005):
        scale = min(1.0, ((i+1) / 5000.0))
        if (((i + 1) % 15000) == 0):
            learn_rate = learn_rate * 0.92
        if (i > 10000):
            momentum = 0.90
        else:
            momentum = 0.50
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        # set sgd and objective function hyperparams for this update
        OSM.set_sgd_params(lr=scale*learn_rate, \
                           mom_1=scale*momentum, mom_2=0.99)
        OSM.set_lam_nll(lam_nll=1.0)
        OSM.set_lam_kld(lam_kld_1=1.0, lam_kld_2=0.0)
        OSM.set_lam_l2w(1e-4)
        # perform a minibatch update and record the cost for this batch
        xb = to_fX( Xtr.take(batch_idx, axis=0) )
        result = OSM.train_joint(xb, batch_reps)
        costs = [(costs[j] + result[j]) for j in range(len(result)-1)]
        if ((i % 250) == 0):
            costs = [(v / 250.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    joint_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_cost  : {0:.4f}".format(costs[1])
            str4 = "    kld_cost  : {0:.4f}".format(costs[2])
            str5 = "    reg_cost  : {0:.4f}".format(costs[3])
            joint_str = "\n".join([str1, str2, str3, str4, str5])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if ((i % 1000) == 0):
            Xva = row_shuffle(Xva)
            # record an estimate of performance on the test set
            xi, xo, xm = construct_masked_data(Xva[0:5000], drop_prob=drop_prob, \
                                               occ_dim=occ_dim, data_mean=data_mean)
            step_nll, step_kld = GPSI.compute_per_step_cost(xi, xo, xm, sample_count=10)
            min_nll = np.min(step_nll)
            str1 = "    va_nll_bound : {}".format(min_nll)
            str2 = "    va_nll_min  : {}".format(min_nll)
            str3 = "    va_nll_final : {}".format(step_nll[-1])
            joint_str = "\n".join([str1, str2, str3])
            print(joint_str)
            out_file.write(joint_str+"\n")
            out_file.flush()
        if ((i % 10000) == 0):
            # Get some validation samples for evaluating model performance
            xb = to_fX( Xva[0:100] )
            xi, xo, xm = construct_masked_data(xb, drop_prob=drop_prob, \
                                    occ_dim=occ_dim, data_mean=data_mean)
            xi = np.repeat(xi, 2, axis=0)
            xo = np.repeat(xo, 2, axis=0)
            xm = np.repeat(xm, 2, axis=0)
            # draw some sample imputations from the model
            samp_count = xi.shape[0]
            _, model_samps = GPSI.sample_imputer(xi, xo, xm, use_guide_policy=False)
            seq_len = len(model_samps)
            seq_samps = np.zeros((seq_len*samp_count, model_samps[0].shape[1]))
            idx = 0
            for s1 in range(samp_count):
                for s2 in range(seq_len):
                    seq_samps[idx] = model_samps[s2][s1]
                    idx += 1
            file_name = "{}_samples_ng_b{}.png".format(result_tag, i)
            utils.visualize_samples(seq_samps, file_name, num_rows=20)
            # get visualizations of policy parameters
            file_name = "{}_gen_gen_weights_b{}.png".format(result_tag, i)
            W = GPSI.gen_gen_weights.get_value(borrow=False)
            utils.visualize_samples(W[:,:obs_dim], file_name, num_rows=20)
            file_name = "{}_gen_inf_weights_b{}.png".format(result_tag, i)
            W = GPSI.gen_inf_weights.get_value(borrow=False).T
            utils.visualize_samples(W[:,:obs_dim], file_name, num_rows=20)
Пример #6
0
def test_one_stage_model():
    ##########################
    # Get some training data #
    ##########################
    rng = np.random.RandomState(1234)
    Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/')
    Xtr = np.vstack((Xtr, Xva))
    Xva = Xte
    #del Xte
    tr_samples = Xtr.shape[0]
    va_samples = Xva.shape[0]
    batch_size = 128
    batch_reps = 1

    ###############################################
    # Setup some parameters for the OneStageModel #
    ###############################################
    x_dim = Xtr.shape[1]
    z_dim = 64
    x_type = 'bernoulli'
    xin_sym = T.matrix('xin_sym')

    ###############
    # p_x_given_z #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'fc',
       'in_chans': z_dim,
       'out_chans': 256,
       'activation': relu_actfun,
       'apply_bn': True}, \
      {'layer_type': 'fc',
       'in_chans': 256,
       'out_chans': 7*7*128,
       'activation': relu_actfun,
       'apply_bn': True,
       'shape_func_out': lambda x: T.reshape(x, (-1, 128, 7, 7))}, \
      {'layer_type': 'conv',
       'in_chans': 128, # in shape:  (batch, 128, 7, 7)
       'out_chans': 64, # out shape: (batch, 64, 14, 14)
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'half',
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'conv',
       'in_chans': 64, # in shape:  (batch, 64, 14, 14)
       'out_chans': 1, # out shape: (batch, 1, 28, 28)
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'half',
       'apply_bn': False,
       'shape_func_out': lambda x: T.flatten(x, 2)}, \
      {'layer_type': 'conv',
       'in_chans': 64,
       'out_chans': 1,
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'half',
       'apply_bn': False,
       'shape_func_out': lambda x: T.flatten(x, 2)} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    p_x_given_z = HydraNet(rng=rng, Xd=xin_sym, \
            params=params, shared_param_dicts=None)
    p_x_given_z.init_biases(0.0)
    ###############
    # q_z_given_x #
    ###############
    params = {}
    shared_config = \
    [ {'layer_type': 'conv',
       'in_chans': 1,   # in shape:  (batch, 784)
       'out_chans': 64, # out shape: (batch, 64, 14, 14)
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'double',
       'apply_bn': True,
       'shape_func_in': lambda x: T.reshape(x, (-1, 1, 28, 28))}, \
      {'layer_type': 'conv',
       'in_chans': 64,   # in shape:  (batch, 64, 14, 14)
       'out_chans': 128, # out shape: (batch, 128, 7, 7)
       'activation': relu_actfun,
       'filt_dim': 5,
       'conv_stride': 'double',
       'apply_bn': True,
       'shape_func_out': lambda x: T.flatten(x, 2)}, \
      {'layer_type': 'fc',
       'in_chans': 128*7*7,
       'out_chans': 256,
       'activation': relu_actfun,
       'apply_bn': True} ]
    output_config = \
    [ {'layer_type': 'fc',
       'in_chans': 256,
       'out_chans': z_dim,
       'activation': relu_actfun,
       'apply_bn': False}, \
      {'layer_type': 'fc',
       'in_chans': 256,
       'out_chans': z_dim,
       'activation': relu_actfun,
       'apply_bn': False} ]
    params['shared_config'] = shared_config
    params['output_config'] = output_config
    params['init_scale'] = 1.0
    params['build_theano_funcs'] = False
    q_z_given_x = HydraNet(rng=rng, Xd=xin_sym, \
            params=params, shared_param_dicts=None)
    q_z_given_x.init_biases(0.0)

    ##############################################################
    # Define parameters for the TwoStageModel, and initialize it #
    ##############################################################
    print("Building the OneStageModel...")
    osm_params = {}
    osm_params['x_type'] = x_type
    osm_params['obs_transform'] = 'sigmoid'
    OSM = OneStageModel(rng=rng,
                        x_in=xin_sym,
                        x_dim=x_dim,
                        z_dim=z_dim,
                        p_x_given_z=p_x_given_z,
                        q_z_given_x=q_z_given_x,
                        params=osm_params)

    ################################################################
    # Apply some updates, to check that they aren't totally broken #
    ################################################################
    log_name = "{}_RESULTS.txt".format("OSM_TEST")
    out_file = open(log_name, 'wb')
    costs = [0. for i in range(10)]
    learn_rate = 0.0005
    momentum = 0.9
    batch_idx = np.arange(batch_size) + tr_samples
    for i in range(500000):
        scale = min(0.5, ((i + 1) / 5000.0))
        if (((i + 1) % 10000) == 0):
            learn_rate = learn_rate * 0.95
        # get the indices of training samples for this batch update
        batch_idx += batch_size
        if (np.max(batch_idx) >= tr_samples):
            # we finished an "epoch", so we rejumble the training set
            Xtr = row_shuffle(Xtr)
            batch_idx = np.arange(batch_size)
        Xb = to_fX(Xtr.take(batch_idx, axis=0))
        #Xb = binarize_data(Xtr.take(batch_idx, axis=0))
        # set sgd and objective function hyperparams for this update
        OSM.set_sgd_params(lr=scale*learn_rate, \
                           mom_1=(scale*momentum), mom_2=0.98)
        OSM.set_lam_nll(lam_nll=1.0)
        OSM.set_lam_kld(lam_kld=1.0)
        OSM.set_lam_l2w(1e-5)
        # perform a minibatch update and record the cost for this batch
        result = OSM.train_joint(Xb, batch_reps)
        costs = [(costs[j] + result[j]) for j in range(len(result))]
        if ((i % 500) == 0):
            costs = [(v / 500.0) for v in costs]
            str1 = "-- batch {0:d} --".format(i)
            str2 = "    joint_cost: {0:.4f}".format(costs[0])
            str3 = "    nll_cost  : {0:.4f}".format(costs[1])
            str4 = "    kld_cost  : {0:.4f}".format(costs[2])
            str5 = "    reg_cost  : {0:.4f}".format(costs[3])
            joint_str = "\n".join([str1, str2, str3, str4, str5])
            print(joint_str)
            out_file.write(joint_str + "\n")
            out_file.flush()
            costs = [0.0 for v in costs]
        if (((i % 5000) == 0) or ((i < 10000) and ((i % 1000) == 0))):
            # draw some independent random samples from the model
            samp_count = 300
            model_samps = OSM.sample_from_prior(samp_count)
            file_name = "OSM_SAMPLES_b{0:d}.png".format(i)
            utils.visualize_samples(model_samps, file_name, num_rows=15)
            # compute free energy estimate for validation samples
            Xva = row_shuffle(Xva)
            fe_terms = OSM.compute_fe_terms(Xva[0:5000], 20)
            fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1])
            out_str = "    nll_bound : {0:.4f}".format(fe_mean)
            print(out_str)
            out_file.write(out_str + "\n")
            out_file.flush()
    return