def __init__(self, rng=None, Xd=None, Xc=None, Xm=None, Xt=None, \ i_net=None, g_net=None, d_net=None, chain_len=None, \ data_dim=None, prior_dim=None, params=None): # Do some stuff! self.rng = RandStream(rng.randint(100000)) self.data_dim = data_dim self.prior_dim = prior_dim self.prior_mean = 0.0 self.prior_logvar = 0.0 if params is None: self.params = {} else: self.params = params if 'cost_decay' in self.params: self.cost_decay = self.params['cost_decay'] else: self.cost_decay = 0.1 if 'chain_type' in self.params: assert((self.params['chain_type'] == 'walkback') or \ (self.params['chain_type'] == 'walkout')) self.chain_type = self.params['chain_type'] else: self.chain_type = 'walkout' if 'xt_transform' in self.params: assert((self.params['xt_transform'] == 'sigmoid') or \ (self.params['xt_transform'] == 'none')) if self.params['xt_transform'] == 'sigmoid': self.xt_transform = lambda x: T.nnet.sigmoid(x) else: self.xt_transform = lambda x: x else: self.xt_transform = lambda x: T.nnet.sigmoid(x) if 'logvar_bound' in self.params: self.logvar_bound = self.params['logvar_bound'] else: self.logvar_bound = 10 # # x_type: this tells if we're using bernoulli or gaussian model for # the observations # self.x_type = self.params['x_type'] assert((self.x_type == 'bernoulli') or (self.x_type == 'gaussian')) # symbolic var for inputting samples for initializing the VAE chain self.Xd = Xd # symbolic var for masking subsets of the state variables self.Xm = Xm # symbolic var for controlling subsets of the state variables self.Xc = Xc # symbolic var for inputting samples from the target distribution self.Xt = Xt # integer number of times to cycle the VAE loop self.chain_len = chain_len # symbolic matrix of indices for data inputs self.It = T.arange(self.Xt.shape[0]) # symbolic matrix of indices for noise/generated inputs self.Id = T.arange(self.chain_len * self.Xd.shape[0]) + self.Xt.shape[0] # get a clone of the desired VAE, for easy access self.OSM = OneStageModel(rng=rng, Xd=self.Xd, Xc=self.Xc, Xm=self.Xm, \ p_x_given_z=g_net, q_z_given_x=i_net, x_dim=self.data_dim, \ z_dim=self.prior_dim, params=self.params) self.IN = self.OSM.q_z_given_x self.GN = self.OSM.p_x_given_z self.transform_x_to_z = self.OSM.transform_x_to_z self.transform_z_to_x = self.OSM.transform_z_to_x self.bounded_logvar = self.OSM.bounded_logvar # self-loop some clones of the main VAE into a chain. # ** All VAEs in the chain share the same Xc and Xm, which are the # symbolic inputs for providing the observed portion of the input # and a mask indicating which part of the input is "observed". # These inputs are used for training "reconstruction" policies. self.IN_chain = [] self.GN_chain = [] self.Xg_chain = [] _Xd = self.Xd print("Unrolling chain...") for i in range(self.chain_len): # create a VAE infer/generate pair with _Xd as input and with # masking variables shared by all VAEs in this chain _IN = self.IN.shared_param_clone(rng=rng, \ Xd=apply_mask(Xd=_Xd, Xc=self.Xc, Xm=self.Xm), \ build_funcs=False) _GN = self.GN.shared_param_clone(rng=rng, Xd=_IN.output, \ build_funcs=False) _Xd = self.xt_transform(_GN.output_mean) self.IN_chain.append(_IN) self.GN_chain.append(_GN) self.Xg_chain.append(_Xd) print(" step {}...".format(i)) # make a clone of the desired discriminator network, which will try # to discriminate between samples from the training data and samples # generated by the self-looped VAE chain. self.DN = d_net.shared_param_clone(rng=rng, \ Xd=T.vertical_stack(self.Xt, *self.Xg_chain)) zero_ary = np.zeros((1,)).astype(theano.config.floatX) # init shared var for weighting nll of data given posterior sample self.lam_chain_nll = theano.shared(value=zero_ary, name='vcg_lam_chain_nll') self.set_lam_chain_nll(lam_chain_nll=1.0) # init shared var for weighting posterior KL-div from prior self.lam_chain_kld = theano.shared(value=zero_ary, name='vcg_lam_chain_kld') self.set_lam_chain_kld(lam_chain_kld=1.0) # init shared var for controlling l2 regularization on params self.lam_l2w = theano.shared(value=zero_ary, name='vcg_lam_l2w') self.set_lam_l2w(lam_l2w=1e-4) # shared var learning rates for all networks self.lr_dn = theano.shared(value=zero_ary, name='vcg_lr_dn') self.lr_gn = theano.shared(value=zero_ary, name='vcg_lr_gn') self.lr_in = theano.shared(value=zero_ary, name='vcg_lr_in') # shared var momentum parameters for all networks self.mom_1 = theano.shared(value=zero_ary, name='vcg_mom_1') self.mom_2 = theano.shared(value=zero_ary, name='vcg_mom_2') self.it_count = theano.shared(value=zero_ary, name='vcg_it_count') # shared var weights for adversarial classification objective self.dw_dn = theano.shared(value=zero_ary, name='vcg_dw_dn') self.dw_gn = theano.shared(value=zero_ary, name='vcg_dw_gn') # init parameters for controlling learning dynamics self.set_all_sgd_params() self.set_disc_weights() # init adversarial cost weights for GN/DN # set a shared var for regularizing the output of the discriminator self.lam_l2d = theano.shared(value=(zero_ary + params['lam_l2d']), \ name='vcg_lam_l2d') # Grab the full set of "optimizable" parameters from the generator # and discriminator networks that we'll be working with. We need to # ignore parameters in the final layers of the proto-networks in the # discriminator network (a generalized pseudo-ensemble). We ignore them # because the VCGair requires that they be "bypassed" in favor of some # binary classification layers that will be managed by this VCGair. self.dn_params = [] for pn in self.DN.proto_nets: for pnl in pn[0:-1]: self.dn_params.extend(pnl.params) self.in_params = [p for p in self.IN.mlp_params] self.in_params.append(self.OSM.output_logvar) self.gn_params = [p for p in self.GN.mlp_params] self.joint_params = self.in_params + self.gn_params + self.dn_params # Now construct a binary discriminator layer for each proto-net in the # discriminator network. And, add their params to optimization list. self._construct_disc_layers(rng) self.disc_reg_cost = self.lam_l2d[0] * \ T.sum([dl.act_l2_sum for dl in self.disc_layers]) # Construct costs for the generator and discriminator networks based # on adversarial binary classification self.disc_cost_dn, self.disc_cost_gn = self._construct_disc_costs() # first, build the cost to be optimized by the discriminator network, # in general this will be treated somewhat indepedently of the # optimization of the generator and inferencer networks. self.dn_cost = self.disc_cost_dn + self.DN.act_reg_cost + \ self.disc_reg_cost # construct costs relevant to the optimization of the generator and # discriminator networks self.chain_nll_cost = self.lam_chain_nll[0] * \ self._construct_chain_nll_cost(cost_decay=self.cost_decay) self.chain_kld_cost = self.lam_chain_kld[0] * \ self._construct_chain_kld_cost(cost_decay=self.cost_decay) self.other_reg_cost = self._construct_other_reg_cost() self.osm_cost = self.disc_cost_gn + self.chain_nll_cost + \ self.chain_kld_cost + self.other_reg_cost # compute total cost on the discriminator and VB generator/inferencer self.joint_cost = self.dn_cost + self.osm_cost # Get the gradient of the joint cost for all optimizable parameters self.joint_grads = OrderedDict() print("Computing VCGLoop DN cost gradients...") grad_list = T.grad(self.dn_cost, self.dn_params, disconnected_inputs='warn') for i, p in enumerate(self.dn_params): self.joint_grads[p] = grad_list[i] print("Computing VCGLoop IN cost gradients...") grad_list = T.grad(self.osm_cost, self.in_params, disconnected_inputs='warn') for i, p in enumerate(self.in_params): self.joint_grads[p] = grad_list[i] print("Computing VCGLoop GN cost gradients...") grad_list = T.grad(self.osm_cost, self.gn_params, disconnected_inputs='warn') for i, p in enumerate(self.gn_params): self.joint_grads[p] = grad_list[i] # construct the updates for the discriminator, generator and # inferencer networks. all networks share the same first/second # moment momentum and iteration count. the networks each have their # own learning rates, which lets you turn their learning on/off. self.dn_updates = get_param_updates(params=self.dn_params, \ grads=self.joint_grads, alpha=self.lr_dn, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) self.gn_updates = get_param_updates(params=self.gn_params, \ grads=self.joint_grads, alpha=self.lr_gn, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) self.in_updates = get_param_updates(params=self.in_params, \ grads=self.joint_grads, alpha=self.lr_in, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) # bag up all the updates required for training self.joint_updates = OrderedDict() for k in self.dn_updates: self.joint_updates[k] = self.dn_updates[k] for k in self.gn_updates: self.joint_updates[k] = self.gn_updates[k] for k in self.in_updates: self.joint_updates[k] = self.in_updates[k] # construct an update for tracking the mean KL divergence of # approximate posteriors for this chain new_kld_mean = (0.98 * self.IN.kld_mean) + ((0.02 / self.chain_len) * \ sum([T.mean(I_N.kld_cost) for I_N in self.IN_chain])) self.joint_updates[self.IN.kld_mean] = T.cast(new_kld_mean, 'floatX') # construct the function for training on training data print("Compiling VCGLoop theano functions....") self.train_joint = self._construct_train_joint() return
def pretrain_osm(lam_kld=0.0): # Initialize a source of randomness rng = np.random.RandomState(1234) # Load some data to train/validate/test with data_file = 'data/tfd_data_48x48.pkl' dataset = load_tfd(tfd_pkl_name=data_file, which_set='unlabeled', fold='all') Xtr_unlabeled = dataset[0] dataset = load_tfd(tfd_pkl_name=data_file, which_set='train', fold='all') Xtr_train = dataset[0] Xtr = np.vstack([Xtr_unlabeled, Xtr_train]) dataset = load_tfd(tfd_pkl_name=data_file, which_set='valid', fold='all') Xva = dataset[0] tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] batch_size = 400 batch_reps = 6 carry_frac = 0.25 carry_size = int(batch_size * carry_frac) reset_prob = 0.04 # setup some symbolic variables and stuff Xd = T.matrix('Xd_base') Xc = T.matrix('Xc_base') Xm = T.matrix('Xm_base') data_dim = Xtr.shape[1] prior_sigma = 1.0 Xtr_mean = np.mean(Xtr, axis=0) ########################## # NETWORK CONFIGURATIONS # ########################## gn_params = {} shared_config = [PRIOR_DIM, 1500, 1500] top_config = [shared_config[-1], data_dim] gn_params['shared_config'] = shared_config gn_params['mu_config'] = top_config gn_params['sigma_config'] = top_config gn_params['activation'] = relu_actfun gn_params['init_scale'] = 1.4 gn_params['lam_l2a'] = 0.0 gn_params['vis_drop'] = 0.0 gn_params['hid_drop'] = 0.0 gn_params['bias_noise'] = 0.0 gn_params['input_noise'] = 0.0 # choose some parameters for the continuous inferencer in_params = {} shared_config = [data_dim, 1500, 1500] top_config = [shared_config[-1], PRIOR_DIM] in_params['shared_config'] = shared_config in_params['mu_config'] = top_config in_params['sigma_config'] = top_config in_params['activation'] = relu_actfun in_params['init_scale'] = 1.4 in_params['lam_l2a'] = 0.0 in_params['vis_drop'] = 0.0 in_params['hid_drop'] = 0.0 in_params['bias_noise'] = 0.0 in_params['input_noise'] = 0.0 # Initialize the base networks for this OneStageModel IN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \ params=in_params, shared_param_dicts=None) GN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \ params=gn_params, shared_param_dicts=None) # Initialize biases in IN and GN IN.init_biases(0.2) GN.init_biases(0.2) ###################################### # LOAD AND RESTART FROM SAVED PARAMS # ###################################### # gn_fname = RESULT_PATH+"pt_osm_params_b110000_GN.pkl" # in_fname = RESULT_PATH+"pt_osm_params_b110000_IN.pkl" # IN = load_infnet_from_file(f_name=in_fname, rng=rng, Xd=Xd, \ # new_params=None) # GN = load_infnet_from_file(f_name=gn_fname, rng=rng, Xd=Xd, \ # new_params=None) # in_params = IN.params # gn_params = GN.params ######################### # INITIALIZE THE GIPAIR # ######################### osm_params = {} osm_params['x_type'] = 'bernoulli' osm_params['xt_transform'] = 'sigmoid' osm_params['logvar_bound'] = LOGVAR_BOUND OSM = OneStageModel(rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, \ p_x_given_z=GN, q_z_given_x=IN, \ x_dim=data_dim, z_dim=PRIOR_DIM, params=osm_params) OSM.set_lam_l2w(1e-5) safe_mean = (0.9 * Xtr_mean) + 0.05 safe_mean_logit = np.log(safe_mean / (1.0 - safe_mean)) OSM.set_output_bias(safe_mean_logit) OSM.set_input_bias(-Xtr_mean) ###################### # BASIC VAE TRAINING # ###################### out_file = open(RESULT_PATH+"pt_osm_results.txt", 'wb') # Set initial learning rate and basic SGD hyper parameters obs_costs = np.zeros((batch_size,)) costs = [0. for i in range(10)] learn_rate = 0.002 for i in range(200000): scale = min(1.0, float(i) / 5000.0) if ((i > 1) and ((i % 20000) == 0)): learn_rate = learn_rate * 0.8 if (i < 50000): momentum = 0.5 elif (i < 10000): momentum = 0.7 else: momentum = 0.9 if ((i == 0) or (npr.rand() < reset_prob)): # sample a fully random batch batch_idx = npr.randint(low=0,high=tr_samples,size=(batch_size,)) else: # sample a partially random batch, which retains some portion of # the worst scoring examples from the previous batch fresh_idx = npr.randint(low=0,high=tr_samples,size=(batch_size-carry_size,)) batch_idx = np.concatenate((fresh_idx.ravel(), carry_idx.ravel())) # do a minibatch update of the model, and compute some costs tr_idx = npr.randint(low=0,high=tr_samples,size=(batch_size,)) Xd_batch = Xtr.take(tr_idx, axis=0) Xc_batch = 0.0 * Xd_batch Xm_batch = 0.0 * Xd_batch # do a minibatch update of the model, and compute some costs OSM.set_sgd_params(lr_1=(scale*learn_rate), \ mom_1=(scale*momentum), mom_2=0.98) OSM.set_lam_nll(1.0) OSM.set_lam_kld(lam_kld_1=scale*lam_kld, lam_kld_2=0.0, lam_kld_c=50.0) result = OSM.train_joint(Xd_batch, Xc_batch, Xm_batch, batch_reps) batch_costs = result[4] + result[5] obs_costs = collect_obs_costs(batch_costs, batch_reps) carry_idx = batch_idx[np.argsort(-obs_costs)[0:carry_size]] costs = [(costs[j] + result[j]) for j in range(len(result))] if ((i % 1000) == 0): # record and then reset the cost trackers costs = [(v / 1000.0) for v in costs] str_1 = "-- batch {0:d} --".format(i) str_2 = " joint_cost: {0:.4f}".format(costs[0]) str_3 = " nll_cost : {0:.4f}".format(costs[1]) str_4 = " kld_cost : {0:.4f}".format(costs[2]) str_5 = " reg_cost : {0:.4f}".format(costs[3]) costs = [0.0 for v in costs] # print out some diagnostic information joint_str = "\n".join([str_1, str_2, str_3, str_4, str_5]) print(joint_str) out_file.write(joint_str+"\n") out_file.flush() if ((i % 2000) == 0): Xva = row_shuffle(Xva) model_samps = OSM.sample_from_prior(500) file_name = RESULT_PATH+"pt_osm_samples_b{0:d}_XG.png".format(i) utils.visualize_samples(model_samps, file_name, num_rows=20) file_name = RESULT_PATH+"pt_osm_inf_weights_b{0:d}.png".format(i) utils.visualize_samples(OSM.inf_weights.get_value(borrow=False).T, \ file_name, num_rows=30) file_name = RESULT_PATH+"pt_osm_gen_weights_b{0:d}.png".format(i) utils.visualize_samples(OSM.gen_weights.get_value(borrow=False), \ file_name, num_rows=30) # compute information about free-energy on validation set file_name = RESULT_PATH+"pt_osm_free_energy_b{0:d}.png".format(i) fe_terms = OSM.compute_fe_terms(Xva[0:2500], 20) fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1]) fe_str = " nll_bound : {0:.4f}".format(fe_mean) print(fe_str) out_file.write(fe_str+"\n") utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \ x_label='Posterior KLd', y_label='Negative Log-likelihood') # compute information about posterior KLds on validation set file_name = RESULT_PATH+"pt_osm_post_klds_b{0:d}.png".format(i) post_klds = OSM.compute_post_klds(Xva[0:2500]) post_dim_klds = np.mean(post_klds, axis=0) utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \ file_name) if ((i % 5000) == 0): IN.save_to_file(f_name=RESULT_PATH+"pt_osm_params_b{0:d}_IN.pkl".format(i)) GN.save_to_file(f_name=RESULT_PATH+"pt_osm_params_b{0:d}_GN.pkl".format(i)) IN.save_to_file(f_name=RESULT_PATH+"pt_osm_params_IN.pkl") GN.save_to_file(f_name=RESULT_PATH+"pt_osm_params_GN.pkl") return
class VCGLoop(object): """ Controller for training a self-looping VAE using guidance provided by a classifier. The classifier tries to discriminate between samples generated by the looped VAE while the VAE minimizes a variational generative model objective and also shifts mass away from regions where the classifier can discern that the generated data is denser than the training data. This class can also train "policies" for reconstructing partially masked inputs. A reconstruction policy can readily be trained to share the same parameters as a policy for generating transitions while self-looping. The generator must be an instance of the InfNet class implemented in "InfNet.py". The discriminator must be an instance of the PeaNet class, as implemented in "PeaNet.py". The inferencer must be an instance of the InfNet class implemented in "InfNet.py". Parameters: rng: numpy.random.RandomState (for reproducibility) Xd: symbolic var for providing points for starting the Markov Chain Xc: symbolic var for providing points for starting the Markov Chain Xm: symbolic var for providing masks to mix Xd with Xc Xt: symbolic var for providing samples from the target distribution i_net: The InfNet instance that will serve as the inferencer g_net: The InfNet instance that will serve as the generator d_net: The PeaNet instance that will serve as the discriminator chain_len: number of steps to unroll the VAE Markov Chain data_dim: dimension of the generated data prior_dim: dimension of the model prior params: a dict of parameters for controlling various costs x_type: can be "bernoulli" or "gaussian" xt_transform: optional transform for gaussian means logvar_bound: optional bound on gaussian output logvar cost_decay: rate of decay for VAE costs in unrolled chain chain_type: can be 'walkout' or 'walkback' lam_l2d: regularization on squared discriminator output """ def __init__(self, rng=None, Xd=None, Xc=None, Xm=None, Xt=None, \ i_net=None, g_net=None, d_net=None, chain_len=None, \ data_dim=None, prior_dim=None, params=None): # Do some stuff! self.rng = RandStream(rng.randint(100000)) self.data_dim = data_dim self.prior_dim = prior_dim self.prior_mean = 0.0 self.prior_logvar = 0.0 if params is None: self.params = {} else: self.params = params if 'cost_decay' in self.params: self.cost_decay = self.params['cost_decay'] else: self.cost_decay = 0.1 if 'chain_type' in self.params: assert((self.params['chain_type'] == 'walkback') or \ (self.params['chain_type'] == 'walkout')) self.chain_type = self.params['chain_type'] else: self.chain_type = 'walkout' if 'xt_transform' in self.params: assert((self.params['xt_transform'] == 'sigmoid') or \ (self.params['xt_transform'] == 'none')) if self.params['xt_transform'] == 'sigmoid': self.xt_transform = lambda x: T.nnet.sigmoid(x) else: self.xt_transform = lambda x: x else: self.xt_transform = lambda x: T.nnet.sigmoid(x) if 'logvar_bound' in self.params: self.logvar_bound = self.params['logvar_bound'] else: self.logvar_bound = 10 # # x_type: this tells if we're using bernoulli or gaussian model for # the observations # self.x_type = self.params['x_type'] assert((self.x_type == 'bernoulli') or (self.x_type == 'gaussian')) # symbolic var for inputting samples for initializing the VAE chain self.Xd = Xd # symbolic var for masking subsets of the state variables self.Xm = Xm # symbolic var for controlling subsets of the state variables self.Xc = Xc # symbolic var for inputting samples from the target distribution self.Xt = Xt # integer number of times to cycle the VAE loop self.chain_len = chain_len # symbolic matrix of indices for data inputs self.It = T.arange(self.Xt.shape[0]) # symbolic matrix of indices for noise/generated inputs self.Id = T.arange(self.chain_len * self.Xd.shape[0]) + self.Xt.shape[0] # get a clone of the desired VAE, for easy access self.OSM = OneStageModel(rng=rng, Xd=self.Xd, Xc=self.Xc, Xm=self.Xm, \ p_x_given_z=g_net, q_z_given_x=i_net, x_dim=self.data_dim, \ z_dim=self.prior_dim, params=self.params) self.IN = self.OSM.q_z_given_x self.GN = self.OSM.p_x_given_z self.transform_x_to_z = self.OSM.transform_x_to_z self.transform_z_to_x = self.OSM.transform_z_to_x self.bounded_logvar = self.OSM.bounded_logvar # self-loop some clones of the main VAE into a chain. # ** All VAEs in the chain share the same Xc and Xm, which are the # symbolic inputs for providing the observed portion of the input # and a mask indicating which part of the input is "observed". # These inputs are used for training "reconstruction" policies. self.IN_chain = [] self.GN_chain = [] self.Xg_chain = [] _Xd = self.Xd print("Unrolling chain...") for i in range(self.chain_len): # create a VAE infer/generate pair with _Xd as input and with # masking variables shared by all VAEs in this chain _IN = self.IN.shared_param_clone(rng=rng, \ Xd=apply_mask(Xd=_Xd, Xc=self.Xc, Xm=self.Xm), \ build_funcs=False) _GN = self.GN.shared_param_clone(rng=rng, Xd=_IN.output, \ build_funcs=False) _Xd = self.xt_transform(_GN.output_mean) self.IN_chain.append(_IN) self.GN_chain.append(_GN) self.Xg_chain.append(_Xd) print(" step {}...".format(i)) # make a clone of the desired discriminator network, which will try # to discriminate between samples from the training data and samples # generated by the self-looped VAE chain. self.DN = d_net.shared_param_clone(rng=rng, \ Xd=T.vertical_stack(self.Xt, *self.Xg_chain)) zero_ary = np.zeros((1,)).astype(theano.config.floatX) # init shared var for weighting nll of data given posterior sample self.lam_chain_nll = theano.shared(value=zero_ary, name='vcg_lam_chain_nll') self.set_lam_chain_nll(lam_chain_nll=1.0) # init shared var for weighting posterior KL-div from prior self.lam_chain_kld = theano.shared(value=zero_ary, name='vcg_lam_chain_kld') self.set_lam_chain_kld(lam_chain_kld=1.0) # init shared var for controlling l2 regularization on params self.lam_l2w = theano.shared(value=zero_ary, name='vcg_lam_l2w') self.set_lam_l2w(lam_l2w=1e-4) # shared var learning rates for all networks self.lr_dn = theano.shared(value=zero_ary, name='vcg_lr_dn') self.lr_gn = theano.shared(value=zero_ary, name='vcg_lr_gn') self.lr_in = theano.shared(value=zero_ary, name='vcg_lr_in') # shared var momentum parameters for all networks self.mom_1 = theano.shared(value=zero_ary, name='vcg_mom_1') self.mom_2 = theano.shared(value=zero_ary, name='vcg_mom_2') self.it_count = theano.shared(value=zero_ary, name='vcg_it_count') # shared var weights for adversarial classification objective self.dw_dn = theano.shared(value=zero_ary, name='vcg_dw_dn') self.dw_gn = theano.shared(value=zero_ary, name='vcg_dw_gn') # init parameters for controlling learning dynamics self.set_all_sgd_params() self.set_disc_weights() # init adversarial cost weights for GN/DN # set a shared var for regularizing the output of the discriminator self.lam_l2d = theano.shared(value=(zero_ary + params['lam_l2d']), \ name='vcg_lam_l2d') # Grab the full set of "optimizable" parameters from the generator # and discriminator networks that we'll be working with. We need to # ignore parameters in the final layers of the proto-networks in the # discriminator network (a generalized pseudo-ensemble). We ignore them # because the VCGair requires that they be "bypassed" in favor of some # binary classification layers that will be managed by this VCGair. self.dn_params = [] for pn in self.DN.proto_nets: for pnl in pn[0:-1]: self.dn_params.extend(pnl.params) self.in_params = [p for p in self.IN.mlp_params] self.in_params.append(self.OSM.output_logvar) self.gn_params = [p for p in self.GN.mlp_params] self.joint_params = self.in_params + self.gn_params + self.dn_params # Now construct a binary discriminator layer for each proto-net in the # discriminator network. And, add their params to optimization list. self._construct_disc_layers(rng) self.disc_reg_cost = self.lam_l2d[0] * \ T.sum([dl.act_l2_sum for dl in self.disc_layers]) # Construct costs for the generator and discriminator networks based # on adversarial binary classification self.disc_cost_dn, self.disc_cost_gn = self._construct_disc_costs() # first, build the cost to be optimized by the discriminator network, # in general this will be treated somewhat indepedently of the # optimization of the generator and inferencer networks. self.dn_cost = self.disc_cost_dn + self.DN.act_reg_cost + \ self.disc_reg_cost # construct costs relevant to the optimization of the generator and # discriminator networks self.chain_nll_cost = self.lam_chain_nll[0] * \ self._construct_chain_nll_cost(cost_decay=self.cost_decay) self.chain_kld_cost = self.lam_chain_kld[0] * \ self._construct_chain_kld_cost(cost_decay=self.cost_decay) self.other_reg_cost = self._construct_other_reg_cost() self.osm_cost = self.disc_cost_gn + self.chain_nll_cost + \ self.chain_kld_cost + self.other_reg_cost # compute total cost on the discriminator and VB generator/inferencer self.joint_cost = self.dn_cost + self.osm_cost # Get the gradient of the joint cost for all optimizable parameters self.joint_grads = OrderedDict() print("Computing VCGLoop DN cost gradients...") grad_list = T.grad(self.dn_cost, self.dn_params, disconnected_inputs='warn') for i, p in enumerate(self.dn_params): self.joint_grads[p] = grad_list[i] print("Computing VCGLoop IN cost gradients...") grad_list = T.grad(self.osm_cost, self.in_params, disconnected_inputs='warn') for i, p in enumerate(self.in_params): self.joint_grads[p] = grad_list[i] print("Computing VCGLoop GN cost gradients...") grad_list = T.grad(self.osm_cost, self.gn_params, disconnected_inputs='warn') for i, p in enumerate(self.gn_params): self.joint_grads[p] = grad_list[i] # construct the updates for the discriminator, generator and # inferencer networks. all networks share the same first/second # moment momentum and iteration count. the networks each have their # own learning rates, which lets you turn their learning on/off. self.dn_updates = get_param_updates(params=self.dn_params, \ grads=self.joint_grads, alpha=self.lr_dn, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) self.gn_updates = get_param_updates(params=self.gn_params, \ grads=self.joint_grads, alpha=self.lr_gn, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) self.in_updates = get_param_updates(params=self.in_params, \ grads=self.joint_grads, alpha=self.lr_in, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) # bag up all the updates required for training self.joint_updates = OrderedDict() for k in self.dn_updates: self.joint_updates[k] = self.dn_updates[k] for k in self.gn_updates: self.joint_updates[k] = self.gn_updates[k] for k in self.in_updates: self.joint_updates[k] = self.in_updates[k] # construct an update for tracking the mean KL divergence of # approximate posteriors for this chain new_kld_mean = (0.98 * self.IN.kld_mean) + ((0.02 / self.chain_len) * \ sum([T.mean(I_N.kld_cost) for I_N in self.IN_chain])) self.joint_updates[self.IN.kld_mean] = T.cast(new_kld_mean, 'floatX') # construct the function for training on training data print("Compiling VCGLoop theano functions....") self.train_joint = self._construct_train_joint() return def set_dn_sgd_params(self, learn_rate=0.01): """ Set learning rate for the discriminator network. """ zero_ary = np.zeros((1,)) new_lr = zero_ary + learn_rate self.lr_dn.set_value(new_lr.astype(theano.config.floatX)) return def set_in_sgd_params(self, learn_rate=0.01): """ Set learning rate for the inferencer network. """ zero_ary = np.zeros((1,)) new_lr = zero_ary + learn_rate self.lr_in.set_value(new_lr.astype(theano.config.floatX)) return def set_gn_sgd_params(self, learn_rate=0.01): """ Set learning rate for the generator network. """ zero_ary = np.zeros((1,)) new_lr = zero_ary + learn_rate self.lr_gn.set_value(new_lr.astype(theano.config.floatX)) return def set_all_sgd_params(self, learn_rate=0.01, mom_1=0.9, mom_2=0.999): """ Set learning rate and momentum parameter for all updates. """ zero_ary = np.zeros((1,)) # set learning rates to the same value new_lr = zero_ary + learn_rate self.lr_dn.set_value(new_lr.astype(theano.config.floatX)) self.lr_gn.set_value(new_lr.astype(theano.config.floatX)) self.lr_in.set_value(new_lr.astype(theano.config.floatX)) # set the first/second moment momentum parameters new_mom_1 = zero_ary + mom_1 new_mom_2 = zero_ary + mom_2 self.mom_1.set_value(new_mom_1.astype(theano.config.floatX)) self.mom_2.set_value(new_mom_2.astype(theano.config.floatX)) return def set_disc_weights(self, dweight_gn=1.0, dweight_dn=1.0): """ Set weights for the adversarial classification cost. """ zero_ary = np.zeros((1,)).astype(theano.config.floatX) new_dw_dn = zero_ary + dweight_dn self.dw_dn.set_value(new_dw_dn) new_dw_gn = zero_ary + dweight_gn self.dw_gn.set_value(new_dw_gn) return def set_lam_chain_nll(self, lam_chain_nll=1.0): """ Set weight for controlling the influence of the data likelihood. """ zero_ary = np.zeros((1,)) new_lam = zero_ary + lam_chain_nll self.lam_chain_nll.set_value(new_lam.astype(theano.config.floatX)) return def set_lam_chain_kld(self, lam_chain_kld=1.0): """ Set the strength of regularization on KL-divergence for continuous posterior variables. When set to 1.0, this reproduces the standard role of KL(posterior || prior) in variational learning. """ zero_ary = np.zeros((1,)) new_lam = zero_ary + lam_chain_kld self.lam_chain_kld.set_value(new_lam.astype(theano.config.floatX)) return def set_lam_l2w(self, lam_l2w=1e-3): """ Set the relative strength of l2 regularization on network params. """ zero_ary = np.zeros((1,)) new_lam = zero_ary + lam_l2w self.lam_l2w.set_value(new_lam.astype(theano.config.floatX)) return def _construct_disc_layers(self, rng): """ Construct binary discrimination layers for each spawn-net in the underlying discrimnator pseudo-ensemble. All spawn-nets spawned from the same proto-net will use the same disc-layer parameters. """ self.disc_layers = [] self.disc_outputs = [] dn_init_scale = self.DN.init_scale for sn in self.DN.spawn_nets: # construct a "binary discriminator" layer to sit on top of each # spawn net in the discriminator pseudo-ensemble sn_fl = sn[-1] init_scale = dn_init_scale * (1. / np.sqrt(sn_fl.in_dim)) self.disc_layers.append(DiscLayer(rng=rng, \ input=sn_fl.noisy_input, in_dim=sn_fl.in_dim, \ W_scale=dn_init_scale)) # capture the (linear) output of the DiscLayer, for possible reuse self.disc_outputs.append(self.disc_layers[-1].linear_output) # get the params of this DiscLayer, for convenient optimization self.dn_params.extend(self.disc_layers[-1].params) return def _construct_disc_costs(self): """ Construct the generator and discriminator adversarial costs. """ gn_costs = [] dn_costs = [] for dl_output in self.disc_outputs: data_preds = dl_output.take(self.It, axis=0) noise_preds = dl_output.take(self.Id, axis=0) # compute the cost with respect to which we will be optimizing # the parameters of the discriminator network data_size = T.cast(self.It.size, 'floatX') noise_size = T.cast(self.Id.size, 'floatX') dnl_dn_cost = (logreg_loss(data_preds, 1.0) / data_size) + \ (logreg_loss(noise_preds, -1.0) / noise_size) # compute the cost with respect to which we will be optimizing # the parameters of the generative model dnl_gn_cost = (hinge_loss(noise_preds, 0.0) + hinge_sq_loss(noise_preds, 0.0)) / (2.0 * noise_size) dn_costs.append(dnl_dn_cost) gn_costs.append(dnl_gn_cost) dn_cost = self.dw_dn[0] * T.sum(dn_costs) gn_cost = self.dw_gn[0] * T.sum(gn_costs) return [dn_cost, gn_cost] def _log_prob_wrapper(self, x_true, x_apprx): """ Wrap log-prob with switching for bernoulli/gaussian output types. """ if self.x_type == 'bernoulli': ll_cost = log_prob_bernoulli(x_true, x_apprx) else: ll_cost = log_prob_gaussian2(x_true, x_apprx, \ log_vars=self.bounded_logvar) nll_cost = -ll_cost return nll_cost def _construct_chain_nll_cost(self, cost_decay=0.1): """ Construct the negative log-likelihood part of cost to minimize. This is for operation in "free chain" mode, where a seed point is used to initialize a long(ish) running markov chain. """ assert((cost_decay >= 0.0) and (cost_decay <= 1.0)) obs_count = T.cast(self.Xd.shape[0], 'floatX') nll_costs = [] step_weight = 1.0 step_weights = [] step_decay = cost_decay for i in range(self.chain_len): if self.chain_type == 'walkback': # train with walkback roll-outs -- reconstruct initial point IN_i = self.IN_chain[0] else: # train with walkout roll-outs -- reconstruct previous point IN_i = self.IN_chain[i] x_true = IN_i.Xd x_apprx = self.Xg_chain[i] c = T.mean(self._log_prob_wrapper(x_true, x_apprx)) nll_costs.append(step_weight * c) step_weights.append(step_weight) step_weight = step_weight * step_decay nll_cost = sum(nll_costs) / sum(step_weights) return nll_cost def _construct_chain_kld_cost(self, cost_decay=0.1): """ Construct the posterior KL-d from prior part of cost to minimize. This is for operation in "free chain" mode, where a seed point is used to initialize a long(ish) running markov chain. """ assert((cost_decay >= 0.0) and (cost_decay <= 1.0)) obs_count = T.cast(self.Xd.shape[0], 'floatX') kld_mean = self.IN.kld_mean[0] kld_costs = [] step_weight = 1.0 step_weights = [] step_decay = cost_decay for i in range(self.chain_len): IN_i = self.IN_chain[i] # basic variational term on KL divergence between post and prior kld_i = gaussian_kld(IN_i.output_mean, IN_i.output_logvar, \ self.prior_mean, self.prior_logvar) kld_i_costs = T.sum(kld_i, axis=1) # sum and reweight the KLd cost for this step in the chain c = T.mean(kld_i_costs) kld_costs.append(step_weight * c) step_weights.append(step_weight) step_weight = step_weight * step_decay kld_cost = sum(kld_costs) / sum(step_weights) return kld_cost def _construct_other_reg_cost(self): """ Construct the cost for low-level basic regularization. E.g. for applying l2 regularization to the network parameters. """ gp_cost = sum([T.sum(par**2.0) for par in self.gn_params]) ip_cost = sum([T.sum(par**2.0) for par in self.in_params]) other_reg_cost = self.lam_l2w[0] * (gp_cost + ip_cost) return other_reg_cost def _construct_train_joint(self): """ Construct theano function to train generator and discriminator jointly. """ # symbolic vars for passing input to training function xd = T.matrix() xc = T.matrix() xm = T.matrix() xt = T.matrix() batch_reps = T.lscalar() # collect outputs to return to caller outputs = [self.joint_cost, self.chain_nll_cost, self.chain_kld_cost, \ self.chain_nll_cost, self.chain_kld_cost, self.disc_cost_gn, \ self.disc_cost_dn, self.other_reg_cost] func = theano.function(inputs=[ xd, xc, xm, xt, batch_reps ], \ outputs=outputs, updates=self.joint_updates, \ givens={ self.Xd: xd.repeat(batch_reps, axis=0), \ self.Xc: xc.repeat(batch_reps, axis=0), \ self.Xm: xm.repeat(batch_reps, axis=0), \ self.Xt: xt }) return func def sample_from_chain(self, X_d, X_c=None, X_m=None, loop_iters=5, \ sigma_scale=None): """ Sample for several rounds through the I<->G loop, initialized with the the "data variable" samples in X_d. """ result = self.OSM.sample_from_chain(X_d, X_c=X_c, X_m=X_m, \ loop_iters=loop_iters, sigma_scale=sigma_scale) return result def sample_from_prior(self, samp_count): """ Draw independent samples from the model's prior, using the gaussian continuous prior of the underlying GenNet. """ Xs = self.OSM.sample_from_prior(samp_count) return Xs
def pretrain_osm(lam_kld=0.0): # Initialize a source of randomness rng = np.random.RandomState(1234) # Load some data to train/validate/test with data_file = 'data/tfd_data_48x48.pkl' dataset = load_tfd(tfd_pkl_name=data_file, which_set='unlabeled', fold='all') Xtr_unlabeled = dataset[0] dataset = load_tfd(tfd_pkl_name=data_file, which_set='train', fold='all') Xtr_train = dataset[0] Xtr = np.vstack([Xtr_unlabeled, Xtr_train]) dataset = load_tfd(tfd_pkl_name=data_file, which_set='valid', fold='all') Xva = dataset[0] tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] batch_size = 400 batch_reps = 6 carry_frac = 0.25 carry_size = int(batch_size * carry_frac) reset_prob = 0.04 # setup some symbolic variables and stuff Xd = T.matrix('Xd_base') Xc = T.matrix('Xc_base') Xm = T.matrix('Xm_base') data_dim = Xtr.shape[1] prior_sigma = 1.0 Xtr_mean = np.mean(Xtr, axis=0) ########################## # NETWORK CONFIGURATIONS # ########################## gn_params = {} shared_config = [PRIOR_DIM, 1500, 1500] top_config = [shared_config[-1], data_dim] gn_params['shared_config'] = shared_config gn_params['mu_config'] = top_config gn_params['sigma_config'] = top_config gn_params['activation'] = relu_actfun gn_params['init_scale'] = 1.4 gn_params['lam_l2a'] = 0.0 gn_params['vis_drop'] = 0.0 gn_params['hid_drop'] = 0.0 gn_params['bias_noise'] = 0.0 gn_params['input_noise'] = 0.0 # choose some parameters for the continuous inferencer in_params = {} shared_config = [data_dim, 1500, 1500] top_config = [shared_config[-1], PRIOR_DIM] in_params['shared_config'] = shared_config in_params['mu_config'] = top_config in_params['sigma_config'] = top_config in_params['activation'] = relu_actfun in_params['init_scale'] = 1.4 in_params['lam_l2a'] = 0.0 in_params['vis_drop'] = 0.0 in_params['hid_drop'] = 0.0 in_params['bias_noise'] = 0.0 in_params['input_noise'] = 0.0 # Initialize the base networks for this OneStageModel IN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \ params=in_params, shared_param_dicts=None) GN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \ params=gn_params, shared_param_dicts=None) # Initialize biases in IN and GN IN.init_biases(0.2) GN.init_biases(0.2) ###################################### # LOAD AND RESTART FROM SAVED PARAMS # ###################################### # gn_fname = RESULT_PATH+"pt_osm_params_b110000_GN.pkl" # in_fname = RESULT_PATH+"pt_osm_params_b110000_IN.pkl" # IN = load_infnet_from_file(f_name=in_fname, rng=rng, Xd=Xd, \ # new_params=None) # GN = load_infnet_from_file(f_name=gn_fname, rng=rng, Xd=Xd, \ # new_params=None) # in_params = IN.params # gn_params = GN.params ######################### # INITIALIZE THE GIPAIR # ######################### osm_params = {} osm_params['x_type'] = 'bernoulli' osm_params['xt_transform'] = 'sigmoid' osm_params['logvar_bound'] = LOGVAR_BOUND OSM = OneStageModel(rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, \ p_x_given_z=GN, q_z_given_x=IN, \ x_dim=data_dim, z_dim=PRIOR_DIM, params=osm_params) OSM.set_lam_l2w(1e-5) safe_mean = (0.9 * Xtr_mean) + 0.05 safe_mean_logit = np.log(safe_mean / (1.0 - safe_mean)) OSM.set_output_bias(safe_mean_logit) OSM.set_input_bias(-Xtr_mean) ###################### # BASIC VAE TRAINING # ###################### out_file = open(RESULT_PATH + "pt_osm_results.txt", 'wb') # Set initial learning rate and basic SGD hyper parameters obs_costs = np.zeros((batch_size, )) costs = [0. for i in range(10)] learn_rate = 0.002 for i in range(200000): scale = min(1.0, float(i) / 5000.0) if ((i > 1) and ((i % 20000) == 0)): learn_rate = learn_rate * 0.8 if (i < 50000): momentum = 0.5 elif (i < 10000): momentum = 0.7 else: momentum = 0.9 if ((i == 0) or (npr.rand() < reset_prob)): # sample a fully random batch batch_idx = npr.randint(low=0, high=tr_samples, size=(batch_size, )) else: # sample a partially random batch, which retains some portion of # the worst scoring examples from the previous batch fresh_idx = npr.randint(low=0, high=tr_samples, size=(batch_size - carry_size, )) batch_idx = np.concatenate((fresh_idx.ravel(), carry_idx.ravel())) # do a minibatch update of the model, and compute some costs tr_idx = npr.randint(low=0, high=tr_samples, size=(batch_size, )) Xd_batch = Xtr.take(tr_idx, axis=0) Xc_batch = 0.0 * Xd_batch Xm_batch = 0.0 * Xd_batch # do a minibatch update of the model, and compute some costs OSM.set_sgd_params(lr_1=(scale*learn_rate), \ mom_1=(scale*momentum), mom_2=0.98) OSM.set_lam_nll(1.0) OSM.set_lam_kld(lam_kld_1=scale * lam_kld, lam_kld_2=0.0, lam_kld_c=50.0) result = OSM.train_joint(Xd_batch, Xc_batch, Xm_batch, batch_reps) batch_costs = result[4] + result[5] obs_costs = collect_obs_costs(batch_costs, batch_reps) carry_idx = batch_idx[np.argsort(-obs_costs)[0:carry_size]] costs = [(costs[j] + result[j]) for j in range(len(result))] if ((i % 1000) == 0): # record and then reset the cost trackers costs = [(v / 1000.0) for v in costs] str_1 = "-- batch {0:d} --".format(i) str_2 = " joint_cost: {0:.4f}".format(costs[0]) str_3 = " nll_cost : {0:.4f}".format(costs[1]) str_4 = " kld_cost : {0:.4f}".format(costs[2]) str_5 = " reg_cost : {0:.4f}".format(costs[3]) costs = [0.0 for v in costs] # print out some diagnostic information joint_str = "\n".join([str_1, str_2, str_3, str_4, str_5]) print(joint_str) out_file.write(joint_str + "\n") out_file.flush() if ((i % 2000) == 0): Xva = row_shuffle(Xva) model_samps = OSM.sample_from_prior(500) file_name = RESULT_PATH + "pt_osm_samples_b{0:d}_XG.png".format(i) utils.visualize_samples(model_samps, file_name, num_rows=20) file_name = RESULT_PATH + "pt_osm_inf_weights_b{0:d}.png".format(i) utils.visualize_samples(OSM.inf_weights.get_value(borrow=False).T, \ file_name, num_rows=30) file_name = RESULT_PATH + "pt_osm_gen_weights_b{0:d}.png".format(i) utils.visualize_samples(OSM.gen_weights.get_value(borrow=False), \ file_name, num_rows=30) # compute information about free-energy on validation set file_name = RESULT_PATH + "pt_osm_free_energy_b{0:d}.png".format(i) fe_terms = OSM.compute_fe_terms(Xva[0:2500], 20) fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1]) fe_str = " nll_bound : {0:.4f}".format(fe_mean) print(fe_str) out_file.write(fe_str + "\n") utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \ x_label='Posterior KLd', y_label='Negative Log-likelihood') # compute information about posterior KLds on validation set file_name = RESULT_PATH + "pt_osm_post_klds_b{0:d}.png".format(i) post_klds = OSM.compute_post_klds(Xva[0:2500]) post_dim_klds = np.mean(post_klds, axis=0) utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \ file_name) if ((i % 5000) == 0): IN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_b{0:d}_IN.pkl".format(i)) GN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_b{0:d}_GN.pkl".format(i)) IN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_IN.pkl") GN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_GN.pkl") return
def pretrain_osm(lam_kld=0.0): # Initialize a source of randomness rng = np.random.RandomState(1234) # Load some data to train/validate/test with dataset = 'data/mnist.pkl.gz' datasets = load_udm(dataset, zero_mean=False) Xtr = datasets[0][0] Xtr = Xtr.get_value(borrow=False) Xva = datasets[2][0] Xva = Xva.get_value(borrow=False) print("Xtr.shape: {0:s}, Xva.shape: {1:s}".format(str(Xtr.shape),str(Xva.shape))) # get and set some basic dataset information Xtr_mean = np.mean(Xtr, axis=0) tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] batch_size = 100 batch_reps = 5 # setup some symbolic variables and stuff Xd = T.matrix('Xd_base') Xc = T.matrix('Xc_base') Xm = T.matrix('Xm_base') data_dim = Xtr.shape[1] prior_sigma = 1.0 ########################## # NETWORK CONFIGURATIONS # ########################## gn_params = {} shared_config = [PRIOR_DIM, 1000, 1000] top_config = [shared_config[-1], data_dim] gn_params['shared_config'] = shared_config gn_params['mu_config'] = top_config gn_params['sigma_config'] = top_config gn_params['activation'] = relu_actfun gn_params['init_scale'] = 1.4 gn_params['lam_l2a'] = 0.0 gn_params['vis_drop'] = 0.0 gn_params['hid_drop'] = 0.0 gn_params['bias_noise'] = 0.0 gn_params['input_noise'] = 0.0 # choose some parameters for the continuous inferencer in_params = {} shared_config = [data_dim, 1000, 1000] top_config = [shared_config[-1], PRIOR_DIM] in_params['shared_config'] = shared_config in_params['mu_config'] = top_config in_params['sigma_config'] = top_config in_params['activation'] = relu_actfun in_params['init_scale'] = 1.4 in_params['lam_l2a'] = 0.0 in_params['vis_drop'] = 0.0 in_params['hid_drop'] = 0.0 in_params['bias_noise'] = 0.0 in_params['input_noise'] = 0.0 # Initialize the base networks for this OneStageModel IN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \ params=in_params, shared_param_dicts=None) GN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \ params=gn_params, shared_param_dicts=None) # Initialize biases in IN and GN IN.init_biases(0.2) GN.init_biases(0.2) ######################### # INITIALIZE THE GIPAIR # ######################### osm_params = {} osm_params['x_type'] = 'bernoulli' osm_params['xt_transform'] = 'sigmoid' osm_params['logvar_bound'] = LOGVAR_BOUND OSM = OneStageModel(rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, \ p_x_given_z=GN, q_z_given_x=IN, \ x_dim=data_dim, z_dim=PRIOR_DIM, params=osm_params) OSM.set_lam_l2w(1e-5) safe_mean = (0.9 * Xtr_mean) + 0.05 safe_mean_logit = np.log(safe_mean / (1.0 - safe_mean)) OSM.set_output_bias(safe_mean_logit) OSM.set_input_bias(-Xtr_mean) ###################### # BASIC VAE TRAINING # ###################### out_file = open(RESULT_PATH+"pt_osm_results.txt", 'wb') # Set initial learning rate and basic SGD hyper parameters obs_costs = np.zeros((batch_size,)) costs = [0. for i in range(10)] learn_rate = 0.0005 for i in range(150000): scale = min(1.0, float(i) / 10000.0) if ((i > 1) and ((i % 20000) == 0)): learn_rate = learn_rate * 0.9 # do a minibatch update of the model, and compute some costs tr_idx = npr.randint(low=0,high=tr_samples,size=(batch_size,)) Xd_batch = Xtr.take(tr_idx, axis=0) Xc_batch = 0.0 * Xd_batch Xm_batch = 0.0 * Xd_batch # do a minibatch update of the model, and compute some costs OSM.set_sgd_params(lr_1=(scale*learn_rate), mom_1=0.5, mom_2=0.98) OSM.set_lam_nll(1.0) OSM.set_lam_kld(lam_kld_1=(1.0 + (scale*(lam_kld-1.0))), lam_kld_2=0.0) result = OSM.train_joint(Xd_batch, Xc_batch, Xm_batch, batch_reps) costs = [(costs[j] + result[j]) for j in range(len(result))] if ((i % 1000) == 0): # record and then reset the cost trackers costs = [(v / 1000.0) for v in costs] str_1 = "-- batch {0:d} --".format(i) str_2 = " joint_cost: {0:.4f}".format(costs[0]) str_3 = " nll_cost : {0:.4f}".format(costs[1]) str_4 = " kld_cost : {0:.4f}".format(costs[2]) str_5 = " reg_cost : {0:.4f}".format(costs[3]) costs = [0.0 for v in costs] # print out some diagnostic information joint_str = "\n".join([str_1, str_2, str_3, str_4, str_5]) print(joint_str) out_file.write(joint_str+"\n") out_file.flush() if ((i % 2000) == 0): Xva = row_shuffle(Xva) model_samps = OSM.sample_from_prior(500) file_name = RESULT_PATH+"pt_osm_samples_b{0:d}_XG.png".format(i) utils.visualize_samples(model_samps, file_name, num_rows=20) # compute information about free-energy on validation set file_name = RESULT_PATH+"pt_osm_free_energy_b{0:d}.png".format(i) fe_terms = OSM.compute_fe_terms(Xva[0:2500], 20) fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1]) fe_str = " nll_bound : {0:.4f}".format(fe_mean) print(fe_str) out_file.write(fe_str+"\n") utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \ x_label='Posterior KLd', y_label='Negative Log-likelihood') # compute information about posterior KLds on validation set file_name = RESULT_PATH+"pt_osm_post_klds_b{0:d}.png".format(i) post_klds = OSM.compute_post_klds(Xva[0:2500]) post_dim_klds = np.mean(post_klds, axis=0) utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \ file_name) if ((i % 5000) == 0): IN.save_to_file(f_name=RESULT_PATH+"pt_osm_params_b{0:d}_IN.pkl".format(i)) GN.save_to_file(f_name=RESULT_PATH+"pt_osm_params_b{0:d}_GN.pkl".format(i)) IN.save_to_file(f_name=RESULT_PATH+"pt_osm_params_IN.pkl") GN.save_to_file(f_name=RESULT_PATH+"pt_osm_params_GN.pkl") return
def test_gip_sigma_scale_tfd(): from LogPDFs import cross_validate_sigma # Simple test code, to check that everything is basically functional. print("TESTING...") # Initialize a source of randomness rng = np.random.RandomState(12345) # Load some data to train/validate/test with data_file = "data/tfd_data_48x48.pkl" dataset = load_tfd(tfd_pkl_name=data_file, which_set="unlabeled", fold="all") Xtr_unlabeled = dataset[0] dataset = load_tfd(tfd_pkl_name=data_file, which_set="train", fold="all") Xtr_train = dataset[0] Xtr = np.vstack([Xtr_unlabeled, Xtr_train]) dataset = load_tfd(tfd_pkl_name=data_file, which_set="test", fold="all") Xva = dataset[0] tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] print("Xtr.shape: {0:s}, Xva.shape: {1:s}".format(str(Xtr.shape), str(Xva.shape))) # get and set some basic dataset information tr_samples = Xtr.shape[0] data_dim = Xtr.shape[1] batch_size = 100 # Symbolic inputs Xd = T.matrix(name="Xd") Xc = T.matrix(name="Xc") Xm = T.matrix(name="Xm") Xt = T.matrix(name="Xt") # Load inferencer and generator from saved parameters gn_fname = "TFD_WALKOUT_TEST_KLD/pt_walk_params_b25000_GN.pkl" in_fname = "TFD_WALKOUT_TEST_KLD/pt_walk_params_b25000_IN.pkl" IN = load_infnet_from_file(f_name=in_fname, rng=rng, Xd=Xd) GN = load_infnet_from_file(f_name=gn_fname, rng=rng, Xd=Xd) x_dim = IN.shared_layers[0].in_dim z_dim = IN.mu_layers[-1].out_dim # construct a GIPair with the loaded InfNet and GenNet osm_params = {} osm_params["x_type"] = "gaussian" osm_params["xt_transform"] = "sigmoid" osm_params["logvar_bound"] = LOGVAR_BOUND OSM = OneStageModel( rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, p_x_given_z=GN, q_z_given_x=IN, x_dim=x_dim, z_dim=z_dim, params=osm_params ) # # compute variational likelihood bound and its sub-components Xva = row_shuffle(Xva) Xb = Xva[0:5000] # file_name = "A_TFD_POST_KLDS.png" # post_klds = OSM.compute_post_klds(Xb) # post_dim_klds = np.mean(post_klds, axis=0) # utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \ # file_name) # compute information about free-energy on validation set file_name = "A_TFD_KLD_FREE_ENERGY.png" fe_terms = OSM.compute_fe_terms(Xb, 20) utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, x_label="Posterior KLd", y_label="Negative Log-likelihood") # bound_results = OSM.compute_ll_bound(Xva) # ll_bounds = bound_results[0] # post_klds = bound_results[1] # log_likelihoods = bound_results[2] # max_lls = bound_results[3] # print("mean ll bound: {0:.4f}".format(np.mean(ll_bounds))) # print("mean posterior KLd: {0:.4f}".format(np.mean(post_klds))) # print("mean log-likelihood: {0:.4f}".format(np.mean(log_likelihoods))) # print("mean max log-likelihood: {0:.4f}".format(np.mean(max_lls))) # print("min ll bound: {0:.4f}".format(np.min(ll_bounds))) # print("max posterior KLd: {0:.4f}".format(np.max(post_klds))) # print("min log-likelihood: {0:.4f}".format(np.min(log_likelihoods))) # print("min max log-likelihood: {0:.4f}".format(np.min(max_lls))) # # compute some information about the approximate posteriors # post_stats = OSM.compute_post_stats(Xva, 0.0*Xva, 0.0*Xva) # all_post_klds = np.sort(post_stats[0].ravel()) # post KLds for each obs and dim # obs_post_klds = np.sort(post_stats[1]) # summed post KLds for each obs # post_dim_klds = post_stats[2] # average post KLds for each post dim # post_dim_vars = post_stats[3] # average squared mean for each post dim # utils.plot_line(np.arange(all_post_klds.shape[0]), all_post_klds, "AAA_ALL_POST_KLDS.png") # utils.plot_line(np.arange(obs_post_klds.shape[0]), obs_post_klds, "AAA_OBS_POST_KLDS.png") # utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, "AAA_POST_DIM_KLDS.png") # utils.plot_stem(np.arange(post_dim_vars.shape[0]), post_dim_vars, "AAA_POST_DIM_VARS.png") # draw many samples from the GIP for i in range(5): tr_idx = npr.randint(low=0, high=tr_samples, size=(100,)) Xd_batch = Xtr.take(tr_idx, axis=0) Xs = [] for row in range(3): Xs.append([]) for col in range(3): sample_lists = OSM.sample_from_chain(Xd_batch[0:10, :], loop_iters=100, sigma_scale=1.0) Xs[row].append(group_chains(sample_lists["data samples"])) Xs, block_im_dim = block_video(Xs, (48, 48), (3, 3)) to_video(Xs, block_im_dim, "A_TFD_KLD_CHAIN_VIDEO_{0:d}.avi".format(i), frame_rate=10) # sample_lists = GIP.sample_from_chain(Xd_batch[0,:].reshape((1,data_dim)), loop_iters=300, \ # sigma_scale=1.0) # Xs = np.vstack(sample_lists["data samples"]) # file_name = "TFD_TEST_{0:d}.png".format(i) # utils.visualize_samples(Xs, file_name, num_rows=15) file_name = "A_TFD_KLD_PRIOR_SAMPLE.png" Xs = OSM.sample_from_prior(20 * 20) utils.visualize_samples(Xs, file_name, num_rows=20) # test Parzen density estimator built from prior samples # Xs = OSM.sample_from_prior(10000) # [best_sigma, best_ll, best_lls] = \ # cross_validate_sigma(Xs, Xva, [0.09, 0.095, 0.1, 0.105, 0.11], 10) # sort_idx = np.argsort(best_lls) # sort_idx = sort_idx[0:400] # utils.plot_line(np.arange(sort_idx.shape[0]), best_lls[sort_idx], "A_TFD_BEST_LLS_1.png") # utils.visualize_samples(Xva[sort_idx], "A_TFD_BAD_FACES_1.png", num_rows=20) return
def pretrain_osm(lam_kld=0.0): # Initialize a source of randomness rng = np.random.RandomState(1234) # Load some data to train/validate/test with dataset = 'data/mnist.pkl.gz' datasets = load_udm(dataset, zero_mean=False) Xtr = datasets[0][0] Xtr = Xtr.get_value(borrow=False) Xva = datasets[2][0] Xva = Xva.get_value(borrow=False) print("Xtr.shape: {0:s}, Xva.shape: {1:s}".format(str(Xtr.shape), str(Xva.shape))) # get and set some basic dataset information Xtr_mean = np.mean(Xtr, axis=0) tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] batch_size = 100 batch_reps = 5 # setup some symbolic variables and stuff Xd = T.matrix('Xd_base') Xc = T.matrix('Xc_base') Xm = T.matrix('Xm_base') data_dim = Xtr.shape[1] prior_sigma = 1.0 ########################## # NETWORK CONFIGURATIONS # ########################## gn_params = {} shared_config = [PRIOR_DIM, 1000, 1000] top_config = [shared_config[-1], data_dim] gn_params['shared_config'] = shared_config gn_params['mu_config'] = top_config gn_params['sigma_config'] = top_config gn_params['activation'] = relu_actfun gn_params['init_scale'] = 1.4 gn_params['lam_l2a'] = 0.0 gn_params['vis_drop'] = 0.0 gn_params['hid_drop'] = 0.0 gn_params['bias_noise'] = 0.0 gn_params['input_noise'] = 0.0 # choose some parameters for the continuous inferencer in_params = {} shared_config = [data_dim, 1000, 1000] top_config = [shared_config[-1], PRIOR_DIM] in_params['shared_config'] = shared_config in_params['mu_config'] = top_config in_params['sigma_config'] = top_config in_params['activation'] = relu_actfun in_params['init_scale'] = 1.4 in_params['lam_l2a'] = 0.0 in_params['vis_drop'] = 0.0 in_params['hid_drop'] = 0.0 in_params['bias_noise'] = 0.0 in_params['input_noise'] = 0.0 # Initialize the base networks for this OneStageModel IN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \ params=in_params, shared_param_dicts=None) GN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \ params=gn_params, shared_param_dicts=None) # Initialize biases in IN and GN IN.init_biases(0.2) GN.init_biases(0.2) ######################### # INITIALIZE THE GIPAIR # ######################### osm_params = {} osm_params['x_type'] = 'bernoulli' osm_params['xt_transform'] = 'sigmoid' osm_params['logvar_bound'] = LOGVAR_BOUND OSM = OneStageModel(rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, \ p_x_given_z=GN, q_z_given_x=IN, \ x_dim=data_dim, z_dim=PRIOR_DIM, params=osm_params) OSM.set_lam_l2w(1e-5) safe_mean = (0.9 * Xtr_mean) + 0.05 safe_mean_logit = np.log(safe_mean / (1.0 - safe_mean)) OSM.set_output_bias(safe_mean_logit) OSM.set_input_bias(-Xtr_mean) ###################### # BASIC VAE TRAINING # ###################### out_file = open(RESULT_PATH + "pt_osm_results.txt", 'wb') # Set initial learning rate and basic SGD hyper parameters obs_costs = np.zeros((batch_size, )) costs = [0. for i in range(10)] learn_rate = 0.0005 for i in range(150000): scale = min(1.0, float(i) / 10000.0) if ((i > 1) and ((i % 20000) == 0)): learn_rate = learn_rate * 0.9 # do a minibatch update of the model, and compute some costs tr_idx = npr.randint(low=0, high=tr_samples, size=(batch_size, )) Xd_batch = Xtr.take(tr_idx, axis=0) Xc_batch = 0.0 * Xd_batch Xm_batch = 0.0 * Xd_batch # do a minibatch update of the model, and compute some costs OSM.set_sgd_params(lr_1=(scale * learn_rate), mom_1=0.5, mom_2=0.98) OSM.set_lam_nll(1.0) OSM.set_lam_kld(lam_kld_1=(1.0 + (scale * (lam_kld - 1.0))), lam_kld_2=0.0) result = OSM.train_joint(Xd_batch, Xc_batch, Xm_batch, batch_reps) costs = [(costs[j] + result[j]) for j in range(len(result))] if ((i % 1000) == 0): # record and then reset the cost trackers costs = [(v / 1000.0) for v in costs] str_1 = "-- batch {0:d} --".format(i) str_2 = " joint_cost: {0:.4f}".format(costs[0]) str_3 = " nll_cost : {0:.4f}".format(costs[1]) str_4 = " kld_cost : {0:.4f}".format(costs[2]) str_5 = " reg_cost : {0:.4f}".format(costs[3]) costs = [0.0 for v in costs] # print out some diagnostic information joint_str = "\n".join([str_1, str_2, str_3, str_4, str_5]) print(joint_str) out_file.write(joint_str + "\n") out_file.flush() if ((i % 2000) == 0): Xva = row_shuffle(Xva) model_samps = OSM.sample_from_prior(500) file_name = RESULT_PATH + "pt_osm_samples_b{0:d}_XG.png".format(i) utils.visualize_samples(model_samps, file_name, num_rows=20) # compute information about free-energy on validation set file_name = RESULT_PATH + "pt_osm_free_energy_b{0:d}.png".format(i) fe_terms = OSM.compute_fe_terms(Xva[0:2500], 20) fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1]) fe_str = " nll_bound : {0:.4f}".format(fe_mean) print(fe_str) out_file.write(fe_str + "\n") utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \ x_label='Posterior KLd', y_label='Negative Log-likelihood') # compute information about posterior KLds on validation set file_name = RESULT_PATH + "pt_osm_post_klds_b{0:d}.png".format(i) post_klds = OSM.compute_post_klds(Xva[0:2500]) post_dim_klds = np.mean(post_klds, axis=0) utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \ file_name) if ((i % 5000) == 0): IN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_b{0:d}_IN.pkl".format(i)) GN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_b{0:d}_GN.pkl".format(i)) IN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_IN.pkl") GN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_GN.pkl") return
def __init__(self, rng=None, Xd=None, Xc=None, Xm=None, Xt=None, \ i_net=None, g_net=None, d_net=None, chain_len=None, \ data_dim=None, prior_dim=None, params=None): # Do some stuff! self.rng = RandStream(rng.randint(100000)) self.data_dim = data_dim self.prior_dim = prior_dim self.prior_mean = 0.0 self.prior_logvar = 0.0 if params is None: self.params = {} else: self.params = params if 'cost_decay' in self.params: self.cost_decay = self.params['cost_decay'] else: self.cost_decay = 0.1 if 'chain_type' in self.params: assert((self.params['chain_type'] == 'walkback') or \ (self.params['chain_type'] == 'walkout')) self.chain_type = self.params['chain_type'] else: self.chain_type = 'walkout' if 'xt_transform' in self.params: assert((self.params['xt_transform'] == 'sigmoid') or \ (self.params['xt_transform'] == 'none')) if self.params['xt_transform'] == 'sigmoid': self.xt_transform = lambda x: T.nnet.sigmoid(x) else: self.xt_transform = lambda x: x else: self.xt_transform = lambda x: T.nnet.sigmoid(x) if 'logvar_bound' in self.params: self.logvar_bound = self.params['logvar_bound'] else: self.logvar_bound = 10 # # x_type: this tells if we're using bernoulli or gaussian model for # the observations # self.x_type = self.params['x_type'] assert ((self.x_type == 'bernoulli') or (self.x_type == 'gaussian')) # symbolic var for inputting samples for initializing the VAE chain self.Xd = Xd # symbolic var for masking subsets of the state variables self.Xm = Xm # symbolic var for controlling subsets of the state variables self.Xc = Xc # symbolic var for inputting samples from the target distribution self.Xt = Xt # integer number of times to cycle the VAE loop self.chain_len = chain_len # symbolic matrix of indices for data inputs self.It = T.arange(self.Xt.shape[0]) # symbolic matrix of indices for noise/generated inputs self.Id = T.arange( self.chain_len * self.Xd.shape[0]) + self.Xt.shape[0] # get a clone of the desired VAE, for easy access self.OSM = OneStageModel(rng=rng, Xd=self.Xd, Xc=self.Xc, Xm=self.Xm, \ p_x_given_z=g_net, q_z_given_x=i_net, x_dim=self.data_dim, \ z_dim=self.prior_dim, params=self.params) self.IN = self.OSM.q_z_given_x self.GN = self.OSM.p_x_given_z self.transform_x_to_z = self.OSM.transform_x_to_z self.transform_z_to_x = self.OSM.transform_z_to_x self.bounded_logvar = self.OSM.bounded_logvar # self-loop some clones of the main VAE into a chain. # ** All VAEs in the chain share the same Xc and Xm, which are the # symbolic inputs for providing the observed portion of the input # and a mask indicating which part of the input is "observed". # These inputs are used for training "reconstruction" policies. self.IN_chain = [] self.GN_chain = [] self.Xg_chain = [] _Xd = self.Xd print("Unrolling chain...") for i in range(self.chain_len): # create a VAE infer/generate pair with _Xd as input and with # masking variables shared by all VAEs in this chain _IN = self.IN.shared_param_clone(rng=rng, \ Xd=apply_mask(Xd=_Xd, Xc=self.Xc, Xm=self.Xm), \ build_funcs=False) _GN = self.GN.shared_param_clone(rng=rng, Xd=_IN.output, \ build_funcs=False) _Xd = self.xt_transform(_GN.output_mean) self.IN_chain.append(_IN) self.GN_chain.append(_GN) self.Xg_chain.append(_Xd) print(" step {}...".format(i)) # make a clone of the desired discriminator network, which will try # to discriminate between samples from the training data and samples # generated by the self-looped VAE chain. self.DN = d_net.shared_param_clone(rng=rng, \ Xd=T.vertical_stack(self.Xt, *self.Xg_chain)) zero_ary = np.zeros((1, )).astype(theano.config.floatX) # init shared var for weighting nll of data given posterior sample self.lam_chain_nll = theano.shared(value=zero_ary, name='vcg_lam_chain_nll') self.set_lam_chain_nll(lam_chain_nll=1.0) # init shared var for weighting posterior KL-div from prior self.lam_chain_kld = theano.shared(value=zero_ary, name='vcg_lam_chain_kld') self.set_lam_chain_kld(lam_chain_kld=1.0) # init shared var for controlling l2 regularization on params self.lam_l2w = theano.shared(value=zero_ary, name='vcg_lam_l2w') self.set_lam_l2w(lam_l2w=1e-4) # shared var learning rates for all networks self.lr_dn = theano.shared(value=zero_ary, name='vcg_lr_dn') self.lr_gn = theano.shared(value=zero_ary, name='vcg_lr_gn') self.lr_in = theano.shared(value=zero_ary, name='vcg_lr_in') # shared var momentum parameters for all networks self.mom_1 = theano.shared(value=zero_ary, name='vcg_mom_1') self.mom_2 = theano.shared(value=zero_ary, name='vcg_mom_2') self.it_count = theano.shared(value=zero_ary, name='vcg_it_count') # shared var weights for adversarial classification objective self.dw_dn = theano.shared(value=zero_ary, name='vcg_dw_dn') self.dw_gn = theano.shared(value=zero_ary, name='vcg_dw_gn') # init parameters for controlling learning dynamics self.set_all_sgd_params() self.set_disc_weights() # init adversarial cost weights for GN/DN # set a shared var for regularizing the output of the discriminator self.lam_l2d = theano.shared(value=(zero_ary + params['lam_l2d']), \ name='vcg_lam_l2d') # Grab the full set of "optimizable" parameters from the generator # and discriminator networks that we'll be working with. We need to # ignore parameters in the final layers of the proto-networks in the # discriminator network (a generalized pseudo-ensemble). We ignore them # because the VCGair requires that they be "bypassed" in favor of some # binary classification layers that will be managed by this VCGair. self.dn_params = [] for pn in self.DN.proto_nets: for pnl in pn[0:-1]: self.dn_params.extend(pnl.params) self.in_params = [p for p in self.IN.mlp_params] self.in_params.append(self.OSM.output_logvar) self.gn_params = [p for p in self.GN.mlp_params] self.joint_params = self.in_params + self.gn_params + self.dn_params # Now construct a binary discriminator layer for each proto-net in the # discriminator network. And, add their params to optimization list. self._construct_disc_layers(rng) self.disc_reg_cost = self.lam_l2d[0] * \ T.sum([dl.act_l2_sum for dl in self.disc_layers]) # Construct costs for the generator and discriminator networks based # on adversarial binary classification self.disc_cost_dn, self.disc_cost_gn = self._construct_disc_costs() # first, build the cost to be optimized by the discriminator network, # in general this will be treated somewhat indepedently of the # optimization of the generator and inferencer networks. self.dn_cost = self.disc_cost_dn + self.DN.act_reg_cost + \ self.disc_reg_cost # construct costs relevant to the optimization of the generator and # discriminator networks self.chain_nll_cost = self.lam_chain_nll[0] * \ self._construct_chain_nll_cost(cost_decay=self.cost_decay) self.chain_kld_cost = self.lam_chain_kld[0] * \ self._construct_chain_kld_cost(cost_decay=self.cost_decay) self.other_reg_cost = self._construct_other_reg_cost() self.osm_cost = self.disc_cost_gn + self.chain_nll_cost + \ self.chain_kld_cost + self.other_reg_cost # compute total cost on the discriminator and VB generator/inferencer self.joint_cost = self.dn_cost + self.osm_cost # Get the gradient of the joint cost for all optimizable parameters self.joint_grads = OrderedDict() print("Computing VCGLoop DN cost gradients...") grad_list = T.grad(self.dn_cost, self.dn_params, disconnected_inputs='warn') for i, p in enumerate(self.dn_params): self.joint_grads[p] = grad_list[i] print("Computing VCGLoop IN cost gradients...") grad_list = T.grad(self.osm_cost, self.in_params, disconnected_inputs='warn') for i, p in enumerate(self.in_params): self.joint_grads[p] = grad_list[i] print("Computing VCGLoop GN cost gradients...") grad_list = T.grad(self.osm_cost, self.gn_params, disconnected_inputs='warn') for i, p in enumerate(self.gn_params): self.joint_grads[p] = grad_list[i] # construct the updates for the discriminator, generator and # inferencer networks. all networks share the same first/second # moment momentum and iteration count. the networks each have their # own learning rates, which lets you turn their learning on/off. self.dn_updates = get_param_updates(params=self.dn_params, \ grads=self.joint_grads, alpha=self.lr_dn, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) self.gn_updates = get_param_updates(params=self.gn_params, \ grads=self.joint_grads, alpha=self.lr_gn, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) self.in_updates = get_param_updates(params=self.in_params, \ grads=self.joint_grads, alpha=self.lr_in, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) # bag up all the updates required for training self.joint_updates = OrderedDict() for k in self.dn_updates: self.joint_updates[k] = self.dn_updates[k] for k in self.gn_updates: self.joint_updates[k] = self.gn_updates[k] for k in self.in_updates: self.joint_updates[k] = self.in_updates[k] # construct an update for tracking the mean KL divergence of # approximate posteriors for this chain new_kld_mean = (0.98 * self.IN.kld_mean) + ((0.02 / self.chain_len) * \ sum([T.mean(I_N.kld_cost) for I_N in self.IN_chain])) self.joint_updates[self.IN.kld_mean] = T.cast(new_kld_mean, 'floatX') # construct the function for training on training data print("Compiling VCGLoop theano functions....") self.train_joint = self._construct_train_joint() return
class VCGLoop(object): """ Controller for training a self-looping VAE using guidance provided by a classifier. The classifier tries to discriminate between samples generated by the looped VAE while the VAE minimizes a variational generative model objective and also shifts mass away from regions where the classifier can discern that the generated data is denser than the training data. This class can also train "policies" for reconstructing partially masked inputs. A reconstruction policy can readily be trained to share the same parameters as a policy for generating transitions while self-looping. The generator must be an instance of the InfNet class implemented in "InfNet.py". The discriminator must be an instance of the PeaNet class, as implemented in "PeaNet.py". The inferencer must be an instance of the InfNet class implemented in "InfNet.py". Parameters: rng: numpy.random.RandomState (for reproducibility) Xd: symbolic var for providing points for starting the Markov Chain Xc: symbolic var for providing points for starting the Markov Chain Xm: symbolic var for providing masks to mix Xd with Xc Xt: symbolic var for providing samples from the target distribution i_net: The InfNet instance that will serve as the inferencer g_net: The InfNet instance that will serve as the generator d_net: The PeaNet instance that will serve as the discriminator chain_len: number of steps to unroll the VAE Markov Chain data_dim: dimension of the generated data prior_dim: dimension of the model prior params: a dict of parameters for controlling various costs x_type: can be "bernoulli" or "gaussian" xt_transform: optional transform for gaussian means logvar_bound: optional bound on gaussian output logvar cost_decay: rate of decay for VAE costs in unrolled chain chain_type: can be 'walkout' or 'walkback' lam_l2d: regularization on squared discriminator output """ def __init__(self, rng=None, Xd=None, Xc=None, Xm=None, Xt=None, \ i_net=None, g_net=None, d_net=None, chain_len=None, \ data_dim=None, prior_dim=None, params=None): # Do some stuff! self.rng = RandStream(rng.randint(100000)) self.data_dim = data_dim self.prior_dim = prior_dim self.prior_mean = 0.0 self.prior_logvar = 0.0 if params is None: self.params = {} else: self.params = params if 'cost_decay' in self.params: self.cost_decay = self.params['cost_decay'] else: self.cost_decay = 0.1 if 'chain_type' in self.params: assert((self.params['chain_type'] == 'walkback') or \ (self.params['chain_type'] == 'walkout')) self.chain_type = self.params['chain_type'] else: self.chain_type = 'walkout' if 'xt_transform' in self.params: assert((self.params['xt_transform'] == 'sigmoid') or \ (self.params['xt_transform'] == 'none')) if self.params['xt_transform'] == 'sigmoid': self.xt_transform = lambda x: T.nnet.sigmoid(x) else: self.xt_transform = lambda x: x else: self.xt_transform = lambda x: T.nnet.sigmoid(x) if 'logvar_bound' in self.params: self.logvar_bound = self.params['logvar_bound'] else: self.logvar_bound = 10 # # x_type: this tells if we're using bernoulli or gaussian model for # the observations # self.x_type = self.params['x_type'] assert ((self.x_type == 'bernoulli') or (self.x_type == 'gaussian')) # symbolic var for inputting samples for initializing the VAE chain self.Xd = Xd # symbolic var for masking subsets of the state variables self.Xm = Xm # symbolic var for controlling subsets of the state variables self.Xc = Xc # symbolic var for inputting samples from the target distribution self.Xt = Xt # integer number of times to cycle the VAE loop self.chain_len = chain_len # symbolic matrix of indices for data inputs self.It = T.arange(self.Xt.shape[0]) # symbolic matrix of indices for noise/generated inputs self.Id = T.arange( self.chain_len * self.Xd.shape[0]) + self.Xt.shape[0] # get a clone of the desired VAE, for easy access self.OSM = OneStageModel(rng=rng, Xd=self.Xd, Xc=self.Xc, Xm=self.Xm, \ p_x_given_z=g_net, q_z_given_x=i_net, x_dim=self.data_dim, \ z_dim=self.prior_dim, params=self.params) self.IN = self.OSM.q_z_given_x self.GN = self.OSM.p_x_given_z self.transform_x_to_z = self.OSM.transform_x_to_z self.transform_z_to_x = self.OSM.transform_z_to_x self.bounded_logvar = self.OSM.bounded_logvar # self-loop some clones of the main VAE into a chain. # ** All VAEs in the chain share the same Xc and Xm, which are the # symbolic inputs for providing the observed portion of the input # and a mask indicating which part of the input is "observed". # These inputs are used for training "reconstruction" policies. self.IN_chain = [] self.GN_chain = [] self.Xg_chain = [] _Xd = self.Xd print("Unrolling chain...") for i in range(self.chain_len): # create a VAE infer/generate pair with _Xd as input and with # masking variables shared by all VAEs in this chain _IN = self.IN.shared_param_clone(rng=rng, \ Xd=apply_mask(Xd=_Xd, Xc=self.Xc, Xm=self.Xm), \ build_funcs=False) _GN = self.GN.shared_param_clone(rng=rng, Xd=_IN.output, \ build_funcs=False) _Xd = self.xt_transform(_GN.output_mean) self.IN_chain.append(_IN) self.GN_chain.append(_GN) self.Xg_chain.append(_Xd) print(" step {}...".format(i)) # make a clone of the desired discriminator network, which will try # to discriminate between samples from the training data and samples # generated by the self-looped VAE chain. self.DN = d_net.shared_param_clone(rng=rng, \ Xd=T.vertical_stack(self.Xt, *self.Xg_chain)) zero_ary = np.zeros((1, )).astype(theano.config.floatX) # init shared var for weighting nll of data given posterior sample self.lam_chain_nll = theano.shared(value=zero_ary, name='vcg_lam_chain_nll') self.set_lam_chain_nll(lam_chain_nll=1.0) # init shared var for weighting posterior KL-div from prior self.lam_chain_kld = theano.shared(value=zero_ary, name='vcg_lam_chain_kld') self.set_lam_chain_kld(lam_chain_kld=1.0) # init shared var for controlling l2 regularization on params self.lam_l2w = theano.shared(value=zero_ary, name='vcg_lam_l2w') self.set_lam_l2w(lam_l2w=1e-4) # shared var learning rates for all networks self.lr_dn = theano.shared(value=zero_ary, name='vcg_lr_dn') self.lr_gn = theano.shared(value=zero_ary, name='vcg_lr_gn') self.lr_in = theano.shared(value=zero_ary, name='vcg_lr_in') # shared var momentum parameters for all networks self.mom_1 = theano.shared(value=zero_ary, name='vcg_mom_1') self.mom_2 = theano.shared(value=zero_ary, name='vcg_mom_2') self.it_count = theano.shared(value=zero_ary, name='vcg_it_count') # shared var weights for adversarial classification objective self.dw_dn = theano.shared(value=zero_ary, name='vcg_dw_dn') self.dw_gn = theano.shared(value=zero_ary, name='vcg_dw_gn') # init parameters for controlling learning dynamics self.set_all_sgd_params() self.set_disc_weights() # init adversarial cost weights for GN/DN # set a shared var for regularizing the output of the discriminator self.lam_l2d = theano.shared(value=(zero_ary + params['lam_l2d']), \ name='vcg_lam_l2d') # Grab the full set of "optimizable" parameters from the generator # and discriminator networks that we'll be working with. We need to # ignore parameters in the final layers of the proto-networks in the # discriminator network (a generalized pseudo-ensemble). We ignore them # because the VCGair requires that they be "bypassed" in favor of some # binary classification layers that will be managed by this VCGair. self.dn_params = [] for pn in self.DN.proto_nets: for pnl in pn[0:-1]: self.dn_params.extend(pnl.params) self.in_params = [p for p in self.IN.mlp_params] self.in_params.append(self.OSM.output_logvar) self.gn_params = [p for p in self.GN.mlp_params] self.joint_params = self.in_params + self.gn_params + self.dn_params # Now construct a binary discriminator layer for each proto-net in the # discriminator network. And, add their params to optimization list. self._construct_disc_layers(rng) self.disc_reg_cost = self.lam_l2d[0] * \ T.sum([dl.act_l2_sum for dl in self.disc_layers]) # Construct costs for the generator and discriminator networks based # on adversarial binary classification self.disc_cost_dn, self.disc_cost_gn = self._construct_disc_costs() # first, build the cost to be optimized by the discriminator network, # in general this will be treated somewhat indepedently of the # optimization of the generator and inferencer networks. self.dn_cost = self.disc_cost_dn + self.DN.act_reg_cost + \ self.disc_reg_cost # construct costs relevant to the optimization of the generator and # discriminator networks self.chain_nll_cost = self.lam_chain_nll[0] * \ self._construct_chain_nll_cost(cost_decay=self.cost_decay) self.chain_kld_cost = self.lam_chain_kld[0] * \ self._construct_chain_kld_cost(cost_decay=self.cost_decay) self.other_reg_cost = self._construct_other_reg_cost() self.osm_cost = self.disc_cost_gn + self.chain_nll_cost + \ self.chain_kld_cost + self.other_reg_cost # compute total cost on the discriminator and VB generator/inferencer self.joint_cost = self.dn_cost + self.osm_cost # Get the gradient of the joint cost for all optimizable parameters self.joint_grads = OrderedDict() print("Computing VCGLoop DN cost gradients...") grad_list = T.grad(self.dn_cost, self.dn_params, disconnected_inputs='warn') for i, p in enumerate(self.dn_params): self.joint_grads[p] = grad_list[i] print("Computing VCGLoop IN cost gradients...") grad_list = T.grad(self.osm_cost, self.in_params, disconnected_inputs='warn') for i, p in enumerate(self.in_params): self.joint_grads[p] = grad_list[i] print("Computing VCGLoop GN cost gradients...") grad_list = T.grad(self.osm_cost, self.gn_params, disconnected_inputs='warn') for i, p in enumerate(self.gn_params): self.joint_grads[p] = grad_list[i] # construct the updates for the discriminator, generator and # inferencer networks. all networks share the same first/second # moment momentum and iteration count. the networks each have their # own learning rates, which lets you turn their learning on/off. self.dn_updates = get_param_updates(params=self.dn_params, \ grads=self.joint_grads, alpha=self.lr_dn, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) self.gn_updates = get_param_updates(params=self.gn_params, \ grads=self.joint_grads, alpha=self.lr_gn, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) self.in_updates = get_param_updates(params=self.in_params, \ grads=self.joint_grads, alpha=self.lr_in, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) # bag up all the updates required for training self.joint_updates = OrderedDict() for k in self.dn_updates: self.joint_updates[k] = self.dn_updates[k] for k in self.gn_updates: self.joint_updates[k] = self.gn_updates[k] for k in self.in_updates: self.joint_updates[k] = self.in_updates[k] # construct an update for tracking the mean KL divergence of # approximate posteriors for this chain new_kld_mean = (0.98 * self.IN.kld_mean) + ((0.02 / self.chain_len) * \ sum([T.mean(I_N.kld_cost) for I_N in self.IN_chain])) self.joint_updates[self.IN.kld_mean] = T.cast(new_kld_mean, 'floatX') # construct the function for training on training data print("Compiling VCGLoop theano functions....") self.train_joint = self._construct_train_joint() return def set_dn_sgd_params(self, learn_rate=0.01): """ Set learning rate for the discriminator network. """ zero_ary = np.zeros((1, )) new_lr = zero_ary + learn_rate self.lr_dn.set_value(new_lr.astype(theano.config.floatX)) return def set_in_sgd_params(self, learn_rate=0.01): """ Set learning rate for the inferencer network. """ zero_ary = np.zeros((1, )) new_lr = zero_ary + learn_rate self.lr_in.set_value(new_lr.astype(theano.config.floatX)) return def set_gn_sgd_params(self, learn_rate=0.01): """ Set learning rate for the generator network. """ zero_ary = np.zeros((1, )) new_lr = zero_ary + learn_rate self.lr_gn.set_value(new_lr.astype(theano.config.floatX)) return def set_all_sgd_params(self, learn_rate=0.01, mom_1=0.9, mom_2=0.999): """ Set learning rate and momentum parameter for all updates. """ zero_ary = np.zeros((1, )) # set learning rates to the same value new_lr = zero_ary + learn_rate self.lr_dn.set_value(new_lr.astype(theano.config.floatX)) self.lr_gn.set_value(new_lr.astype(theano.config.floatX)) self.lr_in.set_value(new_lr.astype(theano.config.floatX)) # set the first/second moment momentum parameters new_mom_1 = zero_ary + mom_1 new_mom_2 = zero_ary + mom_2 self.mom_1.set_value(new_mom_1.astype(theano.config.floatX)) self.mom_2.set_value(new_mom_2.astype(theano.config.floatX)) return def set_disc_weights(self, dweight_gn=1.0, dweight_dn=1.0): """ Set weights for the adversarial classification cost. """ zero_ary = np.zeros((1, )).astype(theano.config.floatX) new_dw_dn = zero_ary + dweight_dn self.dw_dn.set_value(new_dw_dn) new_dw_gn = zero_ary + dweight_gn self.dw_gn.set_value(new_dw_gn) return def set_lam_chain_nll(self, lam_chain_nll=1.0): """ Set weight for controlling the influence of the data likelihood. """ zero_ary = np.zeros((1, )) new_lam = zero_ary + lam_chain_nll self.lam_chain_nll.set_value(new_lam.astype(theano.config.floatX)) return def set_lam_chain_kld(self, lam_chain_kld=1.0): """ Set the strength of regularization on KL-divergence for continuous posterior variables. When set to 1.0, this reproduces the standard role of KL(posterior || prior) in variational learning. """ zero_ary = np.zeros((1, )) new_lam = zero_ary + lam_chain_kld self.lam_chain_kld.set_value(new_lam.astype(theano.config.floatX)) return def set_lam_l2w(self, lam_l2w=1e-3): """ Set the relative strength of l2 regularization on network params. """ zero_ary = np.zeros((1, )) new_lam = zero_ary + lam_l2w self.lam_l2w.set_value(new_lam.astype(theano.config.floatX)) return def _construct_disc_layers(self, rng): """ Construct binary discrimination layers for each spawn-net in the underlying discrimnator pseudo-ensemble. All spawn-nets spawned from the same proto-net will use the same disc-layer parameters. """ self.disc_layers = [] self.disc_outputs = [] dn_init_scale = self.DN.init_scale for sn in self.DN.spawn_nets: # construct a "binary discriminator" layer to sit on top of each # spawn net in the discriminator pseudo-ensemble sn_fl = sn[-1] init_scale = dn_init_scale * (1. / np.sqrt(sn_fl.in_dim)) self.disc_layers.append(DiscLayer(rng=rng, \ input=sn_fl.noisy_input, in_dim=sn_fl.in_dim, \ W_scale=dn_init_scale)) # capture the (linear) output of the DiscLayer, for possible reuse self.disc_outputs.append(self.disc_layers[-1].linear_output) # get the params of this DiscLayer, for convenient optimization self.dn_params.extend(self.disc_layers[-1].params) return def _construct_disc_costs(self): """ Construct the generator and discriminator adversarial costs. """ gn_costs = [] dn_costs = [] for dl_output in self.disc_outputs: data_preds = dl_output.take(self.It, axis=0) noise_preds = dl_output.take(self.Id, axis=0) # compute the cost with respect to which we will be optimizing # the parameters of the discriminator network data_size = T.cast(self.It.size, 'floatX') noise_size = T.cast(self.Id.size, 'floatX') dnl_dn_cost = (logreg_loss(data_preds, 1.0) / data_size) + \ (logreg_loss(noise_preds, -1.0) / noise_size) # compute the cost with respect to which we will be optimizing # the parameters of the generative model dnl_gn_cost = (hinge_loss(noise_preds, 0.0) + hinge_sq_loss( noise_preds, 0.0)) / (2.0 * noise_size) dn_costs.append(dnl_dn_cost) gn_costs.append(dnl_gn_cost) dn_cost = self.dw_dn[0] * T.sum(dn_costs) gn_cost = self.dw_gn[0] * T.sum(gn_costs) return [dn_cost, gn_cost] def _log_prob_wrapper(self, x_true, x_apprx): """ Wrap log-prob with switching for bernoulli/gaussian output types. """ if self.x_type == 'bernoulli': ll_cost = log_prob_bernoulli(x_true, x_apprx) else: ll_cost = log_prob_gaussian2(x_true, x_apprx, \ log_vars=self.bounded_logvar) nll_cost = -ll_cost return nll_cost def _construct_chain_nll_cost(self, cost_decay=0.1): """ Construct the negative log-likelihood part of cost to minimize. This is for operation in "free chain" mode, where a seed point is used to initialize a long(ish) running markov chain. """ assert ((cost_decay >= 0.0) and (cost_decay <= 1.0)) obs_count = T.cast(self.Xd.shape[0], 'floatX') nll_costs = [] step_weight = 1.0 step_weights = [] step_decay = cost_decay for i in range(self.chain_len): if self.chain_type == 'walkback': # train with walkback roll-outs -- reconstruct initial point IN_i = self.IN_chain[0] else: # train with walkout roll-outs -- reconstruct previous point IN_i = self.IN_chain[i] x_true = IN_i.Xd x_apprx = self.Xg_chain[i] c = T.mean(self._log_prob_wrapper(x_true, x_apprx)) nll_costs.append(step_weight * c) step_weights.append(step_weight) step_weight = step_weight * step_decay nll_cost = sum(nll_costs) / sum(step_weights) return nll_cost def _construct_chain_kld_cost(self, cost_decay=0.1): """ Construct the posterior KL-d from prior part of cost to minimize. This is for operation in "free chain" mode, where a seed point is used to initialize a long(ish) running markov chain. """ assert ((cost_decay >= 0.0) and (cost_decay <= 1.0)) obs_count = T.cast(self.Xd.shape[0], 'floatX') kld_mean = self.IN.kld_mean[0] kld_costs = [] step_weight = 1.0 step_weights = [] step_decay = cost_decay for i in range(self.chain_len): IN_i = self.IN_chain[i] # basic variational term on KL divergence between post and prior kld_i = gaussian_kld(IN_i.output_mean, IN_i.output_logvar, \ self.prior_mean, self.prior_logvar) kld_i_costs = T.sum(kld_i, axis=1) # sum and reweight the KLd cost for this step in the chain c = T.mean(kld_i_costs) kld_costs.append(step_weight * c) step_weights.append(step_weight) step_weight = step_weight * step_decay kld_cost = sum(kld_costs) / sum(step_weights) return kld_cost def _construct_other_reg_cost(self): """ Construct the cost for low-level basic regularization. E.g. for applying l2 regularization to the network parameters. """ gp_cost = sum([T.sum(par**2.0) for par in self.gn_params]) ip_cost = sum([T.sum(par**2.0) for par in self.in_params]) other_reg_cost = self.lam_l2w[0] * (gp_cost + ip_cost) return other_reg_cost def _construct_train_joint(self): """ Construct theano function to train generator and discriminator jointly. """ # symbolic vars for passing input to training function xd = T.matrix() xc = T.matrix() xm = T.matrix() xt = T.matrix() batch_reps = T.lscalar() # collect outputs to return to caller outputs = [self.joint_cost, self.chain_nll_cost, self.chain_kld_cost, \ self.chain_nll_cost, self.chain_kld_cost, self.disc_cost_gn, \ self.disc_cost_dn, self.other_reg_cost] func = theano.function(inputs=[ xd, xc, xm, xt, batch_reps ], \ outputs=outputs, updates=self.joint_updates, \ givens={ self.Xd: xd.repeat(batch_reps, axis=0), \ self.Xc: xc.repeat(batch_reps, axis=0), \ self.Xm: xm.repeat(batch_reps, axis=0), \ self.Xt: xt }) return func def sample_from_chain(self, X_d, X_c=None, X_m=None, loop_iters=5, \ sigma_scale=None): """ Sample for several rounds through the I<->G loop, initialized with the the "data variable" samples in X_d. """ result = self.OSM.sample_from_chain(X_d, X_c=X_c, X_m=X_m, \ loop_iters=loop_iters, sigma_scale=sigma_scale) return result def sample_from_prior(self, samp_count): """ Draw independent samples from the model's prior, using the gaussian continuous prior of the underlying GenNet. """ Xs = self.OSM.sample_from_prior(samp_count) return Xs
def test_one_stage_model(): ########################## # Get some training data # ########################## rng = np.random.RandomState(1234) Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/') Xtr = np.vstack((Xtr, Xva)) Xva = Xte #del Xte tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] batch_size = 128 batch_reps = 1 ############################################### # Setup some parameters for the OneStageModel # ############################################### x_dim = Xtr.shape[1] z_dim = 64 x_type = 'bernoulli' xin_sym = T.matrix('xin_sym') ############### # p_x_given_z # ############### params = {} shared_config = \ [ {'layer_type': 'fc', 'in_chans': z_dim, 'out_chans': 256, 'activation': relu_actfun, 'apply_bn': True}, \ {'layer_type': 'fc', 'in_chans': 256, 'out_chans': 7*7*128, 'activation': relu_actfun, 'apply_bn': True, 'shape_func_out': lambda x: T.reshape(x, (-1, 128, 7, 7))}, \ {'layer_type': 'conv', 'in_chans': 128, # in shape: (batch, 128, 7, 7) 'out_chans': 64, # out shape: (batch, 64, 14, 14) 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'half', 'apply_bn': True} ] output_config = \ [ {'layer_type': 'conv', 'in_chans': 64, # in shape: (batch, 64, 14, 14) 'out_chans': 1, # out shape: (batch, 1, 28, 28) 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'half', 'apply_bn': False, 'shape_func_out': lambda x: T.flatten(x, 2)}, \ {'layer_type': 'conv', 'in_chans': 64, 'out_chans': 1, 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'half', 'apply_bn': False, 'shape_func_out': lambda x: T.flatten(x, 2)} ] params['shared_config'] = shared_config params['output_config'] = output_config params['init_scale'] = 1.0 params['build_theano_funcs'] = False p_x_given_z = HydraNet(rng=rng, Xd=xin_sym, \ params=params, shared_param_dicts=None) p_x_given_z.init_biases(0.0) ############### # q_z_given_x # ############### params = {} shared_config = \ [ {'layer_type': 'conv', 'in_chans': 1, # in shape: (batch, 784) 'out_chans': 64, # out shape: (batch, 64, 14, 14) 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'double', 'apply_bn': True, 'shape_func_in': lambda x: T.reshape(x, (-1, 1, 28, 28))}, \ {'layer_type': 'conv', 'in_chans': 64, # in shape: (batch, 64, 14, 14) 'out_chans': 128, # out shape: (batch, 128, 7, 7) 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'double', 'apply_bn': True, 'shape_func_out': lambda x: T.flatten(x, 2)}, \ {'layer_type': 'fc', 'in_chans': 128*7*7, 'out_chans': 256, 'activation': relu_actfun, 'apply_bn': True} ] output_config = \ [ {'layer_type': 'fc', 'in_chans': 256, 'out_chans': z_dim, 'activation': relu_actfun, 'apply_bn': False}, \ {'layer_type': 'fc', 'in_chans': 256, 'out_chans': z_dim, 'activation': relu_actfun, 'apply_bn': False} ] params['shared_config'] = shared_config params['output_config'] = output_config params['init_scale'] = 1.0 params['build_theano_funcs'] = False q_z_given_x = HydraNet(rng=rng, Xd=xin_sym, \ params=params, shared_param_dicts=None) q_z_given_x.init_biases(0.0) ############################################################## # Define parameters for the TwoStageModel, and initialize it # ############################################################## print("Building the OneStageModel...") osm_params = {} osm_params['x_type'] = x_type osm_params['obs_transform'] = 'sigmoid' OSM = OneStageModel(rng=rng, x_in=xin_sym, x_dim=x_dim, z_dim=z_dim, p_x_given_z=p_x_given_z, q_z_given_x=q_z_given_x, params=osm_params) ################################################################ # Apply some updates, to check that they aren't totally broken # ################################################################ log_name = "{}_RESULTS.txt".format("OSM_TEST") out_file = open(log_name, 'wb') costs = [0. for i in range(10)] learn_rate = 0.0005 momentum = 0.9 batch_idx = np.arange(batch_size) + tr_samples for i in range(500000): scale = min(0.5, ((i+1) / 5000.0)) if (((i + 1) % 10000) == 0): learn_rate = learn_rate * 0.95 # get the indices of training samples for this batch update batch_idx += batch_size if (np.max(batch_idx) >= tr_samples): # we finished an "epoch", so we rejumble the training set Xtr = row_shuffle(Xtr) batch_idx = np.arange(batch_size) Xb = to_fX( Xtr.take(batch_idx, axis=0) ) #Xb = binarize_data(Xtr.take(batch_idx, axis=0)) # set sgd and objective function hyperparams for this update OSM.set_sgd_params(lr=scale*learn_rate, \ mom_1=(scale*momentum), mom_2=0.98) OSM.set_lam_nll(lam_nll=1.0) OSM.set_lam_kld(lam_kld=1.0) OSM.set_lam_l2w(1e-5) # perform a minibatch update and record the cost for this batch result = OSM.train_joint(Xb, batch_reps) costs = [(costs[j] + result[j]) for j in range(len(result))] if ((i % 500) == 0): costs = [(v / 500.0) for v in costs] str1 = "-- batch {0:d} --".format(i) str2 = " joint_cost: {0:.4f}".format(costs[0]) str3 = " nll_cost : {0:.4f}".format(costs[1]) str4 = " kld_cost : {0:.4f}".format(costs[2]) str5 = " reg_cost : {0:.4f}".format(costs[3]) joint_str = "\n".join([str1, str2, str3, str4, str5]) print(joint_str) out_file.write(joint_str+"\n") out_file.flush() costs = [0.0 for v in costs] if (((i % 5000) == 0) or ((i < 10000) and ((i % 1000) == 0))): # draw some independent random samples from the model samp_count = 300 model_samps = OSM.sample_from_prior(samp_count) file_name = "OSM_SAMPLES_b{0:d}.png".format(i) utils.visualize_samples(model_samps, file_name, num_rows=15) # compute free energy estimate for validation samples Xva = row_shuffle(Xva) fe_terms = OSM.compute_fe_terms(Xva[0:5000], 20) fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1]) out_str = " nll_bound : {0:.4f}".format(fe_mean) print(out_str) out_file.write(out_str+"\n") out_file.flush() return
def test_gip_sigma_scale_mnist(): from LogPDFs import cross_validate_sigma # Simple test code, to check that everything is basically functional. print("TESTING...") # Initialize a source of randomness rng = np.random.RandomState(12345) # Load some data to train/validate/test with dataset = 'data/mnist.pkl.gz' datasets = load_udm(dataset, zero_mean=False) Xtr = datasets[0][0] Xtr = Xtr.get_value(borrow=False) Xva = datasets[2][0] Xva = Xva.get_value(borrow=False) print("Xtr.shape: {0:s}, Xva.shape: {1:s}".format(str(Xtr.shape),str(Xva.shape))) # get and set some basic dataset information tr_samples = Xtr.shape[0] batch_size = 100 Xtr_mean = np.mean(Xtr, axis=0, keepdims=True) Xtr_mean = (0.0 * Xtr_mean) + np.mean(Xtr) Xc_mean = np.repeat(Xtr_mean, batch_size, axis=0).astype(theano.config.floatX) # Symbolic inputs Xd = T.matrix(name='Xd') Xc = T.matrix(name='Xc') Xm = T.matrix(name='Xm') Xt = T.matrix(name='Xt') # Load inferencer and generator from saved parameters gn_fname = "MNIST_WALKOUT_TEST_MAX_KLD/pt_walk_params_b70000_GN.pkl" in_fname = "MNIST_WALKOUT_TEST_MAX_KLD/pt_walk_params_b70000_IN.pkl" IN = load_infnet_from_file(f_name=in_fname, rng=rng, Xd=Xd) GN = load_infnet_from_file(f_name=gn_fname, rng=rng, Xd=Xd) x_dim = IN.shared_layers[0].in_dim z_dim = IN.mu_layers[-1].out_dim # construct a GIPair with the loaded InfNet and GenNet osm_params = {} osm_params['x_type'] = 'gaussian' osm_params['xt_transform'] = 'sigmoid' osm_params['logvar_bound'] = LOGVAR_BOUND OSM = OneStageModel(rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, \ p_x_given_z=GN, q_z_given_x=IN, \ x_dim=x_dim, z_dim=z_dim, params=osm_params) # compute variational likelihood bound and its sub-components Xva = row_shuffle(Xva) Xb = Xva[0:5000] file_name = "A_MNIST_POST_KLDS.png" post_klds = OSM.compute_post_klds(Xb) post_dim_klds = np.mean(post_klds, axis=0) utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \ file_name) # compute information about free-energy on validation set file_name = "A_MNIST_FREE_ENERGY.png" fe_terms = OSM.compute_fe_terms(Xb, 20) utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \ x_label='Posterior KLd', y_label='Negative Log-likelihood') # bound_results = OSM.compute_ll_bound(Xva) # ll_bounds = bound_results[0] # post_klds = bound_results[1] # log_likelihoods = bound_results[2] # max_lls = bound_results[3] # print("mean ll bound: {0:.4f}".format(np.mean(ll_bounds))) # print("mean posterior KLd: {0:.4f}".format(np.mean(post_klds))) # print("mean log-likelihood: {0:.4f}".format(np.mean(log_likelihoods))) # print("mean max log-likelihood: {0:.4f}".format(np.mean(max_lls))) # print("min ll bound: {0:.4f}".format(np.min(ll_bounds))) # print("max posterior KLd: {0:.4f}".format(np.max(post_klds))) # print("min log-likelihood: {0:.4f}".format(np.min(log_likelihoods))) # print("min max log-likelihood: {0:.4f}".format(np.min(max_lls))) # # compute some information about the approximate posteriors # post_stats = OSM.compute_post_stats(Xva, 0.0*Xva, 0.0*Xva) # all_post_klds = np.sort(post_stats[0].ravel()) # post KLds for each obs and dim # obs_post_klds = np.sort(post_stats[1]) # summed post KLds for each obs # post_dim_klds = post_stats[2] # average post KLds for each post dim # post_dim_vars = post_stats[3] # average squared mean for each post dim # utils.plot_line(np.arange(all_post_klds.shape[0]), all_post_klds, "AAA_ALL_POST_KLDS.png") # utils.plot_line(np.arange(obs_post_klds.shape[0]), obs_post_klds, "AAA_OBS_POST_KLDS.png") # utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, "AAA_POST_DIM_KLDS.png") # utils.plot_stem(np.arange(post_dim_vars.shape[0]), post_dim_vars, "AAA_POST_DIM_VARS.png") # draw many samples from the GIP for i in range(5): tr_idx = npr.randint(low=0,high=tr_samples,size=(100,)) Xd_batch = Xtr.take(tr_idx, axis=0) Xs = [] for row in range(3): Xs.append([]) for col in range(3): sample_lists = OSM.sample_from_chain(Xd_batch[0:10,:], loop_iters=100, \ sigma_scale=1.0) Xs[row].append(group_chains(sample_lists['data samples'])) Xs, block_im_dim = block_video(Xs, (28,28), (3,3)) to_video(Xs, block_im_dim, "A_MNIST_KLD_CHAIN_VIDEO_{0:d}.avi".format(i), frame_rate=10) #sample_lists = GIP.sample_from_chain(Xd_batch[0,:].reshape((1,data_dim)), loop_iters=300, \ # sigma_scale=1.0) #Xs = np.vstack(sample_lists["data samples"]) #file_name = "TFD_TEST_{0:d}.png".format(i) #utils.visualize_samples(Xs, file_name, num_rows=15) file_name = "A_MNIST_KLD_PRIOR_SAMPLE.png" Xs = OSM.sample_from_prior(20*20) utils.visualize_samples(Xs, file_name, num_rows=20) # # test Parzen density estimator built from prior samples # Xs = OSM.sample_from_prior(10000) # [best_sigma, best_ll, best_lls] = \ # cross_validate_sigma(Xs, Xva, [0.12, 0.14, 0.15, 0.16, 0.18], 20) # sort_idx = np.argsort(best_lls) # sort_idx = sort_idx[0:400] # utils.plot_line(np.arange(sort_idx.shape[0]), best_lls[sort_idx], "A_MNIST_BEST_LLS_1.png") # utils.visualize_samples(Xva[sort_idx], "A_MNIST_BAD_DIGITS_1.png", num_rows=20) # ########## # # AGAIN! # # ########## # Xs = OSM.sample_from_prior(10000) # tr_idx = npr.randint(low=0,high=tr_samples,size=(5000,)) # Xva = Xtr.take(tr_idx, axis=0) # [best_sigma, best_ll, best_lls] = \ # cross_validate_sigma(Xs, Xva, [0.12, 0.14, 0.15, 0.16, 0.18], 20) # sort_idx = np.argsort(best_lls) # sort_idx = sort_idx[0:400] # utils.plot_line(np.arange(sort_idx.shape[0]), best_lls[sort_idx], "A_MNIST_BEST_LLS_2.png") # utils.visualize_samples(Xva[sort_idx], "A_MNIST_BAD_DIGITS_2.png", num_rows=20) return
def test_gip_sigma_scale_tfd(): from LogPDFs import cross_validate_sigma # Simple test code, to check that everything is basically functional. print("TESTING...") # Initialize a source of randomness rng = np.random.RandomState(12345) # Load some data to train/validate/test with data_file = 'data/tfd_data_48x48.pkl' dataset = load_tfd(tfd_pkl_name=data_file, which_set='unlabeled', fold='all') Xtr_unlabeled = dataset[0] dataset = load_tfd(tfd_pkl_name=data_file, which_set='train', fold='all') Xtr_train = dataset[0] Xtr = np.vstack([Xtr_unlabeled, Xtr_train]) dataset = load_tfd(tfd_pkl_name=data_file, which_set='test', fold='all') Xva = dataset[0] tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] print("Xtr.shape: {0:s}, Xva.shape: {1:s}".format(str(Xtr.shape), str(Xva.shape))) # get and set some basic dataset information tr_samples = Xtr.shape[0] data_dim = Xtr.shape[1] batch_size = 100 # Symbolic inputs Xd = T.matrix(name='Xd') Xc = T.matrix(name='Xc') Xm = T.matrix(name='Xm') Xt = T.matrix(name='Xt') # Load inferencer and generator from saved parameters gn_fname = "TFD_WALKOUT_TEST_KLD/pt_walk_params_b25000_GN.pkl" in_fname = "TFD_WALKOUT_TEST_KLD/pt_walk_params_b25000_IN.pkl" IN = load_infnet_from_file(f_name=in_fname, rng=rng, Xd=Xd) GN = load_infnet_from_file(f_name=gn_fname, rng=rng, Xd=Xd) x_dim = IN.shared_layers[0].in_dim z_dim = IN.mu_layers[-1].out_dim # construct a GIPair with the loaded InfNet and GenNet osm_params = {} osm_params['x_type'] = 'gaussian' osm_params['xt_transform'] = 'sigmoid' osm_params['logvar_bound'] = LOGVAR_BOUND OSM = OneStageModel(rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, \ p_x_given_z=GN, q_z_given_x=IN, \ x_dim=x_dim, z_dim=z_dim, params=osm_params) # # compute variational likelihood bound and its sub-components Xva = row_shuffle(Xva) Xb = Xva[0:5000] # file_name = "A_TFD_POST_KLDS.png" # post_klds = OSM.compute_post_klds(Xb) # post_dim_klds = np.mean(post_klds, axis=0) # utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \ # file_name) # compute information about free-energy on validation set file_name = "A_TFD_KLD_FREE_ENERGY.png" fe_terms = OSM.compute_fe_terms(Xb, 20) utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \ x_label='Posterior KLd', y_label='Negative Log-likelihood') # bound_results = OSM.compute_ll_bound(Xva) # ll_bounds = bound_results[0] # post_klds = bound_results[1] # log_likelihoods = bound_results[2] # max_lls = bound_results[3] # print("mean ll bound: {0:.4f}".format(np.mean(ll_bounds))) # print("mean posterior KLd: {0:.4f}".format(np.mean(post_klds))) # print("mean log-likelihood: {0:.4f}".format(np.mean(log_likelihoods))) # print("mean max log-likelihood: {0:.4f}".format(np.mean(max_lls))) # print("min ll bound: {0:.4f}".format(np.min(ll_bounds))) # print("max posterior KLd: {0:.4f}".format(np.max(post_klds))) # print("min log-likelihood: {0:.4f}".format(np.min(log_likelihoods))) # print("min max log-likelihood: {0:.4f}".format(np.min(max_lls))) # # compute some information about the approximate posteriors # post_stats = OSM.compute_post_stats(Xva, 0.0*Xva, 0.0*Xva) # all_post_klds = np.sort(post_stats[0].ravel()) # post KLds for each obs and dim # obs_post_klds = np.sort(post_stats[1]) # summed post KLds for each obs # post_dim_klds = post_stats[2] # average post KLds for each post dim # post_dim_vars = post_stats[3] # average squared mean for each post dim # utils.plot_line(np.arange(all_post_klds.shape[0]), all_post_klds, "AAA_ALL_POST_KLDS.png") # utils.plot_line(np.arange(obs_post_klds.shape[0]), obs_post_klds, "AAA_OBS_POST_KLDS.png") # utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, "AAA_POST_DIM_KLDS.png") # utils.plot_stem(np.arange(post_dim_vars.shape[0]), post_dim_vars, "AAA_POST_DIM_VARS.png") # draw many samples from the GIP for i in range(5): tr_idx = npr.randint(low=0, high=tr_samples, size=(100, )) Xd_batch = Xtr.take(tr_idx, axis=0) Xs = [] for row in range(3): Xs.append([]) for col in range(3): sample_lists = OSM.sample_from_chain(Xd_batch[0:10,:], loop_iters=100, \ sigma_scale=1.0) Xs[row].append(group_chains(sample_lists['data samples'])) Xs, block_im_dim = block_video(Xs, (48, 48), (3, 3)) to_video(Xs, block_im_dim, "A_TFD_KLD_CHAIN_VIDEO_{0:d}.avi".format(i), frame_rate=10) #sample_lists = GIP.sample_from_chain(Xd_batch[0,:].reshape((1,data_dim)), loop_iters=300, \ # sigma_scale=1.0) #Xs = np.vstack(sample_lists["data samples"]) #file_name = "TFD_TEST_{0:d}.png".format(i) #utils.visualize_samples(Xs, file_name, num_rows=15) file_name = "A_TFD_KLD_PRIOR_SAMPLE.png" Xs = OSM.sample_from_prior(20 * 20) utils.visualize_samples(Xs, file_name, num_rows=20) # test Parzen density estimator built from prior samples # Xs = OSM.sample_from_prior(10000) # [best_sigma, best_ll, best_lls] = \ # cross_validate_sigma(Xs, Xva, [0.09, 0.095, 0.1, 0.105, 0.11], 10) # sort_idx = np.argsort(best_lls) # sort_idx = sort_idx[0:400] # utils.plot_line(np.arange(sort_idx.shape[0]), best_lls[sort_idx], "A_TFD_BEST_LLS_1.png") # utils.visualize_samples(Xva[sort_idx], "A_TFD_BAD_FACES_1.png", num_rows=20) return
def test_svhn(occ_dim=15, drop_prob=0.0): RESULT_PATH = "IMP_SVHN_VAE/" ######################################### # Format the result tag more thoroughly # ######################################### dp_int = int(100.0 * drop_prob) result_tag = "{}VAE_OD{}_DP{}".format(RESULT_PATH, occ_dim, dp_int) ########################## # Get some training data # ########################## tr_file = 'data/svhn_train_gray.pkl' te_file = 'data/svhn_test_gray.pkl' ex_file = 'data/svhn_extra_gray.pkl' data = load_svhn_gray(tr_file, te_file, ex_file=ex_file, ex_count=200000) Xtr = to_fX( shift_and_scale_into_01(np.vstack([data['Xtr'], data['Xex']])) ) Xva = to_fX( shift_and_scale_into_01(data['Xte']) ) tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] batch_size = 250 all_pix_mean = np.mean(np.mean(Xtr, axis=1)) data_mean = to_fX( all_pix_mean * np.ones((Xtr.shape[1],)) ) ############################################################ # Setup some parameters for the Iterative Refinement Model # ############################################################ obs_dim = Xtr.shape[1] z_dim = 100 imp_steps = 15 # we'll check for the best step count (found oracularly) init_scale = 1.0 x_in_sym = T.matrix('x_in_sym') x_out_sym = T.matrix('x_out_sym') x_mask_sym = T.matrix('x_mask_sym') ################# # p_zi_given_xi # ################# params = {} shared_config = [obs_dim, 1000, 1000] top_config = [shared_config[-1], z_dim] params['shared_config'] = shared_config params['mu_config'] = top_config params['sigma_config'] = top_config params['activation'] = relu_actfun params['init_scale'] = init_scale params['lam_l2a'] = 0.0 params['vis_drop'] = 0.0 params['hid_drop'] = 0.0 params['bias_noise'] = 0.0 params['input_noise'] = 0.0 params['build_theano_funcs'] = False p_zi_given_xi = InfNet(rng=rng, Xd=x_in_sym, \ params=params, shared_param_dicts=None) p_zi_given_xi.init_biases(0.2) ################### # p_xip1_given_zi # ################### params = {} shared_config = [z_dim, 1000, 1000] output_config = [obs_dim, obs_dim] params['shared_config'] = shared_config params['output_config'] = output_config params['activation'] = relu_actfun params['init_scale'] = init_scale params['lam_l2a'] = 0.0 params['vis_drop'] = 0.0 params['hid_drop'] = 0.0 params['bias_noise'] = 0.0 params['input_noise'] = 0.0 params['build_theano_funcs'] = False p_xip1_given_zi = HydraNet(rng=rng, Xd=x_in_sym, \ params=params, shared_param_dicts=None) p_xip1_given_zi.init_biases(0.2) ################### # q_zi_given_x_xi # ################### params = {} shared_config = [(obs_dim + obs_dim), 1000, 1000] top_config = [shared_config[-1], z_dim] params['shared_config'] = shared_config params['mu_config'] = top_config params['sigma_config'] = top_config params['activation'] = relu_actfun params['init_scale'] = init_scale params['lam_l2a'] = 0.0 params['vis_drop'] = 0.0 params['hid_drop'] = 0.0 params['bias_noise'] = 0.0 params['input_noise'] = 0.0 params['build_theano_funcs'] = False q_zi_given_x_xi = InfNet(rng=rng, Xd=x_in_sym, \ params=params, shared_param_dicts=None) q_zi_given_x_xi.init_biases(0.2) ########################################################### # Define parameters for the GPSImputer, and initialize it # ########################################################### print("Building the GPSImputer...") gpsi_params = {} gpsi_params['obs_dim'] = obs_dim gpsi_params['z_dim'] = z_dim gpsi_params['imp_steps'] = imp_steps gpsi_params['step_type'] = 'jump' gpsi_params['x_type'] = 'bernoulli' gpsi_params['obs_transform'] = 'sigmoid' gpsi_params['use_osm_mode'] = True GPSI = GPSImputer(rng=rng, x_in=x_in_sym, x_out=x_out_sym, x_mask=x_mask_sym, \ p_zi_given_xi=p_zi_given_xi, \ p_xip1_given_zi=p_xip1_given_zi, \ q_zi_given_x_xi=q_zi_given_x_xi, \ params=gpsi_params, \ shared_param_dicts=None) ######################################################################### # Define parameters for the underlying OneStageModel, and initialize it # ######################################################################### print("Building the OneStageModel...") osm_params = {} osm_params['x_type'] = 'bernoulli' osm_params['xt_transform'] = 'sigmoid' OSM = OneStageModel(rng=rng, \ x_in=x_in_sym, \ p_x_given_z=p_xip1_given_zi, \ q_z_given_x=p_zi_given_xi, \ x_dim=obs_dim, z_dim=z_dim, \ params=osm_params) ################################################################ # Apply some updates, to check that they aren't totally broken # ################################################################ log_name = "{}_RESULTS.txt".format(result_tag) out_file = open(log_name, 'wb') costs = [0. for i in range(10)] learn_rate = 0.0002 momentum = 0.5 batch_idx = np.arange(batch_size) + tr_samples for i in range(200005): scale = min(1.0, ((i+1) / 5000.0)) if (((i + 1) % 15000) == 0): learn_rate = learn_rate * 0.92 if (i > 10000): momentum = 0.90 else: momentum = 0.50 # get the indices of training samples for this batch update batch_idx += batch_size if (np.max(batch_idx) >= tr_samples): # we finished an "epoch", so we rejumble the training set Xtr = row_shuffle(Xtr) batch_idx = np.arange(batch_size) # set sgd and objective function hyperparams for this update OSM.set_sgd_params(lr=scale*learn_rate, \ mom_1=scale*momentum, mom_2=0.99) OSM.set_lam_nll(lam_nll=1.0) OSM.set_lam_kld(lam_kld_1=1.0, lam_kld_2=0.0) OSM.set_lam_l2w(1e-4) # perform a minibatch update and record the cost for this batch xb = to_fX( Xtr.take(batch_idx, axis=0) ) result = OSM.train_joint(xb, batch_reps) costs = [(costs[j] + result[j]) for j in range(len(result)-1)] if ((i % 250) == 0): costs = [(v / 250.0) for v in costs] str1 = "-- batch {0:d} --".format(i) str2 = " joint_cost: {0:.4f}".format(costs[0]) str3 = " nll_cost : {0:.4f}".format(costs[1]) str4 = " kld_cost : {0:.4f}".format(costs[2]) str5 = " reg_cost : {0:.4f}".format(costs[3]) joint_str = "\n".join([str1, str2, str3, str4, str5]) print(joint_str) out_file.write(joint_str+"\n") out_file.flush() costs = [0.0 for v in costs] if ((i % 1000) == 0): Xva = row_shuffle(Xva) # record an estimate of performance on the test set xi, xo, xm = construct_masked_data(Xva[0:5000], drop_prob=drop_prob, \ occ_dim=occ_dim, data_mean=data_mean) step_nll, step_kld = GPSI.compute_per_step_cost(xi, xo, xm, sample_count=10) min_nll = np.min(step_nll) str1 = " va_nll_bound : {}".format(min_nll) str2 = " va_nll_min : {}".format(min_nll) str3 = " va_nll_final : {}".format(step_nll[-1]) joint_str = "\n".join([str1, str2, str3]) print(joint_str) out_file.write(joint_str+"\n") out_file.flush() if ((i % 10000) == 0): # Get some validation samples for evaluating model performance xb = to_fX( Xva[0:100] ) xi, xo, xm = construct_masked_data(xb, drop_prob=drop_prob, \ occ_dim=occ_dim, data_mean=data_mean) xi = np.repeat(xi, 2, axis=0) xo = np.repeat(xo, 2, axis=0) xm = np.repeat(xm, 2, axis=0) # draw some sample imputations from the model samp_count = xi.shape[0] _, model_samps = GPSI.sample_imputer(xi, xo, xm, use_guide_policy=False) seq_len = len(model_samps) seq_samps = np.zeros((seq_len*samp_count, model_samps[0].shape[1])) idx = 0 for s1 in range(samp_count): for s2 in range(seq_len): seq_samps[idx] = model_samps[s2][s1] idx += 1 file_name = "{}_samples_ng_b{}.png".format(result_tag, i) utils.visualize_samples(seq_samps, file_name, num_rows=20) # get visualizations of policy parameters file_name = "{}_gen_gen_weights_b{}.png".format(result_tag, i) W = GPSI.gen_gen_weights.get_value(borrow=False) utils.visualize_samples(W[:,:obs_dim], file_name, num_rows=20) file_name = "{}_gen_inf_weights_b{}.png".format(result_tag, i) W = GPSI.gen_inf_weights.get_value(borrow=False).T utils.visualize_samples(W[:,:obs_dim], file_name, num_rows=20)
def test_one_stage_model(): ########################## # Get some training data # ########################## rng = np.random.RandomState(1234) Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/') Xtr = np.vstack((Xtr, Xva)) Xva = Xte #del Xte tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] batch_size = 128 batch_reps = 1 ############################################### # Setup some parameters for the OneStageModel # ############################################### x_dim = Xtr.shape[1] z_dim = 64 x_type = 'bernoulli' xin_sym = T.matrix('xin_sym') ############### # p_x_given_z # ############### params = {} shared_config = \ [ {'layer_type': 'fc', 'in_chans': z_dim, 'out_chans': 256, 'activation': relu_actfun, 'apply_bn': True}, \ {'layer_type': 'fc', 'in_chans': 256, 'out_chans': 7*7*128, 'activation': relu_actfun, 'apply_bn': True, 'shape_func_out': lambda x: T.reshape(x, (-1, 128, 7, 7))}, \ {'layer_type': 'conv', 'in_chans': 128, # in shape: (batch, 128, 7, 7) 'out_chans': 64, # out shape: (batch, 64, 14, 14) 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'half', 'apply_bn': True} ] output_config = \ [ {'layer_type': 'conv', 'in_chans': 64, # in shape: (batch, 64, 14, 14) 'out_chans': 1, # out shape: (batch, 1, 28, 28) 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'half', 'apply_bn': False, 'shape_func_out': lambda x: T.flatten(x, 2)}, \ {'layer_type': 'conv', 'in_chans': 64, 'out_chans': 1, 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'half', 'apply_bn': False, 'shape_func_out': lambda x: T.flatten(x, 2)} ] params['shared_config'] = shared_config params['output_config'] = output_config params['init_scale'] = 1.0 params['build_theano_funcs'] = False p_x_given_z = HydraNet(rng=rng, Xd=xin_sym, \ params=params, shared_param_dicts=None) p_x_given_z.init_biases(0.0) ############### # q_z_given_x # ############### params = {} shared_config = \ [ {'layer_type': 'conv', 'in_chans': 1, # in shape: (batch, 784) 'out_chans': 64, # out shape: (batch, 64, 14, 14) 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'double', 'apply_bn': True, 'shape_func_in': lambda x: T.reshape(x, (-1, 1, 28, 28))}, \ {'layer_type': 'conv', 'in_chans': 64, # in shape: (batch, 64, 14, 14) 'out_chans': 128, # out shape: (batch, 128, 7, 7) 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'double', 'apply_bn': True, 'shape_func_out': lambda x: T.flatten(x, 2)}, \ {'layer_type': 'fc', 'in_chans': 128*7*7, 'out_chans': 256, 'activation': relu_actfun, 'apply_bn': True} ] output_config = \ [ {'layer_type': 'fc', 'in_chans': 256, 'out_chans': z_dim, 'activation': relu_actfun, 'apply_bn': False}, \ {'layer_type': 'fc', 'in_chans': 256, 'out_chans': z_dim, 'activation': relu_actfun, 'apply_bn': False} ] params['shared_config'] = shared_config params['output_config'] = output_config params['init_scale'] = 1.0 params['build_theano_funcs'] = False q_z_given_x = HydraNet(rng=rng, Xd=xin_sym, \ params=params, shared_param_dicts=None) q_z_given_x.init_biases(0.0) ############################################################## # Define parameters for the TwoStageModel, and initialize it # ############################################################## print("Building the OneStageModel...") osm_params = {} osm_params['x_type'] = x_type osm_params['obs_transform'] = 'sigmoid' OSM = OneStageModel(rng=rng, x_in=xin_sym, x_dim=x_dim, z_dim=z_dim, p_x_given_z=p_x_given_z, q_z_given_x=q_z_given_x, params=osm_params) ################################################################ # Apply some updates, to check that they aren't totally broken # ################################################################ log_name = "{}_RESULTS.txt".format("OSM_TEST") out_file = open(log_name, 'wb') costs = [0. for i in range(10)] learn_rate = 0.0005 momentum = 0.9 batch_idx = np.arange(batch_size) + tr_samples for i in range(500000): scale = min(0.5, ((i + 1) / 5000.0)) if (((i + 1) % 10000) == 0): learn_rate = learn_rate * 0.95 # get the indices of training samples for this batch update batch_idx += batch_size if (np.max(batch_idx) >= tr_samples): # we finished an "epoch", so we rejumble the training set Xtr = row_shuffle(Xtr) batch_idx = np.arange(batch_size) Xb = to_fX(Xtr.take(batch_idx, axis=0)) #Xb = binarize_data(Xtr.take(batch_idx, axis=0)) # set sgd and objective function hyperparams for this update OSM.set_sgd_params(lr=scale*learn_rate, \ mom_1=(scale*momentum), mom_2=0.98) OSM.set_lam_nll(lam_nll=1.0) OSM.set_lam_kld(lam_kld=1.0) OSM.set_lam_l2w(1e-5) # perform a minibatch update and record the cost for this batch result = OSM.train_joint(Xb, batch_reps) costs = [(costs[j] + result[j]) for j in range(len(result))] if ((i % 500) == 0): costs = [(v / 500.0) for v in costs] str1 = "-- batch {0:d} --".format(i) str2 = " joint_cost: {0:.4f}".format(costs[0]) str3 = " nll_cost : {0:.4f}".format(costs[1]) str4 = " kld_cost : {0:.4f}".format(costs[2]) str5 = " reg_cost : {0:.4f}".format(costs[3]) joint_str = "\n".join([str1, str2, str3, str4, str5]) print(joint_str) out_file.write(joint_str + "\n") out_file.flush() costs = [0.0 for v in costs] if (((i % 5000) == 0) or ((i < 10000) and ((i % 1000) == 0))): # draw some independent random samples from the model samp_count = 300 model_samps = OSM.sample_from_prior(samp_count) file_name = "OSM_SAMPLES_b{0:d}.png".format(i) utils.visualize_samples(model_samps, file_name, num_rows=15) # compute free energy estimate for validation samples Xva = row_shuffle(Xva) fe_terms = OSM.compute_fe_terms(Xva[0:5000], 20) fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1]) out_str = " nll_bound : {0:.4f}".format(fe_mean) print(out_str) out_file.write(out_str + "\n") out_file.flush() return
class VCGLoop(object): """ Controller for training a self-looping VAE using guidance provided by a classifier. The classifier tries to discriminate between samples generated by the looped VAE while the VAE minimizes a variational generative model objective and also shifts mass away from regions where the classifier can discern that the generated data is denser than the training data. The generator must be an instance of the InfNet class implemented in "InfNet.py". The discriminator must be an instance of the PeaNet class, as implemented in "PeaNet.py". The inferencer must be an instance of the InfNet class implemented in "InfNet.py". Parameters: rng: numpy.random.RandomState (for reproducibility) x_d: symbolic var for providing points for starting the Markov Chain x_t: symbolic var for providing samples from the target distribution i_net: The InfNet instance that will serve as the inferencer g_net: The HydraNet instance that will serve as the generator d_net: The PeaNet instance that will serve as the discriminator chain_len: number of steps to unroll the VAE Markov Chain data_dim: dimension of the generated data z_dim: dimension of the model prior params: a dict of parameters for controlling various costs x_type: can be "bernoulli" or "gaussian" xt_transform: optional transform for gaussian means logvar_bound: optional bound on gaussian output logvar cost_decay: rate of decay for VAE costs in unrolled chain chain_type: can be 'walkout' or 'walkback' lam_l2d: regularization on squared discriminator output """ def __init__(self, rng=None, x_d=None, x_t=None, \ i_net=None, g_net=None, d_net=None, \ chain_len=None, data_dim=None, z_dim=None, \ params=None): # Do some stuff! self.rng = RandStream(rng.randint(100000)) self.data_dim = data_dim self.z_dim = z_dim self.p_z_mean = 0.0 self.p_z_logvar = 0.0 if params is None: self.params = {} else: self.params = params if 'cost_decay' in self.params: self.cost_decay = self.params['cost_decay'] else: self.cost_decay = 0.1 if 'chain_type' in self.params: assert((self.params['chain_type'] == 'walkback') or \ (self.params['chain_type'] == 'walkout')) self.chain_type = self.params['chain_type'] else: self.chain_type = 'walkout' if 'xt_transform' in self.params: assert((self.params['xt_transform'] == 'sigmoid') or \ (self.params['xt_transform'] == 'none')) if self.params['xt_transform'] == 'sigmoid': self.xt_transform = lambda x: T.nnet.sigmoid(x) else: self.xt_transform = lambda x: x else: self.xt_transform = lambda x: T.nnet.sigmoid(x) if 'logvar_bound' in self.params: self.logvar_bound = self.params['logvar_bound'] else: self.logvar_bound = 10 # # x_type: this tells if we're using bernoulli or gaussian model for # the observations # self.x_type = self.params['x_type'] assert((self.x_type == 'bernoulli') or (self.x_type == 'gaussian')) # grab symbolic input variables self.x_d = x_d # initial input for starting the chain self.x_t = x_t # samples from target distribution self.z_zmuv = T.tensor3() # ZMUV gaussian samples for use in scan # get the number of steps for chain unrolling self.chain_len = chain_len # symbolic matrix of indices for inputs from target distribution self.It = T.arange(self.x_t.shape[0]) # symbolic matrix of indices for noise/generated inputs self.Id = T.arange(self.chain_len * self.x_d.shape[0]) + self.x_t.shape[0] # get a clone of the desired VAE, for easy access self.OSM = OneStageModel(rng=rng, x_in=self.x_d, \ p_x_given_z=g_net, q_z_given_x=i_net, \ x_dim=self.data_dim, z_dim=self.z_dim, \ params=self.params) self.IN = self.OSM.q_z_given_x self.GN = self.OSM.p_x_given_z self.transform_x_to_z = self.OSM.transform_x_to_z self.transform_z_to_x = self.OSM.transform_z_to_x self.bounded_logvar = self.OSM.bounded_logvar ################################################## # self-loop the VAE into a multi-step Markov chain. # ** All VAEs in the chain share the same Xc and Xm, which are the # symbolic inputs for providing the observed portion of the input # and a mask indicating which part of the input is "observed". # These inputs are used for training "reconstruction" policies. ################################################## # Setup the iterative generation loop using scan # ################################################## def chain_step_func(zi_zmuv, xim1): # get mean and logvar of z samples for this step zi_mean, zi_logvar = self.IN.apply(xim1, do_samples=False) # transform ZMUV samples to get desired samples zi = (T.exp(0.5 * zi_logvar) * zi_zmuv) + zi_mean # get the next generated xi (pre-transformation) outputs = self.GN.apply(zi) xti = outputs[-1] # apply the observation "mean" transform xgi = self.xt_transform(xti) # compute NLL for this step if self.chain_type == 'walkout': x_true = self.x_d else: x_true = xim1 nlli = self._log_prob(x_true, xgi).flatten() kldi = T.sum(gaussian_kld(zi_mean, zi_logvar, \ self.p_z_mean, self.p_z_logvar), axis=1) return xgi, nlli, kldi # apply the scan op init_values = [self.x_d, None, None] self.scan_results, self.scan_updates = \ theano.scan(chain_step_func, outputs_info=init_values, \ sequences=self.z_zmuv) # get the outputs of the scan op self.xgi = self.scan_results[0] self.nlli = self.scan_results[1] self.kldi = self.scan_results[2] self.xgi_list = [self.xgi[i] for i in range(self.chain_len)] # make a clone of the desired discriminator network, which will try # to discriminate between samples from the training data and samples # generated by the self-looped VAE chain. self.DN = d_net.shared_param_clone(rng=rng, \ Xd=T.vertical_stack(self.x_t, *self.xgi_list)) zero_ary = np.zeros((1,)).astype(theano.config.floatX) # init shared var for weighting nll of data given posterior sample self.lam_chain_nll = theano.shared(value=zero_ary, name='vcg_lam_chain_nll') self.set_lam_chain_nll(lam_chain_nll=1.0) # init shared var for weighting posterior KL-div from prior self.lam_chain_kld = theano.shared(value=zero_ary, name='vcg_lam_chain_kld') self.set_lam_chain_kld(lam_chain_kld=1.0) # init shared var for controlling l2 regularization on params self.lam_l2w = theano.shared(value=zero_ary, name='vcg_lam_l2w') self.set_lam_l2w(lam_l2w=1e-4) # shared var learning rates for all networks self.lr_dn = theano.shared(value=zero_ary, name='vcg_lr_dn') self.lr_gn = theano.shared(value=zero_ary, name='vcg_lr_gn') self.lr_in = theano.shared(value=zero_ary, name='vcg_lr_in') # shared var momentum parameters for all networks self.mom_1 = theano.shared(value=zero_ary, name='vcg_mom_1') self.mom_2 = theano.shared(value=zero_ary, name='vcg_mom_2') # shared var weights for adversarial classification objective self.dw_dn = theano.shared(value=zero_ary, name='vcg_dw_dn') self.dw_gn = theano.shared(value=zero_ary, name='vcg_dw_gn') # init parameters for controlling learning dynamics self.set_all_sgd_params() # init adversarial cost weights for GN/DN self.set_disc_weights() # set a shared var for regularizing the output of the discriminator self.lam_l2d = theano.shared(value=(zero_ary + params['lam_l2d']), \ name='vcg_lam_l2d') # Grab the full set of "optimizable" parameters from the generator # and discriminator networks that we'll be working with. We need to # ignore parameters in the final layers of the proto-networks in the # discriminator network (a generalized pseudo-ensemble). We ignore them # because the VCGair requires that they be "bypassed" in favor of some # binary classification layers that will be managed by this VCGair. self.dn_params = [] for pn in self.DN.proto_nets: for pnl in pn[0:-1]: self.dn_params.extend(pnl.params) self.in_params = [p for p in self.IN.mlp_params] self.gn_params = [p for p in self.GN.mlp_params] self.joint_params = self.in_params + self.gn_params + self.dn_params # Now construct a binary discriminator layer for each proto-net in the # discriminator network. And, add their params to optimization list. self._construct_disc_layers(rng) self.disc_reg_cost = self.lam_l2d[0] * \ T.sum([dl.act_l2_sum for dl in self.disc_layers]) # Construct costs for the generator and discriminator networks based # on adversarial binary classification self.disc_cost_dn, self.disc_cost_gn = self._construct_disc_costs() # first, build the cost to be optimized by the discriminator network, # in general this will be treated somewhat indepedently of the # optimization of the generator and inferencer networks. self.dn_cost = self.disc_cost_dn + self.disc_reg_cost # construct costs relevant to the optimization of the generator and # discriminator networks self.chain_nll_cost = self.lam_chain_nll[0] * \ self._construct_chain_nll_cost(cost_decay=self.cost_decay) self.chain_kld_cost = self.lam_chain_kld[0] * \ self._construct_chain_kld_cost(cost_decay=self.cost_decay) self.other_reg_cost = self._construct_other_reg_cost() self.osm_cost = self.disc_cost_gn + self.chain_nll_cost + \ self.chain_kld_cost + self.other_reg_cost # compute total cost on the discriminator and VB generator/inferencer self.joint_cost = self.dn_cost + self.osm_cost print("Computing VCGLoop joint_grad...") # grab the gradients for all parameters to optimize self.joint_grads = OrderedDict() for p in self.dn_params: self.joint_grads[p] = T.grad(self.dn_cost, p) for p in self.in_params: self.joint_grads[p] = T.grad(self.osm_cost, p) for p in self.gn_params: self.joint_grads[p] = T.grad(self.osm_cost, p) # construct the updates for the discriminator, generator and # inferencer networks. all networks share the same first/second # moment momentum and iteration count. the networks each have their # own learning rates, which lets you turn their learning on/off. self.dn_updates = get_adam_updates(params=self.dn_params, \ grads=self.joint_grads, alpha=self.lr_dn, \ beta1=self.mom_1, beta2=self.mom_2, \ mom2_init=1e-3, smoothing=1e-4, max_grad_norm=10.0) self.in_updates = get_adam_updates(params=self.in_params, \ grads=self.joint_grads, alpha=self.lr_in, \ beta1=self.mom_1, beta2=self.mom_2, \ mom2_init=1e-3, smoothing=1e-4, max_grad_norm=10.0) self.gn_updates = get_adam_updates(params=self.gn_params, \ grads=self.joint_grads, alpha=self.lr_gn, \ beta1=self.mom_1, beta2=self.mom_2, \ mom2_init=1e-3, smoothing=1e-4, max_grad_norm=10.0) # bag up all the updates required for training self.joint_updates = OrderedDict() for k in self.dn_updates: self.joint_updates[k] = self.dn_updates[k] for k in self.in_updates: self.joint_updates[k] = self.in_updates[k] for k in self.gn_updates: self.joint_updates[k] = self.gn_updates[k] print("Compiling VCGLoop train_joint...") # construct the function for training on training data self.train_joint = self._construct_train_joint() return def set_dn_sgd_params(self, learn_rate=0.01): """ Set learning rate for the discriminator network. """ zero_ary = np.zeros((1,)) new_lr = zero_ary + learn_rate self.lr_dn.set_value(new_lr.astype(theano.config.floatX)) return def set_in_sgd_params(self, learn_rate=0.01): """ Set learning rate for the inferencer network. """ zero_ary = np.zeros((1,)) new_lr = zero_ary + learn_rate self.lr_in.set_value(new_lr.astype(theano.config.floatX)) return def set_gn_sgd_params(self, learn_rate=0.01): """ Set learning rate for the generator network. """ zero_ary = np.zeros((1,)) new_lr = zero_ary + learn_rate self.lr_gn.set_value(new_lr.astype(theano.config.floatX)) return def set_all_sgd_params(self, learn_rate=0.01, mom_1=0.9, mom_2=0.999): """ Set learning rate and momentum parameter for all updates. """ zero_ary = np.zeros((1,)) # set learning rates to the same value new_lr = zero_ary + learn_rate self.lr_dn.set_value(new_lr.astype(theano.config.floatX)) self.lr_gn.set_value(new_lr.astype(theano.config.floatX)) self.lr_in.set_value(new_lr.astype(theano.config.floatX)) # set the first/second moment momentum parameters new_mom_1 = zero_ary + mom_1 new_mom_2 = zero_ary + mom_2 self.mom_1.set_value(new_mom_1.astype(theano.config.floatX)) self.mom_2.set_value(new_mom_2.astype(theano.config.floatX)) return def set_disc_weights(self, dweight_gn=1.0, dweight_dn=1.0): """ Set weights for the adversarial classification cost. """ zero_ary = np.zeros((1,)).astype(theano.config.floatX) new_dw_dn = zero_ary + dweight_dn self.dw_dn.set_value(new_dw_dn) new_dw_gn = zero_ary + dweight_gn self.dw_gn.set_value(new_dw_gn) return def set_lam_chain_nll(self, lam_chain_nll=1.0): """ Set weight for controlling the influence of the data likelihood. """ zero_ary = np.zeros((1,)) new_lam = zero_ary + lam_chain_nll self.lam_chain_nll.set_value(new_lam.astype(theano.config.floatX)) return def set_lam_chain_kld(self, lam_chain_kld=1.0): """ Set the strength of regularization on KL-divergence for continuous posterior variables. When set to 1.0, this reproduces the standard role of KL(posterior || prior) in variational learning. """ zero_ary = np.zeros((1,)) new_lam = zero_ary + lam_chain_kld self.lam_chain_kld.set_value(new_lam.astype(theano.config.floatX)) return def set_lam_l2w(self, lam_l2w=1e-3): """ Set the relative strength of l2 regularization on network params. """ zero_ary = np.zeros((1,)) new_lam = zero_ary + lam_l2w self.lam_l2w.set_value(new_lam.astype(theano.config.floatX)) return def _construct_zmuv_samples(self, xi, br): """ Construct the necessary (symbolic) samples for computing through this VCGLoop for input (sybolic) matrix X. """ z_zmuv = self.rng.normal( \ size=(self.chain_len, xi.shape[0]*br, self.z_dim), \ avg=0.0, std=1.0, dtype=theano.config.floatX) return z_zmuv def _construct_disc_layers(self, rng): """ Construct binary discrimination layers for each spawn-net in the underlying discrimnator pseudo-ensemble. All spawn-nets spawned from the same proto-net will use the same disc-layer parameters. """ self.disc_layers = [] self.disc_outputs = [] dn_init_scale = self.DN.init_scale for sn in self.DN.spawn_nets: # construct a "binary discriminator" layer to sit on top of each # spawn net in the discriminator pseudo-ensemble sn_fl = sn[-1] self.disc_layers.append(DiscLayer(rng=rng, \ input=sn_fl.noisy_input, in_dim=sn_fl.in_dim, \ W_scale=dn_init_scale)) # capture the (linear) output of the DiscLayer, for possible reuse self.disc_outputs.append(self.disc_layers[-1].linear_output) # get the params of this DiscLayer, for convenient optimization self.dn_params.extend(self.disc_layers[-1].params) return def _construct_disc_costs(self): """ Construct the generator and discriminator adversarial costs. """ gn_costs = [] dn_costs = [] for dl_output in self.disc_outputs: data_preds = dl_output.take(self.It, axis=0) noise_preds = dl_output.take(self.Id, axis=0) # compute the cost with respect to which we will be optimizing # the parameters of the discriminator network data_size = T.cast(self.It.size, 'floatX') noise_size = T.cast(self.Id.size, 'floatX') dnl_dn_cost = (logreg_loss(data_preds, 1.0) / data_size) + \ (logreg_loss(noise_preds, -1.0) / noise_size) # compute the cost with respect to which we will be optimizing # the parameters of the generative model dnl_gn_cost = (hinge_loss(noise_preds, 0.0) + hinge_sq_loss(noise_preds, 0.0)) / (2.0 * noise_size) dn_costs.append(dnl_dn_cost) gn_costs.append(dnl_gn_cost) dn_cost = self.dw_dn[0] * T.sum(dn_costs) gn_cost = self.dw_gn[0] * T.sum(gn_costs) return [dn_cost, gn_cost] def _log_prob(self, x_true, x_apprx): """ Wrap log-prob with switching for bernoulli/gaussian output types. """ if self.x_type == 'bernoulli': ll_cost = log_prob_bernoulli(x_true, x_apprx) else: ll_cost = log_prob_gaussian2(x_true, x_apprx, \ log_vars=self.bounded_logvar) nll_cost = -ll_cost return nll_cost def _construct_chain_nll_cost(self, cost_decay=0.1): """ Construct the negative log-likelihood part of cost to minimize. This is for operation in "free chain" mode, where a seed point is used to initialize a long(ish) running markov chain. """ assert((cost_decay > 0.0) and (cost_decay < 1.0)) nll_costs = [] step_weight = 1.0 step_weights = [] step_decay = cost_decay for i in range(self.chain_len): c = T.mean(self.nlli[i]) nll_costs.append(step_weight * c) step_weights.append(step_weight) step_weight = step_weight * step_decay nll_cost = sum(nll_costs) / sum(step_weights) return nll_cost def _construct_chain_kld_cost(self, cost_decay=0.1): """ Construct the posterior KLd from prior part of cost to minimize. This is for operation in "free chain" mode, where a seed point is used to initialize a long(ish) running markov chain. """ assert((cost_decay > 0.0) and (cost_decay < 1.0)) kld_costs = [] step_weight = 1.0 step_weights = [] step_decay = cost_decay for i in range(self.chain_len): # sum and reweight the KLd cost for this step in the chain c = T.mean(self.kldi[i]) kld_costs.append(step_weight * c) step_weights.append(step_weight) step_weight = step_weight * step_decay kld_cost = sum(kld_costs) / sum(step_weights) return kld_cost def _construct_other_reg_cost(self): """ Construct the cost for low-level basic regularization. E.g. for applying l2 regularization to the network parameters. """ gp_cost = sum([T.sum(par**2.0) for par in self.gn_params]) ip_cost = sum([T.sum(par**2.0) for par in self.in_params]) other_reg_cost = self.lam_l2w[0] * (gp_cost + ip_cost) return other_reg_cost def _construct_train_joint(self): """ Construct theano function to train generator and discriminator jointly. """ # symbolic vars for passing input to training function xd = T.matrix() xt = T.matrix() br = T.lscalar() zzmuv = self._construct_zmuv_samples(xd, br) # collect outputs to return to caller outputs = [self.joint_cost, self.chain_nll_cost, self.chain_kld_cost, \ self.disc_cost_gn, self.disc_cost_dn, self.other_reg_cost] func = theano.function(inputs=[ xd, xt, br ], \ outputs=outputs, updates=self.joint_updates, \ givens={ self.x_d: xd.repeat(br, axis=0), \ self.x_t: xt, self.z_zmuv: zzmuv }) return func def sample_from_chain(self, X_d, X_c=None, X_m=None, loop_iters=5, \ sigma_scale=None): """ Sample for several rounds through the I<->G loop, initialized with the the "data variable" samples in X_d. """ result = self.OSM.sample_from_chain(X_d, X_c=X_c, X_m=X_m, \ loop_iters=loop_iters, sigma_scale=sigma_scale) return result def sample_from_prior(self, samp_count): """ Draw independent samples from the model's prior. """ Xs = self.OSM.sample_from_prior(samp_count) return Xs
def __init__(self, rng=None, x_d=None, x_t=None, \ i_net=None, g_net=None, d_net=None, \ chain_len=None, data_dim=None, z_dim=None, \ params=None): # Do some stuff! self.rng = RandStream(rng.randint(100000)) self.data_dim = data_dim self.z_dim = z_dim self.p_z_mean = 0.0 self.p_z_logvar = 0.0 if params is None: self.params = {} else: self.params = params if 'cost_decay' in self.params: self.cost_decay = self.params['cost_decay'] else: self.cost_decay = 0.1 if 'chain_type' in self.params: assert((self.params['chain_type'] == 'walkback') or \ (self.params['chain_type'] == 'walkout')) self.chain_type = self.params['chain_type'] else: self.chain_type = 'walkout' if 'xt_transform' in self.params: assert((self.params['xt_transform'] == 'sigmoid') or \ (self.params['xt_transform'] == 'none')) if self.params['xt_transform'] == 'sigmoid': self.xt_transform = lambda x: T.nnet.sigmoid(x) else: self.xt_transform = lambda x: x else: self.xt_transform = lambda x: T.nnet.sigmoid(x) if 'logvar_bound' in self.params: self.logvar_bound = self.params['logvar_bound'] else: self.logvar_bound = 10 # # x_type: this tells if we're using bernoulli or gaussian model for # the observations # self.x_type = self.params['x_type'] assert((self.x_type == 'bernoulli') or (self.x_type == 'gaussian')) # grab symbolic input variables self.x_d = x_d # initial input for starting the chain self.x_t = x_t # samples from target distribution self.z_zmuv = T.tensor3() # ZMUV gaussian samples for use in scan # get the number of steps for chain unrolling self.chain_len = chain_len # symbolic matrix of indices for inputs from target distribution self.It = T.arange(self.x_t.shape[0]) # symbolic matrix of indices for noise/generated inputs self.Id = T.arange(self.chain_len * self.x_d.shape[0]) + self.x_t.shape[0] # get a clone of the desired VAE, for easy access self.OSM = OneStageModel(rng=rng, x_in=self.x_d, \ p_x_given_z=g_net, q_z_given_x=i_net, \ x_dim=self.data_dim, z_dim=self.z_dim, \ params=self.params) self.IN = self.OSM.q_z_given_x self.GN = self.OSM.p_x_given_z self.transform_x_to_z = self.OSM.transform_x_to_z self.transform_z_to_x = self.OSM.transform_z_to_x self.bounded_logvar = self.OSM.bounded_logvar ################################################## # self-loop the VAE into a multi-step Markov chain. # ** All VAEs in the chain share the same Xc and Xm, which are the # symbolic inputs for providing the observed portion of the input # and a mask indicating which part of the input is "observed". # These inputs are used for training "reconstruction" policies. ################################################## # Setup the iterative generation loop using scan # ################################################## def chain_step_func(zi_zmuv, xim1): # get mean and logvar of z samples for this step zi_mean, zi_logvar = self.IN.apply(xim1, do_samples=False) # transform ZMUV samples to get desired samples zi = (T.exp(0.5 * zi_logvar) * zi_zmuv) + zi_mean # get the next generated xi (pre-transformation) outputs = self.GN.apply(zi) xti = outputs[-1] # apply the observation "mean" transform xgi = self.xt_transform(xti) # compute NLL for this step if self.chain_type == 'walkout': x_true = self.x_d else: x_true = xim1 nlli = self._log_prob(x_true, xgi).flatten() kldi = T.sum(gaussian_kld(zi_mean, zi_logvar, \ self.p_z_mean, self.p_z_logvar), axis=1) return xgi, nlli, kldi # apply the scan op init_values = [self.x_d, None, None] self.scan_results, self.scan_updates = \ theano.scan(chain_step_func, outputs_info=init_values, \ sequences=self.z_zmuv) # get the outputs of the scan op self.xgi = self.scan_results[0] self.nlli = self.scan_results[1] self.kldi = self.scan_results[2] self.xgi_list = [self.xgi[i] for i in range(self.chain_len)] # make a clone of the desired discriminator network, which will try # to discriminate between samples from the training data and samples # generated by the self-looped VAE chain. self.DN = d_net.shared_param_clone(rng=rng, \ Xd=T.vertical_stack(self.x_t, *self.xgi_list)) zero_ary = np.zeros((1,)).astype(theano.config.floatX) # init shared var for weighting nll of data given posterior sample self.lam_chain_nll = theano.shared(value=zero_ary, name='vcg_lam_chain_nll') self.set_lam_chain_nll(lam_chain_nll=1.0) # init shared var for weighting posterior KL-div from prior self.lam_chain_kld = theano.shared(value=zero_ary, name='vcg_lam_chain_kld') self.set_lam_chain_kld(lam_chain_kld=1.0) # init shared var for controlling l2 regularization on params self.lam_l2w = theano.shared(value=zero_ary, name='vcg_lam_l2w') self.set_lam_l2w(lam_l2w=1e-4) # shared var learning rates for all networks self.lr_dn = theano.shared(value=zero_ary, name='vcg_lr_dn') self.lr_gn = theano.shared(value=zero_ary, name='vcg_lr_gn') self.lr_in = theano.shared(value=zero_ary, name='vcg_lr_in') # shared var momentum parameters for all networks self.mom_1 = theano.shared(value=zero_ary, name='vcg_mom_1') self.mom_2 = theano.shared(value=zero_ary, name='vcg_mom_2') # shared var weights for adversarial classification objective self.dw_dn = theano.shared(value=zero_ary, name='vcg_dw_dn') self.dw_gn = theano.shared(value=zero_ary, name='vcg_dw_gn') # init parameters for controlling learning dynamics self.set_all_sgd_params() # init adversarial cost weights for GN/DN self.set_disc_weights() # set a shared var for regularizing the output of the discriminator self.lam_l2d = theano.shared(value=(zero_ary + params['lam_l2d']), \ name='vcg_lam_l2d') # Grab the full set of "optimizable" parameters from the generator # and discriminator networks that we'll be working with. We need to # ignore parameters in the final layers of the proto-networks in the # discriminator network (a generalized pseudo-ensemble). We ignore them # because the VCGair requires that they be "bypassed" in favor of some # binary classification layers that will be managed by this VCGair. self.dn_params = [] for pn in self.DN.proto_nets: for pnl in pn[0:-1]: self.dn_params.extend(pnl.params) self.in_params = [p for p in self.IN.mlp_params] self.gn_params = [p for p in self.GN.mlp_params] self.joint_params = self.in_params + self.gn_params + self.dn_params # Now construct a binary discriminator layer for each proto-net in the # discriminator network. And, add their params to optimization list. self._construct_disc_layers(rng) self.disc_reg_cost = self.lam_l2d[0] * \ T.sum([dl.act_l2_sum for dl in self.disc_layers]) # Construct costs for the generator and discriminator networks based # on adversarial binary classification self.disc_cost_dn, self.disc_cost_gn = self._construct_disc_costs() # first, build the cost to be optimized by the discriminator network, # in general this will be treated somewhat indepedently of the # optimization of the generator and inferencer networks. self.dn_cost = self.disc_cost_dn + self.disc_reg_cost # construct costs relevant to the optimization of the generator and # discriminator networks self.chain_nll_cost = self.lam_chain_nll[0] * \ self._construct_chain_nll_cost(cost_decay=self.cost_decay) self.chain_kld_cost = self.lam_chain_kld[0] * \ self._construct_chain_kld_cost(cost_decay=self.cost_decay) self.other_reg_cost = self._construct_other_reg_cost() self.osm_cost = self.disc_cost_gn + self.chain_nll_cost + \ self.chain_kld_cost + self.other_reg_cost # compute total cost on the discriminator and VB generator/inferencer self.joint_cost = self.dn_cost + self.osm_cost print("Computing VCGLoop joint_grad...") # grab the gradients for all parameters to optimize self.joint_grads = OrderedDict() for p in self.dn_params: self.joint_grads[p] = T.grad(self.dn_cost, p) for p in self.in_params: self.joint_grads[p] = T.grad(self.osm_cost, p) for p in self.gn_params: self.joint_grads[p] = T.grad(self.osm_cost, p) # construct the updates for the discriminator, generator and # inferencer networks. all networks share the same first/second # moment momentum and iteration count. the networks each have their # own learning rates, which lets you turn their learning on/off. self.dn_updates = get_adam_updates(params=self.dn_params, \ grads=self.joint_grads, alpha=self.lr_dn, \ beta1=self.mom_1, beta2=self.mom_2, \ mom2_init=1e-3, smoothing=1e-4, max_grad_norm=10.0) self.in_updates = get_adam_updates(params=self.in_params, \ grads=self.joint_grads, alpha=self.lr_in, \ beta1=self.mom_1, beta2=self.mom_2, \ mom2_init=1e-3, smoothing=1e-4, max_grad_norm=10.0) self.gn_updates = get_adam_updates(params=self.gn_params, \ grads=self.joint_grads, alpha=self.lr_gn, \ beta1=self.mom_1, beta2=self.mom_2, \ mom2_init=1e-3, smoothing=1e-4, max_grad_norm=10.0) # bag up all the updates required for training self.joint_updates = OrderedDict() for k in self.dn_updates: self.joint_updates[k] = self.dn_updates[k] for k in self.in_updates: self.joint_updates[k] = self.in_updates[k] for k in self.gn_updates: self.joint_updates[k] = self.gn_updates[k] print("Compiling VCGLoop train_joint...") # construct the function for training on training data self.train_joint = self._construct_train_joint() return