def pretrain_osm(lam_kld=0.0): # Initialize a source of randomness rng = np.random.RandomState(1234) # Load some data to train/validate/test with data_file = 'data/tfd_data_48x48.pkl' dataset = load_tfd(tfd_pkl_name=data_file, which_set='unlabeled', fold='all') Xtr_unlabeled = dataset[0] dataset = load_tfd(tfd_pkl_name=data_file, which_set='train', fold='all') Xtr_train = dataset[0] Xtr = np.vstack([Xtr_unlabeled, Xtr_train]) dataset = load_tfd(tfd_pkl_name=data_file, which_set='valid', fold='all') Xva = dataset[0] tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] batch_size = 400 batch_reps = 6 carry_frac = 0.25 carry_size = int(batch_size * carry_frac) reset_prob = 0.04 # setup some symbolic variables and stuff Xd = T.matrix('Xd_base') Xc = T.matrix('Xc_base') Xm = T.matrix('Xm_base') data_dim = Xtr.shape[1] prior_sigma = 1.0 Xtr_mean = np.mean(Xtr, axis=0) ########################## # NETWORK CONFIGURATIONS # ########################## gn_params = {} shared_config = [PRIOR_DIM, 1500, 1500] top_config = [shared_config[-1], data_dim] gn_params['shared_config'] = shared_config gn_params['mu_config'] = top_config gn_params['sigma_config'] = top_config gn_params['activation'] = relu_actfun gn_params['init_scale'] = 1.4 gn_params['lam_l2a'] = 0.0 gn_params['vis_drop'] = 0.0 gn_params['hid_drop'] = 0.0 gn_params['bias_noise'] = 0.0 gn_params['input_noise'] = 0.0 # choose some parameters for the continuous inferencer in_params = {} shared_config = [data_dim, 1500, 1500] top_config = [shared_config[-1], PRIOR_DIM] in_params['shared_config'] = shared_config in_params['mu_config'] = top_config in_params['sigma_config'] = top_config in_params['activation'] = relu_actfun in_params['init_scale'] = 1.4 in_params['lam_l2a'] = 0.0 in_params['vis_drop'] = 0.0 in_params['hid_drop'] = 0.0 in_params['bias_noise'] = 0.0 in_params['input_noise'] = 0.0 # Initialize the base networks for this OneStageModel IN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \ params=in_params, shared_param_dicts=None) GN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \ params=gn_params, shared_param_dicts=None) # Initialize biases in IN and GN IN.init_biases(0.2) GN.init_biases(0.2) ###################################### # LOAD AND RESTART FROM SAVED PARAMS # ###################################### # gn_fname = RESULT_PATH+"pt_osm_params_b110000_GN.pkl" # in_fname = RESULT_PATH+"pt_osm_params_b110000_IN.pkl" # IN = load_infnet_from_file(f_name=in_fname, rng=rng, Xd=Xd, \ # new_params=None) # GN = load_infnet_from_file(f_name=gn_fname, rng=rng, Xd=Xd, \ # new_params=None) # in_params = IN.params # gn_params = GN.params ######################### # INITIALIZE THE GIPAIR # ######################### osm_params = {} osm_params['x_type'] = 'bernoulli' osm_params['xt_transform'] = 'sigmoid' osm_params['logvar_bound'] = LOGVAR_BOUND OSM = OneStageModel(rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, \ p_x_given_z=GN, q_z_given_x=IN, \ x_dim=data_dim, z_dim=PRIOR_DIM, params=osm_params) OSM.set_lam_l2w(1e-5) safe_mean = (0.9 * Xtr_mean) + 0.05 safe_mean_logit = np.log(safe_mean / (1.0 - safe_mean)) OSM.set_output_bias(safe_mean_logit) OSM.set_input_bias(-Xtr_mean) ###################### # BASIC VAE TRAINING # ###################### out_file = open(RESULT_PATH + "pt_osm_results.txt", 'wb') # Set initial learning rate and basic SGD hyper parameters obs_costs = np.zeros((batch_size, )) costs = [0. for i in range(10)] learn_rate = 0.002 for i in range(200000): scale = min(1.0, float(i) / 5000.0) if ((i > 1) and ((i % 20000) == 0)): learn_rate = learn_rate * 0.8 if (i < 50000): momentum = 0.5 elif (i < 10000): momentum = 0.7 else: momentum = 0.9 if ((i == 0) or (npr.rand() < reset_prob)): # sample a fully random batch batch_idx = npr.randint(low=0, high=tr_samples, size=(batch_size, )) else: # sample a partially random batch, which retains some portion of # the worst scoring examples from the previous batch fresh_idx = npr.randint(low=0, high=tr_samples, size=(batch_size - carry_size, )) batch_idx = np.concatenate((fresh_idx.ravel(), carry_idx.ravel())) # do a minibatch update of the model, and compute some costs tr_idx = npr.randint(low=0, high=tr_samples, size=(batch_size, )) Xd_batch = Xtr.take(tr_idx, axis=0) Xc_batch = 0.0 * Xd_batch Xm_batch = 0.0 * Xd_batch # do a minibatch update of the model, and compute some costs OSM.set_sgd_params(lr_1=(scale*learn_rate), \ mom_1=(scale*momentum), mom_2=0.98) OSM.set_lam_nll(1.0) OSM.set_lam_kld(lam_kld_1=scale * lam_kld, lam_kld_2=0.0, lam_kld_c=50.0) result = OSM.train_joint(Xd_batch, Xc_batch, Xm_batch, batch_reps) batch_costs = result[4] + result[5] obs_costs = collect_obs_costs(batch_costs, batch_reps) carry_idx = batch_idx[np.argsort(-obs_costs)[0:carry_size]] costs = [(costs[j] + result[j]) for j in range(len(result))] if ((i % 1000) == 0): # record and then reset the cost trackers costs = [(v / 1000.0) for v in costs] str_1 = "-- batch {0:d} --".format(i) str_2 = " joint_cost: {0:.4f}".format(costs[0]) str_3 = " nll_cost : {0:.4f}".format(costs[1]) str_4 = " kld_cost : {0:.4f}".format(costs[2]) str_5 = " reg_cost : {0:.4f}".format(costs[3]) costs = [0.0 for v in costs] # print out some diagnostic information joint_str = "\n".join([str_1, str_2, str_3, str_4, str_5]) print(joint_str) out_file.write(joint_str + "\n") out_file.flush() if ((i % 2000) == 0): Xva = row_shuffle(Xva) model_samps = OSM.sample_from_prior(500) file_name = RESULT_PATH + "pt_osm_samples_b{0:d}_XG.png".format(i) utils.visualize_samples(model_samps, file_name, num_rows=20) file_name = RESULT_PATH + "pt_osm_inf_weights_b{0:d}.png".format(i) utils.visualize_samples(OSM.inf_weights.get_value(borrow=False).T, \ file_name, num_rows=30) file_name = RESULT_PATH + "pt_osm_gen_weights_b{0:d}.png".format(i) utils.visualize_samples(OSM.gen_weights.get_value(borrow=False), \ file_name, num_rows=30) # compute information about free-energy on validation set file_name = RESULT_PATH + "pt_osm_free_energy_b{0:d}.png".format(i) fe_terms = OSM.compute_fe_terms(Xva[0:2500], 20) fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1]) fe_str = " nll_bound : {0:.4f}".format(fe_mean) print(fe_str) out_file.write(fe_str + "\n") utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \ x_label='Posterior KLd', y_label='Negative Log-likelihood') # compute information about posterior KLds on validation set file_name = RESULT_PATH + "pt_osm_post_klds_b{0:d}.png".format(i) post_klds = OSM.compute_post_klds(Xva[0:2500]) post_dim_klds = np.mean(post_klds, axis=0) utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \ file_name) if ((i % 5000) == 0): IN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_b{0:d}_IN.pkl".format(i)) GN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_b{0:d}_GN.pkl".format(i)) IN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_IN.pkl") GN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_GN.pkl") return
def pretrain_osm(lam_kld=0.0): # Initialize a source of randomness rng = np.random.RandomState(1234) # Load some data to train/validate/test with dataset = 'data/mnist.pkl.gz' datasets = load_udm(dataset, zero_mean=False) Xtr = datasets[0][0] Xtr = Xtr.get_value(borrow=False) Xva = datasets[2][0] Xva = Xva.get_value(borrow=False) print("Xtr.shape: {0:s}, Xva.shape: {1:s}".format(str(Xtr.shape), str(Xva.shape))) # get and set some basic dataset information Xtr_mean = np.mean(Xtr, axis=0) tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] batch_size = 100 batch_reps = 5 # setup some symbolic variables and stuff Xd = T.matrix('Xd_base') Xc = T.matrix('Xc_base') Xm = T.matrix('Xm_base') data_dim = Xtr.shape[1] prior_sigma = 1.0 ########################## # NETWORK CONFIGURATIONS # ########################## gn_params = {} shared_config = [PRIOR_DIM, 1000, 1000] top_config = [shared_config[-1], data_dim] gn_params['shared_config'] = shared_config gn_params['mu_config'] = top_config gn_params['sigma_config'] = top_config gn_params['activation'] = relu_actfun gn_params['init_scale'] = 1.4 gn_params['lam_l2a'] = 0.0 gn_params['vis_drop'] = 0.0 gn_params['hid_drop'] = 0.0 gn_params['bias_noise'] = 0.0 gn_params['input_noise'] = 0.0 # choose some parameters for the continuous inferencer in_params = {} shared_config = [data_dim, 1000, 1000] top_config = [shared_config[-1], PRIOR_DIM] in_params['shared_config'] = shared_config in_params['mu_config'] = top_config in_params['sigma_config'] = top_config in_params['activation'] = relu_actfun in_params['init_scale'] = 1.4 in_params['lam_l2a'] = 0.0 in_params['vis_drop'] = 0.0 in_params['hid_drop'] = 0.0 in_params['bias_noise'] = 0.0 in_params['input_noise'] = 0.0 # Initialize the base networks for this OneStageModel IN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \ params=in_params, shared_param_dicts=None) GN = InfNet(rng=rng, Xd=Xd, prior_sigma=prior_sigma, \ params=gn_params, shared_param_dicts=None) # Initialize biases in IN and GN IN.init_biases(0.2) GN.init_biases(0.2) ######################### # INITIALIZE THE GIPAIR # ######################### osm_params = {} osm_params['x_type'] = 'bernoulli' osm_params['xt_transform'] = 'sigmoid' osm_params['logvar_bound'] = LOGVAR_BOUND OSM = OneStageModel(rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, \ p_x_given_z=GN, q_z_given_x=IN, \ x_dim=data_dim, z_dim=PRIOR_DIM, params=osm_params) OSM.set_lam_l2w(1e-5) safe_mean = (0.9 * Xtr_mean) + 0.05 safe_mean_logit = np.log(safe_mean / (1.0 - safe_mean)) OSM.set_output_bias(safe_mean_logit) OSM.set_input_bias(-Xtr_mean) ###################### # BASIC VAE TRAINING # ###################### out_file = open(RESULT_PATH + "pt_osm_results.txt", 'wb') # Set initial learning rate and basic SGD hyper parameters obs_costs = np.zeros((batch_size, )) costs = [0. for i in range(10)] learn_rate = 0.0005 for i in range(150000): scale = min(1.0, float(i) / 10000.0) if ((i > 1) and ((i % 20000) == 0)): learn_rate = learn_rate * 0.9 # do a minibatch update of the model, and compute some costs tr_idx = npr.randint(low=0, high=tr_samples, size=(batch_size, )) Xd_batch = Xtr.take(tr_idx, axis=0) Xc_batch = 0.0 * Xd_batch Xm_batch = 0.0 * Xd_batch # do a minibatch update of the model, and compute some costs OSM.set_sgd_params(lr_1=(scale * learn_rate), mom_1=0.5, mom_2=0.98) OSM.set_lam_nll(1.0) OSM.set_lam_kld(lam_kld_1=(1.0 + (scale * (lam_kld - 1.0))), lam_kld_2=0.0) result = OSM.train_joint(Xd_batch, Xc_batch, Xm_batch, batch_reps) costs = [(costs[j] + result[j]) for j in range(len(result))] if ((i % 1000) == 0): # record and then reset the cost trackers costs = [(v / 1000.0) for v in costs] str_1 = "-- batch {0:d} --".format(i) str_2 = " joint_cost: {0:.4f}".format(costs[0]) str_3 = " nll_cost : {0:.4f}".format(costs[1]) str_4 = " kld_cost : {0:.4f}".format(costs[2]) str_5 = " reg_cost : {0:.4f}".format(costs[3]) costs = [0.0 for v in costs] # print out some diagnostic information joint_str = "\n".join([str_1, str_2, str_3, str_4, str_5]) print(joint_str) out_file.write(joint_str + "\n") out_file.flush() if ((i % 2000) == 0): Xva = row_shuffle(Xva) model_samps = OSM.sample_from_prior(500) file_name = RESULT_PATH + "pt_osm_samples_b{0:d}_XG.png".format(i) utils.visualize_samples(model_samps, file_name, num_rows=20) # compute information about free-energy on validation set file_name = RESULT_PATH + "pt_osm_free_energy_b{0:d}.png".format(i) fe_terms = OSM.compute_fe_terms(Xva[0:2500], 20) fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1]) fe_str = " nll_bound : {0:.4f}".format(fe_mean) print(fe_str) out_file.write(fe_str + "\n") utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \ x_label='Posterior KLd', y_label='Negative Log-likelihood') # compute information about posterior KLds on validation set file_name = RESULT_PATH + "pt_osm_post_klds_b{0:d}.png".format(i) post_klds = OSM.compute_post_klds(Xva[0:2500]) post_dim_klds = np.mean(post_klds, axis=0) utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \ file_name) if ((i % 5000) == 0): IN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_b{0:d}_IN.pkl".format(i)) GN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_b{0:d}_GN.pkl".format(i)) IN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_IN.pkl") GN.save_to_file(f_name=RESULT_PATH + "pt_osm_params_GN.pkl") return
def __init__(self, rng=None, Xd=None, Xc=None, Xm=None, Xt=None, \ i_net=None, g_net=None, d_net=None, chain_len=None, \ data_dim=None, prior_dim=None, params=None): # Do some stuff! self.rng = RandStream(rng.randint(100000)) self.data_dim = data_dim self.prior_dim = prior_dim self.prior_mean = 0.0 self.prior_logvar = 0.0 if params is None: self.params = {} else: self.params = params if 'cost_decay' in self.params: self.cost_decay = self.params['cost_decay'] else: self.cost_decay = 0.1 if 'chain_type' in self.params: assert((self.params['chain_type'] == 'walkback') or \ (self.params['chain_type'] == 'walkout')) self.chain_type = self.params['chain_type'] else: self.chain_type = 'walkout' if 'xt_transform' in self.params: assert((self.params['xt_transform'] == 'sigmoid') or \ (self.params['xt_transform'] == 'none')) if self.params['xt_transform'] == 'sigmoid': self.xt_transform = lambda x: T.nnet.sigmoid(x) else: self.xt_transform = lambda x: x else: self.xt_transform = lambda x: T.nnet.sigmoid(x) if 'logvar_bound' in self.params: self.logvar_bound = self.params['logvar_bound'] else: self.logvar_bound = 10 # # x_type: this tells if we're using bernoulli or gaussian model for # the observations # self.x_type = self.params['x_type'] assert ((self.x_type == 'bernoulli') or (self.x_type == 'gaussian')) # symbolic var for inputting samples for initializing the VAE chain self.Xd = Xd # symbolic var for masking subsets of the state variables self.Xm = Xm # symbolic var for controlling subsets of the state variables self.Xc = Xc # symbolic var for inputting samples from the target distribution self.Xt = Xt # integer number of times to cycle the VAE loop self.chain_len = chain_len # symbolic matrix of indices for data inputs self.It = T.arange(self.Xt.shape[0]) # symbolic matrix of indices for noise/generated inputs self.Id = T.arange( self.chain_len * self.Xd.shape[0]) + self.Xt.shape[0] # get a clone of the desired VAE, for easy access self.OSM = OneStageModel(rng=rng, Xd=self.Xd, Xc=self.Xc, Xm=self.Xm, \ p_x_given_z=g_net, q_z_given_x=i_net, x_dim=self.data_dim, \ z_dim=self.prior_dim, params=self.params) self.IN = self.OSM.q_z_given_x self.GN = self.OSM.p_x_given_z self.transform_x_to_z = self.OSM.transform_x_to_z self.transform_z_to_x = self.OSM.transform_z_to_x self.bounded_logvar = self.OSM.bounded_logvar # self-loop some clones of the main VAE into a chain. # ** All VAEs in the chain share the same Xc and Xm, which are the # symbolic inputs for providing the observed portion of the input # and a mask indicating which part of the input is "observed". # These inputs are used for training "reconstruction" policies. self.IN_chain = [] self.GN_chain = [] self.Xg_chain = [] _Xd = self.Xd print("Unrolling chain...") for i in range(self.chain_len): # create a VAE infer/generate pair with _Xd as input and with # masking variables shared by all VAEs in this chain _IN = self.IN.shared_param_clone(rng=rng, \ Xd=apply_mask(Xd=_Xd, Xc=self.Xc, Xm=self.Xm), \ build_funcs=False) _GN = self.GN.shared_param_clone(rng=rng, Xd=_IN.output, \ build_funcs=False) _Xd = self.xt_transform(_GN.output_mean) self.IN_chain.append(_IN) self.GN_chain.append(_GN) self.Xg_chain.append(_Xd) print(" step {}...".format(i)) # make a clone of the desired discriminator network, which will try # to discriminate between samples from the training data and samples # generated by the self-looped VAE chain. self.DN = d_net.shared_param_clone(rng=rng, \ Xd=T.vertical_stack(self.Xt, *self.Xg_chain)) zero_ary = np.zeros((1, )).astype(theano.config.floatX) # init shared var for weighting nll of data given posterior sample self.lam_chain_nll = theano.shared(value=zero_ary, name='vcg_lam_chain_nll') self.set_lam_chain_nll(lam_chain_nll=1.0) # init shared var for weighting posterior KL-div from prior self.lam_chain_kld = theano.shared(value=zero_ary, name='vcg_lam_chain_kld') self.set_lam_chain_kld(lam_chain_kld=1.0) # init shared var for controlling l2 regularization on params self.lam_l2w = theano.shared(value=zero_ary, name='vcg_lam_l2w') self.set_lam_l2w(lam_l2w=1e-4) # shared var learning rates for all networks self.lr_dn = theano.shared(value=zero_ary, name='vcg_lr_dn') self.lr_gn = theano.shared(value=zero_ary, name='vcg_lr_gn') self.lr_in = theano.shared(value=zero_ary, name='vcg_lr_in') # shared var momentum parameters for all networks self.mom_1 = theano.shared(value=zero_ary, name='vcg_mom_1') self.mom_2 = theano.shared(value=zero_ary, name='vcg_mom_2') self.it_count = theano.shared(value=zero_ary, name='vcg_it_count') # shared var weights for adversarial classification objective self.dw_dn = theano.shared(value=zero_ary, name='vcg_dw_dn') self.dw_gn = theano.shared(value=zero_ary, name='vcg_dw_gn') # init parameters for controlling learning dynamics self.set_all_sgd_params() self.set_disc_weights() # init adversarial cost weights for GN/DN # set a shared var for regularizing the output of the discriminator self.lam_l2d = theano.shared(value=(zero_ary + params['lam_l2d']), \ name='vcg_lam_l2d') # Grab the full set of "optimizable" parameters from the generator # and discriminator networks that we'll be working with. We need to # ignore parameters in the final layers of the proto-networks in the # discriminator network (a generalized pseudo-ensemble). We ignore them # because the VCGair requires that they be "bypassed" in favor of some # binary classification layers that will be managed by this VCGair. self.dn_params = [] for pn in self.DN.proto_nets: for pnl in pn[0:-1]: self.dn_params.extend(pnl.params) self.in_params = [p for p in self.IN.mlp_params] self.in_params.append(self.OSM.output_logvar) self.gn_params = [p for p in self.GN.mlp_params] self.joint_params = self.in_params + self.gn_params + self.dn_params # Now construct a binary discriminator layer for each proto-net in the # discriminator network. And, add their params to optimization list. self._construct_disc_layers(rng) self.disc_reg_cost = self.lam_l2d[0] * \ T.sum([dl.act_l2_sum for dl in self.disc_layers]) # Construct costs for the generator and discriminator networks based # on adversarial binary classification self.disc_cost_dn, self.disc_cost_gn = self._construct_disc_costs() # first, build the cost to be optimized by the discriminator network, # in general this will be treated somewhat indepedently of the # optimization of the generator and inferencer networks. self.dn_cost = self.disc_cost_dn + self.DN.act_reg_cost + \ self.disc_reg_cost # construct costs relevant to the optimization of the generator and # discriminator networks self.chain_nll_cost = self.lam_chain_nll[0] * \ self._construct_chain_nll_cost(cost_decay=self.cost_decay) self.chain_kld_cost = self.lam_chain_kld[0] * \ self._construct_chain_kld_cost(cost_decay=self.cost_decay) self.other_reg_cost = self._construct_other_reg_cost() self.osm_cost = self.disc_cost_gn + self.chain_nll_cost + \ self.chain_kld_cost + self.other_reg_cost # compute total cost on the discriminator and VB generator/inferencer self.joint_cost = self.dn_cost + self.osm_cost # Get the gradient of the joint cost for all optimizable parameters self.joint_grads = OrderedDict() print("Computing VCGLoop DN cost gradients...") grad_list = T.grad(self.dn_cost, self.dn_params, disconnected_inputs='warn') for i, p in enumerate(self.dn_params): self.joint_grads[p] = grad_list[i] print("Computing VCGLoop IN cost gradients...") grad_list = T.grad(self.osm_cost, self.in_params, disconnected_inputs='warn') for i, p in enumerate(self.in_params): self.joint_grads[p] = grad_list[i] print("Computing VCGLoop GN cost gradients...") grad_list = T.grad(self.osm_cost, self.gn_params, disconnected_inputs='warn') for i, p in enumerate(self.gn_params): self.joint_grads[p] = grad_list[i] # construct the updates for the discriminator, generator and # inferencer networks. all networks share the same first/second # moment momentum and iteration count. the networks each have their # own learning rates, which lets you turn their learning on/off. self.dn_updates = get_param_updates(params=self.dn_params, \ grads=self.joint_grads, alpha=self.lr_dn, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) self.gn_updates = get_param_updates(params=self.gn_params, \ grads=self.joint_grads, alpha=self.lr_gn, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) self.in_updates = get_param_updates(params=self.in_params, \ grads=self.joint_grads, alpha=self.lr_in, \ beta1=self.mom_1, beta2=self.mom_2, it_count=self.it_count, \ mom2_init=1e-3, smoothing=1e-8, max_grad_norm=10.0) # bag up all the updates required for training self.joint_updates = OrderedDict() for k in self.dn_updates: self.joint_updates[k] = self.dn_updates[k] for k in self.gn_updates: self.joint_updates[k] = self.gn_updates[k] for k in self.in_updates: self.joint_updates[k] = self.in_updates[k] # construct an update for tracking the mean KL divergence of # approximate posteriors for this chain new_kld_mean = (0.98 * self.IN.kld_mean) + ((0.02 / self.chain_len) * \ sum([T.mean(I_N.kld_cost) for I_N in self.IN_chain])) self.joint_updates[self.IN.kld_mean] = T.cast(new_kld_mean, 'floatX') # construct the function for training on training data print("Compiling VCGLoop theano functions....") self.train_joint = self._construct_train_joint() return
def test_gip_sigma_scale_tfd(): from LogPDFs import cross_validate_sigma # Simple test code, to check that everything is basically functional. print("TESTING...") # Initialize a source of randomness rng = np.random.RandomState(12345) # Load some data to train/validate/test with data_file = 'data/tfd_data_48x48.pkl' dataset = load_tfd(tfd_pkl_name=data_file, which_set='unlabeled', fold='all') Xtr_unlabeled = dataset[0] dataset = load_tfd(tfd_pkl_name=data_file, which_set='train', fold='all') Xtr_train = dataset[0] Xtr = np.vstack([Xtr_unlabeled, Xtr_train]) dataset = load_tfd(tfd_pkl_name=data_file, which_set='test', fold='all') Xva = dataset[0] tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] print("Xtr.shape: {0:s}, Xva.shape: {1:s}".format(str(Xtr.shape), str(Xva.shape))) # get and set some basic dataset information tr_samples = Xtr.shape[0] data_dim = Xtr.shape[1] batch_size = 100 # Symbolic inputs Xd = T.matrix(name='Xd') Xc = T.matrix(name='Xc') Xm = T.matrix(name='Xm') Xt = T.matrix(name='Xt') # Load inferencer and generator from saved parameters gn_fname = "TFD_WALKOUT_TEST_KLD/pt_walk_params_b25000_GN.pkl" in_fname = "TFD_WALKOUT_TEST_KLD/pt_walk_params_b25000_IN.pkl" IN = load_infnet_from_file(f_name=in_fname, rng=rng, Xd=Xd) GN = load_infnet_from_file(f_name=gn_fname, rng=rng, Xd=Xd) x_dim = IN.shared_layers[0].in_dim z_dim = IN.mu_layers[-1].out_dim # construct a GIPair with the loaded InfNet and GenNet osm_params = {} osm_params['x_type'] = 'gaussian' osm_params['xt_transform'] = 'sigmoid' osm_params['logvar_bound'] = LOGVAR_BOUND OSM = OneStageModel(rng=rng, Xd=Xd, Xc=Xc, Xm=Xm, \ p_x_given_z=GN, q_z_given_x=IN, \ x_dim=x_dim, z_dim=z_dim, params=osm_params) # # compute variational likelihood bound and its sub-components Xva = row_shuffle(Xva) Xb = Xva[0:5000] # file_name = "A_TFD_POST_KLDS.png" # post_klds = OSM.compute_post_klds(Xb) # post_dim_klds = np.mean(post_klds, axis=0) # utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, \ # file_name) # compute information about free-energy on validation set file_name = "A_TFD_KLD_FREE_ENERGY.png" fe_terms = OSM.compute_fe_terms(Xb, 20) utils.plot_scatter(fe_terms[1], fe_terms[0], file_name, \ x_label='Posterior KLd', y_label='Negative Log-likelihood') # bound_results = OSM.compute_ll_bound(Xva) # ll_bounds = bound_results[0] # post_klds = bound_results[1] # log_likelihoods = bound_results[2] # max_lls = bound_results[3] # print("mean ll bound: {0:.4f}".format(np.mean(ll_bounds))) # print("mean posterior KLd: {0:.4f}".format(np.mean(post_klds))) # print("mean log-likelihood: {0:.4f}".format(np.mean(log_likelihoods))) # print("mean max log-likelihood: {0:.4f}".format(np.mean(max_lls))) # print("min ll bound: {0:.4f}".format(np.min(ll_bounds))) # print("max posterior KLd: {0:.4f}".format(np.max(post_klds))) # print("min log-likelihood: {0:.4f}".format(np.min(log_likelihoods))) # print("min max log-likelihood: {0:.4f}".format(np.min(max_lls))) # # compute some information about the approximate posteriors # post_stats = OSM.compute_post_stats(Xva, 0.0*Xva, 0.0*Xva) # all_post_klds = np.sort(post_stats[0].ravel()) # post KLds for each obs and dim # obs_post_klds = np.sort(post_stats[1]) # summed post KLds for each obs # post_dim_klds = post_stats[2] # average post KLds for each post dim # post_dim_vars = post_stats[3] # average squared mean for each post dim # utils.plot_line(np.arange(all_post_klds.shape[0]), all_post_klds, "AAA_ALL_POST_KLDS.png") # utils.plot_line(np.arange(obs_post_klds.shape[0]), obs_post_klds, "AAA_OBS_POST_KLDS.png") # utils.plot_stem(np.arange(post_dim_klds.shape[0]), post_dim_klds, "AAA_POST_DIM_KLDS.png") # utils.plot_stem(np.arange(post_dim_vars.shape[0]), post_dim_vars, "AAA_POST_DIM_VARS.png") # draw many samples from the GIP for i in range(5): tr_idx = npr.randint(low=0, high=tr_samples, size=(100, )) Xd_batch = Xtr.take(tr_idx, axis=0) Xs = [] for row in range(3): Xs.append([]) for col in range(3): sample_lists = OSM.sample_from_chain(Xd_batch[0:10,:], loop_iters=100, \ sigma_scale=1.0) Xs[row].append(group_chains(sample_lists['data samples'])) Xs, block_im_dim = block_video(Xs, (48, 48), (3, 3)) to_video(Xs, block_im_dim, "A_TFD_KLD_CHAIN_VIDEO_{0:d}.avi".format(i), frame_rate=10) #sample_lists = GIP.sample_from_chain(Xd_batch[0,:].reshape((1,data_dim)), loop_iters=300, \ # sigma_scale=1.0) #Xs = np.vstack(sample_lists["data samples"]) #file_name = "TFD_TEST_{0:d}.png".format(i) #utils.visualize_samples(Xs, file_name, num_rows=15) file_name = "A_TFD_KLD_PRIOR_SAMPLE.png" Xs = OSM.sample_from_prior(20 * 20) utils.visualize_samples(Xs, file_name, num_rows=20) # test Parzen density estimator built from prior samples # Xs = OSM.sample_from_prior(10000) # [best_sigma, best_ll, best_lls] = \ # cross_validate_sigma(Xs, Xva, [0.09, 0.095, 0.1, 0.105, 0.11], 10) # sort_idx = np.argsort(best_lls) # sort_idx = sort_idx[0:400] # utils.plot_line(np.arange(sort_idx.shape[0]), best_lls[sort_idx], "A_TFD_BEST_LLS_1.png") # utils.visualize_samples(Xva[sort_idx], "A_TFD_BAD_FACES_1.png", num_rows=20) return
def test_svhn(occ_dim=15, drop_prob=0.0): RESULT_PATH = "IMP_SVHN_VAE/" ######################################### # Format the result tag more thoroughly # ######################################### dp_int = int(100.0 * drop_prob) result_tag = "{}VAE_OD{}_DP{}".format(RESULT_PATH, occ_dim, dp_int) ########################## # Get some training data # ########################## tr_file = 'data/svhn_train_gray.pkl' te_file = 'data/svhn_test_gray.pkl' ex_file = 'data/svhn_extra_gray.pkl' data = load_svhn_gray(tr_file, te_file, ex_file=ex_file, ex_count=200000) Xtr = to_fX( shift_and_scale_into_01(np.vstack([data['Xtr'], data['Xex']])) ) Xva = to_fX( shift_and_scale_into_01(data['Xte']) ) tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] batch_size = 250 all_pix_mean = np.mean(np.mean(Xtr, axis=1)) data_mean = to_fX( all_pix_mean * np.ones((Xtr.shape[1],)) ) ############################################################ # Setup some parameters for the Iterative Refinement Model # ############################################################ obs_dim = Xtr.shape[1] z_dim = 100 imp_steps = 15 # we'll check for the best step count (found oracularly) init_scale = 1.0 x_in_sym = T.matrix('x_in_sym') x_out_sym = T.matrix('x_out_sym') x_mask_sym = T.matrix('x_mask_sym') ################# # p_zi_given_xi # ################# params = {} shared_config = [obs_dim, 1000, 1000] top_config = [shared_config[-1], z_dim] params['shared_config'] = shared_config params['mu_config'] = top_config params['sigma_config'] = top_config params['activation'] = relu_actfun params['init_scale'] = init_scale params['lam_l2a'] = 0.0 params['vis_drop'] = 0.0 params['hid_drop'] = 0.0 params['bias_noise'] = 0.0 params['input_noise'] = 0.0 params['build_theano_funcs'] = False p_zi_given_xi = InfNet(rng=rng, Xd=x_in_sym, \ params=params, shared_param_dicts=None) p_zi_given_xi.init_biases(0.2) ################### # p_xip1_given_zi # ################### params = {} shared_config = [z_dim, 1000, 1000] output_config = [obs_dim, obs_dim] params['shared_config'] = shared_config params['output_config'] = output_config params['activation'] = relu_actfun params['init_scale'] = init_scale params['lam_l2a'] = 0.0 params['vis_drop'] = 0.0 params['hid_drop'] = 0.0 params['bias_noise'] = 0.0 params['input_noise'] = 0.0 params['build_theano_funcs'] = False p_xip1_given_zi = HydraNet(rng=rng, Xd=x_in_sym, \ params=params, shared_param_dicts=None) p_xip1_given_zi.init_biases(0.2) ################### # q_zi_given_x_xi # ################### params = {} shared_config = [(obs_dim + obs_dim), 1000, 1000] top_config = [shared_config[-1], z_dim] params['shared_config'] = shared_config params['mu_config'] = top_config params['sigma_config'] = top_config params['activation'] = relu_actfun params['init_scale'] = init_scale params['lam_l2a'] = 0.0 params['vis_drop'] = 0.0 params['hid_drop'] = 0.0 params['bias_noise'] = 0.0 params['input_noise'] = 0.0 params['build_theano_funcs'] = False q_zi_given_x_xi = InfNet(rng=rng, Xd=x_in_sym, \ params=params, shared_param_dicts=None) q_zi_given_x_xi.init_biases(0.2) ########################################################### # Define parameters for the GPSImputer, and initialize it # ########################################################### print("Building the GPSImputer...") gpsi_params = {} gpsi_params['obs_dim'] = obs_dim gpsi_params['z_dim'] = z_dim gpsi_params['imp_steps'] = imp_steps gpsi_params['step_type'] = 'jump' gpsi_params['x_type'] = 'bernoulli' gpsi_params['obs_transform'] = 'sigmoid' gpsi_params['use_osm_mode'] = True GPSI = GPSImputer(rng=rng, x_in=x_in_sym, x_out=x_out_sym, x_mask=x_mask_sym, \ p_zi_given_xi=p_zi_given_xi, \ p_xip1_given_zi=p_xip1_given_zi, \ q_zi_given_x_xi=q_zi_given_x_xi, \ params=gpsi_params, \ shared_param_dicts=None) ######################################################################### # Define parameters for the underlying OneStageModel, and initialize it # ######################################################################### print("Building the OneStageModel...") osm_params = {} osm_params['x_type'] = 'bernoulli' osm_params['xt_transform'] = 'sigmoid' OSM = OneStageModel(rng=rng, \ x_in=x_in_sym, \ p_x_given_z=p_xip1_given_zi, \ q_z_given_x=p_zi_given_xi, \ x_dim=obs_dim, z_dim=z_dim, \ params=osm_params) ################################################################ # Apply some updates, to check that they aren't totally broken # ################################################################ log_name = "{}_RESULTS.txt".format(result_tag) out_file = open(log_name, 'wb') costs = [0. for i in range(10)] learn_rate = 0.0002 momentum = 0.5 batch_idx = np.arange(batch_size) + tr_samples for i in range(200005): scale = min(1.0, ((i+1) / 5000.0)) if (((i + 1) % 15000) == 0): learn_rate = learn_rate * 0.92 if (i > 10000): momentum = 0.90 else: momentum = 0.50 # get the indices of training samples for this batch update batch_idx += batch_size if (np.max(batch_idx) >= tr_samples): # we finished an "epoch", so we rejumble the training set Xtr = row_shuffle(Xtr) batch_idx = np.arange(batch_size) # set sgd and objective function hyperparams for this update OSM.set_sgd_params(lr=scale*learn_rate, \ mom_1=scale*momentum, mom_2=0.99) OSM.set_lam_nll(lam_nll=1.0) OSM.set_lam_kld(lam_kld_1=1.0, lam_kld_2=0.0) OSM.set_lam_l2w(1e-4) # perform a minibatch update and record the cost for this batch xb = to_fX( Xtr.take(batch_idx, axis=0) ) result = OSM.train_joint(xb, batch_reps) costs = [(costs[j] + result[j]) for j in range(len(result)-1)] if ((i % 250) == 0): costs = [(v / 250.0) for v in costs] str1 = "-- batch {0:d} --".format(i) str2 = " joint_cost: {0:.4f}".format(costs[0]) str3 = " nll_cost : {0:.4f}".format(costs[1]) str4 = " kld_cost : {0:.4f}".format(costs[2]) str5 = " reg_cost : {0:.4f}".format(costs[3]) joint_str = "\n".join([str1, str2, str3, str4, str5]) print(joint_str) out_file.write(joint_str+"\n") out_file.flush() costs = [0.0 for v in costs] if ((i % 1000) == 0): Xva = row_shuffle(Xva) # record an estimate of performance on the test set xi, xo, xm = construct_masked_data(Xva[0:5000], drop_prob=drop_prob, \ occ_dim=occ_dim, data_mean=data_mean) step_nll, step_kld = GPSI.compute_per_step_cost(xi, xo, xm, sample_count=10) min_nll = np.min(step_nll) str1 = " va_nll_bound : {}".format(min_nll) str2 = " va_nll_min : {}".format(min_nll) str3 = " va_nll_final : {}".format(step_nll[-1]) joint_str = "\n".join([str1, str2, str3]) print(joint_str) out_file.write(joint_str+"\n") out_file.flush() if ((i % 10000) == 0): # Get some validation samples for evaluating model performance xb = to_fX( Xva[0:100] ) xi, xo, xm = construct_masked_data(xb, drop_prob=drop_prob, \ occ_dim=occ_dim, data_mean=data_mean) xi = np.repeat(xi, 2, axis=0) xo = np.repeat(xo, 2, axis=0) xm = np.repeat(xm, 2, axis=0) # draw some sample imputations from the model samp_count = xi.shape[0] _, model_samps = GPSI.sample_imputer(xi, xo, xm, use_guide_policy=False) seq_len = len(model_samps) seq_samps = np.zeros((seq_len*samp_count, model_samps[0].shape[1])) idx = 0 for s1 in range(samp_count): for s2 in range(seq_len): seq_samps[idx] = model_samps[s2][s1] idx += 1 file_name = "{}_samples_ng_b{}.png".format(result_tag, i) utils.visualize_samples(seq_samps, file_name, num_rows=20) # get visualizations of policy parameters file_name = "{}_gen_gen_weights_b{}.png".format(result_tag, i) W = GPSI.gen_gen_weights.get_value(borrow=False) utils.visualize_samples(W[:,:obs_dim], file_name, num_rows=20) file_name = "{}_gen_inf_weights_b{}.png".format(result_tag, i) W = GPSI.gen_inf_weights.get_value(borrow=False).T utils.visualize_samples(W[:,:obs_dim], file_name, num_rows=20)
def test_one_stage_model(): ########################## # Get some training data # ########################## rng = np.random.RandomState(1234) Xtr, Xva, Xte = load_binarized_mnist(data_path='./data/') Xtr = np.vstack((Xtr, Xva)) Xva = Xte #del Xte tr_samples = Xtr.shape[0] va_samples = Xva.shape[0] batch_size = 128 batch_reps = 1 ############################################### # Setup some parameters for the OneStageModel # ############################################### x_dim = Xtr.shape[1] z_dim = 64 x_type = 'bernoulli' xin_sym = T.matrix('xin_sym') ############### # p_x_given_z # ############### params = {} shared_config = \ [ {'layer_type': 'fc', 'in_chans': z_dim, 'out_chans': 256, 'activation': relu_actfun, 'apply_bn': True}, \ {'layer_type': 'fc', 'in_chans': 256, 'out_chans': 7*7*128, 'activation': relu_actfun, 'apply_bn': True, 'shape_func_out': lambda x: T.reshape(x, (-1, 128, 7, 7))}, \ {'layer_type': 'conv', 'in_chans': 128, # in shape: (batch, 128, 7, 7) 'out_chans': 64, # out shape: (batch, 64, 14, 14) 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'half', 'apply_bn': True} ] output_config = \ [ {'layer_type': 'conv', 'in_chans': 64, # in shape: (batch, 64, 14, 14) 'out_chans': 1, # out shape: (batch, 1, 28, 28) 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'half', 'apply_bn': False, 'shape_func_out': lambda x: T.flatten(x, 2)}, \ {'layer_type': 'conv', 'in_chans': 64, 'out_chans': 1, 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'half', 'apply_bn': False, 'shape_func_out': lambda x: T.flatten(x, 2)} ] params['shared_config'] = shared_config params['output_config'] = output_config params['init_scale'] = 1.0 params['build_theano_funcs'] = False p_x_given_z = HydraNet(rng=rng, Xd=xin_sym, \ params=params, shared_param_dicts=None) p_x_given_z.init_biases(0.0) ############### # q_z_given_x # ############### params = {} shared_config = \ [ {'layer_type': 'conv', 'in_chans': 1, # in shape: (batch, 784) 'out_chans': 64, # out shape: (batch, 64, 14, 14) 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'double', 'apply_bn': True, 'shape_func_in': lambda x: T.reshape(x, (-1, 1, 28, 28))}, \ {'layer_type': 'conv', 'in_chans': 64, # in shape: (batch, 64, 14, 14) 'out_chans': 128, # out shape: (batch, 128, 7, 7) 'activation': relu_actfun, 'filt_dim': 5, 'conv_stride': 'double', 'apply_bn': True, 'shape_func_out': lambda x: T.flatten(x, 2)}, \ {'layer_type': 'fc', 'in_chans': 128*7*7, 'out_chans': 256, 'activation': relu_actfun, 'apply_bn': True} ] output_config = \ [ {'layer_type': 'fc', 'in_chans': 256, 'out_chans': z_dim, 'activation': relu_actfun, 'apply_bn': False}, \ {'layer_type': 'fc', 'in_chans': 256, 'out_chans': z_dim, 'activation': relu_actfun, 'apply_bn': False} ] params['shared_config'] = shared_config params['output_config'] = output_config params['init_scale'] = 1.0 params['build_theano_funcs'] = False q_z_given_x = HydraNet(rng=rng, Xd=xin_sym, \ params=params, shared_param_dicts=None) q_z_given_x.init_biases(0.0) ############################################################## # Define parameters for the TwoStageModel, and initialize it # ############################################################## print("Building the OneStageModel...") osm_params = {} osm_params['x_type'] = x_type osm_params['obs_transform'] = 'sigmoid' OSM = OneStageModel(rng=rng, x_in=xin_sym, x_dim=x_dim, z_dim=z_dim, p_x_given_z=p_x_given_z, q_z_given_x=q_z_given_x, params=osm_params) ################################################################ # Apply some updates, to check that they aren't totally broken # ################################################################ log_name = "{}_RESULTS.txt".format("OSM_TEST") out_file = open(log_name, 'wb') costs = [0. for i in range(10)] learn_rate = 0.0005 momentum = 0.9 batch_idx = np.arange(batch_size) + tr_samples for i in range(500000): scale = min(0.5, ((i + 1) / 5000.0)) if (((i + 1) % 10000) == 0): learn_rate = learn_rate * 0.95 # get the indices of training samples for this batch update batch_idx += batch_size if (np.max(batch_idx) >= tr_samples): # we finished an "epoch", so we rejumble the training set Xtr = row_shuffle(Xtr) batch_idx = np.arange(batch_size) Xb = to_fX(Xtr.take(batch_idx, axis=0)) #Xb = binarize_data(Xtr.take(batch_idx, axis=0)) # set sgd and objective function hyperparams for this update OSM.set_sgd_params(lr=scale*learn_rate, \ mom_1=(scale*momentum), mom_2=0.98) OSM.set_lam_nll(lam_nll=1.0) OSM.set_lam_kld(lam_kld=1.0) OSM.set_lam_l2w(1e-5) # perform a minibatch update and record the cost for this batch result = OSM.train_joint(Xb, batch_reps) costs = [(costs[j] + result[j]) for j in range(len(result))] if ((i % 500) == 0): costs = [(v / 500.0) for v in costs] str1 = "-- batch {0:d} --".format(i) str2 = " joint_cost: {0:.4f}".format(costs[0]) str3 = " nll_cost : {0:.4f}".format(costs[1]) str4 = " kld_cost : {0:.4f}".format(costs[2]) str5 = " reg_cost : {0:.4f}".format(costs[3]) joint_str = "\n".join([str1, str2, str3, str4, str5]) print(joint_str) out_file.write(joint_str + "\n") out_file.flush() costs = [0.0 for v in costs] if (((i % 5000) == 0) or ((i < 10000) and ((i % 1000) == 0))): # draw some independent random samples from the model samp_count = 300 model_samps = OSM.sample_from_prior(samp_count) file_name = "OSM_SAMPLES_b{0:d}.png".format(i) utils.visualize_samples(model_samps, file_name, num_rows=15) # compute free energy estimate for validation samples Xva = row_shuffle(Xva) fe_terms = OSM.compute_fe_terms(Xva[0:5000], 20) fe_mean = np.mean(fe_terms[0]) + np.mean(fe_terms[1]) out_str = " nll_bound : {0:.4f}".format(fe_mean) print(out_str) out_file.write(out_str + "\n") out_file.flush() return