def sample(self): dir_normal = tf.random.normal(shape=tf.shape(self.centers)) dir_normal_norm = helper.safe_tf_sqrt(tf.reduce_sum(dir_normal**2, axis=1, keepdims=True)) sample_dir = dir_normal/dir_normal_norm sample_norm = tf.random_uniform(tf.shape(self.inner_radius), 0, 1, dtype=tf.float32)*(self.outer_radius-self.inner_radius)+self.inner_radius sample = sample_norm*sample_dir+self.centers return sample
def inference(self, batch, additional_inputs_tf): self.epoch = additional_inputs_tf[0] self.b_identity = additional_inputs_tf[1] if len(batch['observed']['properties']['flat']) > 0: for e in batch['observed']['properties']['flat']: e['dist'] = 'dirac' else: for e in batch['observed']['properties']['image']: e['dist'] = 'dirac' self.input_sample = batch['observed']['data'] self.input_dist = distributions.ProductDistribution( sample_properties=batch['observed']['properties'], params=self.input_sample) if not self.bModules: self.generate_modules(batch) try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1] except: self.n_time = batch['observed']['properties']['image'][0]['size'][ 1] try: batch_size_tf = tf.shape(self.input_sample['flat'])[0] except: batch_size_tf = tf.shape(self.input_sample['image'])[0] self.prior_param = self.PriorMap.forward( (tf.zeros(shape=(batch_size_tf, 1)), )) self.prior_dist = distributions.DiagonalGaussianDistribution( params=self.prior_param) self.prior_latent_code = self.prior_dist.sample() self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code) self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior) self.uniform_dist = distributions.UniformDistribution( params=tf.concat([ tf.zeros(shape=(batch_size_tf, 1)), tf.ones(shape=(batch_size_tf, 1)) ], axis=1)) self.convex_mix = self.uniform_dist.sample() self.prior_latent_code_expanded = tf.reshape( self.prior_latent_code, [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]]) self.obs_sample_param = self.Decoder.forward( self.prior_latent_code_expanded) self.obs_sample_dist = distributions.ProductDistribution( sample_properties=batch['observed']['properties'], params=self.obs_sample_param) self.obs_sample = self.obs_sample_dist.sample(b_mode=True) ############################################################################# b_use_reconst_as_reg = True self.posterior_param_expanded = self.Encoder.forward(self.input_sample) self.posterior_param = self.posterior_param_expanded[:, 0, :] self.posterior_dist = distributions.DiagonalGaussianDistribution( params=self.posterior_param) self.posterior_latent_code = self.posterior_dist.sample() self.posterior_latent_code_expanded = self.posterior_latent_code[:, np. newaxis, :] self.reconst_param = self.Decoder.forward( self.posterior_latent_code_expanded) self.reconst_dist = distributions.ProductDistribution( sample_properties=batch['observed']['properties'], params=self.reconst_param) self.reconst_sample = self.reconst_dist.sample(b_mode=True) if b_use_reconst_as_reg: self.reg_target_dist = self.reconst_dist self.reg_target_sample = self.reconst_sample else: self.reg_target_dist = self.obs_sample_dist self.reg_target_sample = self.obs_sample self.neg_cross_ent_posterior = self.prior_dist.log_pdf( self.posterior_latent_code) self.mean_neg_cross_ent_posterior = tf.reduce_mean( self.neg_cross_ent_posterior) self.neg_ent_posterior = self.posterior_dist.log_pdf( self.posterior_latent_code) self.mean_neg_ent_posterior = tf.reduce_mean(self.neg_ent_posterior) # self.kl_posterior_prior = distributions.KLDivDiagGaussianVsDiagGaussian().forward(self.posterior_dist, self.prior_dist) # self.mean_kl_posterior_prior = tf.reduce_mean(self.kl_posterior_prior) self.mean_kl_posterior_prior = -self.mean_neg_cross_ent_posterior + self.mean_neg_ent_posterior ############################################################################# self.reg_sample_param = {'image': None, 'flat': None} try: self.reg_sample_param['image'] = self.convex_mix[:, np.newaxis, :, np.newaxis, np.newaxis]*self.reg_target_sample['image']+\ (1-self.convex_mix[:, np.newaxis, :, np.newaxis, np.newaxis])*self.input_sample['image'] except: self.reg_sample_param['flat'] = self.convex_mix[:, np.newaxis, :]*self.reg_target_sample['flat']+\ (1-self.convex_mix[:, np.newaxis, :])*self.input_sample['flat'] self.reg_dist = distributions.ProductDistribution( sample_properties=batch['observed']['properties'], params=self.reg_sample_param) self.reg_sample = self.reg_dist.sample(b_mode=True) self.critic_real = self.Discriminator.forward(self.input_sample) self.critic_gen = self.Discriminator.forward(self.obs_sample) self.critic_reg = self.Discriminator.forward(self.reg_sample) lambda_t = 1 self.real_reconst_distances_sq = self.metric_distance_sq( self.input_sample, self.reconst_sample) self.autoencode_costs = self.real_reconst_distances_sq / 2 + lambda_t * self.mean_kl_posterior_prior try: self.convex_grad = tf.gradients(self.critic_reg, [self.reg_sample['image']])[0] self.convex_grad_norm = helper.safe_tf_sqrt( tf.reduce_sum(self.convex_grad**2, axis=[-1, -2, -3], keep_dims=False)[:, :, np.newaxis]) except: self.convex_grad = tf.gradients(self.critic_reg, [self.reg_sample['flat']])[0] self.convex_grad_norm = helper.safe_tf_sqrt( tf.reduce_sum(self.convex_grad**2, axis=[-1], keep_dims=True)) self.gradient_penalties = ((self.convex_grad_norm - 1)**2) self.mean_critic_real = tf.reduce_mean(self.critic_real) self.mean_critic_gen = tf.reduce_mean(self.critic_gen) self.mean_critic_reg = tf.reduce_mean(self.critic_reg) self.mean_autoencode_cost = tf.reduce_mean(self.autoencode_costs) self.mean_gradient_penalty = tf.reduce_mean(self.gradient_penalties) self.regularizer_cost = 10 * self.mean_gradient_penalty self.discriminator_cost = -self.mean_critic_real + self.mean_critic_gen + self.regularizer_cost self.generator_cost = -self.mean_critic_gen self.transporter_cost = self.mean_autoencode_cost
def euclidean_distance(self, a, b): return helper.safe_tf_sqrt(self.metric_distance_sq(a, b))
def inference(self, batch, additional_inputs_tf): self.epoch = additional_inputs_tf[0] self.b_identity = additional_inputs_tf[1] if len(batch['observed']['properties']['flat'])>0: for e in batch['observed']['properties']['flat']: e['dist']='dirac' else: for e in batch['observed']['properties']['image']: e['dist']='dirac' self.input_sample = batch['observed']['data'] self.input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.input_sample) if not self.bModules: self.generate_modules(batch) try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1] except: self.n_time = batch['observed']['properties']['image'][0]['size'][1] try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0] except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0] ############################################################################# self.prior_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),)) self.prior_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param) self.prior_latent_code = self.prior_dist.sample() self.prior_latent_code_expanded = self.prior_latent_code[:,np.newaxis,:] # GENERATOR self.prior_feature_expanded = self.Generator.forward(self.prior_latent_code_expanded) self.prior_feature = self.prior_feature_expanded[:,0,:] self.obs_sample_param = self.PostGen.forward(self.prior_feature_expanded) self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param) self.obs_sample = self.obs_sample_dist.sample(b_mode=True) if not os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/'): os.makedirs(str(Path.home())+'/ExperimentalResults/FixedSamples/') if os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz'): np_constant_prior_sample = np.load(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz') else: np_constant_prior_sample = np.random.normal(loc=0., scale=1., size=[400, self.prior_latent_code.get_shape().as_list()[-1]]) np.save(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample) self.constant_prior_latent_code = tf.constant(np.asarray(np_constant_prior_sample), dtype=np.float32) self.constant_prior_latent_code_expanded = self.constant_prior_latent_code[:, np.newaxis, :] self.constant_prior_feature_expanded = self.Generator.forward(self.constant_prior_latent_code_expanded) self.constant_prior_feature = self.constant_prior_feature_expanded[:,0,:] self.constant_obs_sample_param = self.PostGen.forward(self.constant_prior_feature_expanded) self.constant_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_sample_param) self.constant_obs_sample = self.constant_obs_sample_dist.sample(b_mode=True) if self.config['n_latent'] == 2: grid_scale = 3 x = np.linspace(-grid_scale, grid_scale, 20) y = np.linspace(grid_scale, -grid_scale, 20) xv, yv = np.meshgrid(x, y) np_constant_prior_grid_sample = np.concatenate((xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]), axis=1) self.constant_prior_grid_latent_code = tf.constant(np.asarray(np_constant_prior_grid_sample), dtype=np.float32) self.constant_prior_grid_latent_code_expanded = self.constant_prior_grid_latent_code[:, np.newaxis, :] self.constant_prior_grid_feature_expanded = self.Generator.forward(self.constant_prior_grid_latent_code_expanded) self.constant_prior_grid_feature = self.constant_prior_grid_feature_expanded[:,0,:] self.constant_obs_grid_sample_param = self.PostGen.forward(self.constant_prior_grid_feature_expanded) self.constant_obs_grid_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_grid_sample_param) self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(b_mode=True) ############################################################################# # ENCODER self.epsilon_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),)) self.epsilon_dist = distributions.DiagonalGaussianDistribution(params = self.epsilon_param) if self.config['encoder_mode'] == 'Deterministic': self.epsilon = None if self.config['encoder_mode'] == 'Gaussian' or self.config['encoder_mode'] == 'UnivApprox' or self.config['encoder_mode'] == 'UnivApproxNoSpatial' or self.config['encoder_mode'] == 'UnivApproxSine': self.epsilon = self.epsilon_dist.sample() self.pre_posterior_feature_expanded = self.PreEnc.forward(self.input_sample, noise=self.epsilon) self.pre_posterior_feature = self.pre_posterior_feature_expanded[:,0,:] # self.tiny_perturb_dist = distributions.DiagonalGaussianDistribution(params = tf.zeros(shape=(self.batch_size_tf, 2*512))) # self.tiny_perturb_sample_expanded = 0.1*self.tiny_perturb_dist.sample()[:,np.newaxis,:] self.flow_param = self.FlowMap.forward() self.flow_object = transforms.HouseholdRotationFlow(self.flow_param, 512) self.transformed_pre_posterior_feature, _ = self.flow_object.transform(self.pre_posterior_feature, tf.zeros(shape=(self.batch_size_tf, 1))) self.transformed_pre_posterior_feature_abs = helper.relu_abs(self.transformed_pre_posterior_feature) self.transformed_pre_posterior_feature_abs_means = tf.reduce_mean(self.transformed_pre_posterior_feature_abs, axis=0) # self.reconst_param = self.PostGen.forward(self.pre_posterior_feature_expanded+self.tiny_perturb_sample_expanded) self.reconst_param = self.PostGen.forward(self.pre_posterior_feature_expanded) self.reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reconst_param) self.reconst_sample = self.reconst_dist.sample(b_mode=True) self.posterior_latent_code_expanded = self.Encoder.forward(self.pre_posterior_feature_expanded) self.posterior_latent_code = self.posterior_latent_code_expanded[:,0,:] self.posterior_feature_expanded = self.Generator.forward(self.posterior_latent_code_expanded) self.posterior_feature = self.posterior_feature_expanded[:,0,:] self.full_reconst_param = self.PostGen.forward(self.posterior_feature_expanded) self.full_reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.full_reconst_param) self.full_reconst_sample = self.full_reconst_dist.sample(b_mode=True) self.interpolated_posterior_latent_code = self.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2) self.interpolated_posterior_latent_code_expanded = self.interpolated_posterior_latent_code[:,np.newaxis,:] self.interpolated_posterior_feature_expanded = self.Generator.forward(self.interpolated_posterior_latent_code_expanded) self.interpolated_posterior_feature = self.interpolated_posterior_feature_expanded[:,0,:] self.interpolated_reconst_param = self.PostGen.forward(self.interpolated_posterior_feature_expanded) self.interpolated_reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.interpolated_reconst_param) self.interpolated_obs = self.interpolated_reconst_dist.sample(b_mode=True) ### Primal Penalty # self.enc_reg_cost = self.compute_MMD(self.prior_latent_code, self.posterior_latent_code) # self.div_posterior = self.Diverger.forward(self.posterior_latent_code_expanded) # self.div_prior = self.Diverger.forward(self.prior_latent_code_expanded) ############################################################################# # REGULARIZER self.reg_target_dist = self.reconst_dist self.reg_target_sample = self.reconst_sample self.reg_dist = self.reconst_dist self.reg_sample = self.reconst_sample ############################################################################# # # OBJECTIVES # # Divergence # if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN': # self.div_cost = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_posterior)+10e-7))+tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7))) # if self.config['divergence_mode'] == 'WGAN-GP': # uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1)) # uniform_w = uniform_dist.sample() # self.trivial_line = uniform_w[:,np.newaxis,:]*self.posterior_latent_code_expanded+(1-uniform_w[:,np.newaxis,:])*self.prior_latent_code_expanded # self.div_trivial_line = self.Diverger.forward(self.trivial_line) # self.trivial_line_grad = tf.gradients(tf.reduce_sum(self.div_trivial_line), [self.trivial_line])[0] # self.trivial_line_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_line_grad**2, axis=-1, keep_dims=False)[:,:,np.newaxis]) # self.trivial_line_grad_norm_1_penalties = ((self.trivial_line_grad_norm-1)**2) # self.div_reg_cost = tf.reduce_mean(self.trivial_line_grad_norm_1_penalties) # self.div_cost = -(tf.reduce_mean(self.div_posterior)-tf.reduce_mean(self.div_prior))+10*self.div_reg_cost # ### Encoder # b_use_timer, timescale, starttime = False, 10, 5 # self.OT_primal = self.sample_distance_function(self.input_sample, self.reconst_sample) # self.mean_OT_primal = tf.reduce_mean(self.OT_primal) # if b_use_timer: # self.mean_POT_primal = self.mean_OT_primal+helper.hardstep((self.epoch-float(starttime))/float(timescale))*self.config['enc_reg_strength']*self.enc_reg_cost # else: # self.mean_POT_primal = self.mean_OT_primal+self.config['enc_reg_strength']*self.enc_reg_cost # self.enc_cost = self.mean_POT_primal # self.enc_cost = self.mean_OT_primal # ### Critic # # self.cri_cost = self.compute_MMD(self.prior_latent_code, self.prior_latent_code) # if self.config['divergence_mode'] == 'NS-GAN': # self.cri_cost = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*self.compute_MMD(self.prior_latent_code, self.prior_latent_code) # elif self.config['divergence_mode'] == 'GAN': # self.cri_cost = tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*self.compute_MMD(self.prior_latent_code, self.prior_latent_code) # elif self.config['divergence_mode'] == 'WGAN-GP': # self.cri_cost = -tf.reduce_mean(self.div_prior)+self.config['enc_reg_strength']*self.compute_MMD(self.prior_latent_code, self.prior_latent_code) # self.cri_reg_cost = self.shouldbezero_cost # self.cri_cost = self.shouldbenormal_cost+self.cri_reg_cost ### Generator # self.gen_cost = self.mean_OT_primal # self.div_cost = tf.reduce_mean(tf.reduce_sum(self.transformed_pre_posterior_feature[:, self.config['n_latent']:]**2, axis=1)) self.div_cost = tf.reduce_mean(tf.reduce_sum(helper.relu_abs(self.transformed_pre_posterior_feature[:, self.config['n_latent']:]), axis=1)) self.enc_cost = tf.reduce_mean(helper.safe_tf_sqrt(tf.reduce_sum((self.pre_posterior_feature-self.posterior_feature)**2, axis=1))) self.gen_cost = tf.reduce_mean(helper.safe_tf_sqrt(tf.reduce_sum((self.pre_posterior_feature-self.posterior_feature)**2, axis=1))) self.cri_cost = tf.reduce_mean(self.sample_distance_function(self.input_sample, self.reconst_sample))
def inference(self, batch, additional_inputs_tf): self.epoch = additional_inputs_tf[0] self.b_identity = additional_inputs_tf[1] if len(batch['observed']['properties']['flat'])>0: for e in batch['observed']['properties']['flat']: e['dist']='dirac' else: for e in batch['observed']['properties']['image']: e['dist']='dirac' self.input_sample = batch['observed']['data'] self.input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.input_sample) if not self.bModules: self.generate_modules(batch) try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1] except: self.n_time = batch['observed']['properties']['image'][0]['size'][1] try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0] except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0] ############################################################################# # GENERATOR self.prior_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),)) self.prior_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param) self.prior_latent_code = self.prior_dist.sample() self.prior_latent_code_expanded = tf.reshape(self.prior_latent_code, [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]]) self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code) self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior) self.obs_sample_param = self.Decoder.forward(self.prior_latent_code_expanded) self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param) self.obs_sample = self.obs_sample_dist.sample(b_mode=True) if os.path.exists('./np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz'): np_constant_prior_sample = np.load('./np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz') else: np_constant_prior_sample = np.random.normal(loc=0., scale=1., size=[400, self.prior_latent_code.get_shape().as_list()[-1]]) np.save('./np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample) self.constant_prior_latent_code = tf.constant(np.asarray(np_constant_prior_sample), dtype=np.float32) self.constant_prior_latent_code_expanded = tf.reshape(self.constant_prior_latent_code, [-1, 1, *self.constant_prior_latent_code.get_shape().as_list()[1:]]) self.constant_obs_sample_param = self.Decoder.forward(self.constant_prior_latent_code_expanded) self.constant_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_sample_param) self.constant_obs_sample = self.constant_obs_sample_dist.sample(b_mode=True) ############################################################################# # ENCODER b_deterministic = False if b_deterministic: self.epsilon = None else: self.epsilon_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),)) self.epsilon_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param) self.epsilon = self.epsilon_dist.sample() self.posterior_latent_code_expanded = self.EncoderSampler.forward(self.input_sample, noise=self.epsilon) self.posterior_latent_code = self.posterior_latent_code_expanded[:,0,:] self.interpolated_posterior_latent_code = self.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2) self.interpolated_obs = self.Decoder.forward(self.interpolated_posterior_latent_code) self.reconst_param = self.Decoder.forward(self.posterior_latent_code_expanded) self.reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reconst_param) self.reconst_sample = self.reconst_dist.sample(b_mode=True) self.kernel_function = self.inv_multiquadratics_kernel self.k_post_prior = tf.reduce_mean(self.kernel_function(self.posterior_latent_code, self.prior_dist.sample())) self.k_post_post = tf.reduce_mean(self.kernel_function(self.posterior_latent_code)) self.k_prior_prior = tf.reduce_mean(self.kernel_function(self.prior_dist.sample())) self.MMD = self.k_prior_prior+self.k_post_post-2*self.k_post_prior self.disc_posterior = self.DivergenceLatent.forward(self.posterior_latent_code_expanded) self.disc_prior = self.DivergenceLatent.forward(self.prior_latent_code_expanded) self.mean_z_divergence = tf.reduce_mean(tf.log(10e-7+self.disc_prior))+tf.reduce_mean(tf.log(10e-7+1-self.disc_posterior)) self.mean_z_divergence_minimizer = tf.reduce_mean(tf.log(10e-7+1-self.disc_posterior)) ############################################################################# # REGULARIZER self.uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1)) self.uniform_sample = self.uniform_dist.sample() self.reg_trivial_sample_param = {'image': None, 'flat': None} try: self.reg_trivial_sample_param['image'] = self.uniform_sample[:, np.newaxis, :, np.newaxis, np.newaxis]*self.obs_sample['image']+\ (1-self.uniform_sample[:, np.newaxis, :, np.newaxis, np.newaxis])*self.input_sample['image'] except: self.reg_trivial_sample_param['flat'] = self.uniform_sample[:, np.newaxis, :]*self.obs_sample['flat']+\ (1-self.uniform_sample[:, np.newaxis, :])*self.input_sample['flat'] self.reg_trivial_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reg_trivial_sample_param) self.reg_trivial_sample = self.reg_trivial_dist.sample(b_mode=True) self.reg_target_dist = self.obs_sample_dist self.reg_target_sample = self.obs_sample self.reg_dist = self.reg_trivial_dist self.reg_sample = self.reg_trivial_sample ############################################################################# # CRITIC self.critic_real = self.Discriminator.forward(self.input_sample) self.critic_gen = self.Discriminator.forward(self.obs_sample) self.critic_reg_trivial = self.Discriminator.forward(self.reg_trivial_sample) self.critic_reg = self.critic_reg_trivial try: self.trivial_grad = tf.gradients(self.critic_reg_trivial, [self.reg_trivial_sample['image']])[0] self.trivial_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_grad**2, axis=[-1,-2,-3], keep_dims=False)[:,:,np.newaxis]) except: self.trivial_grad = tf.gradients(self.critic_reg_trivial, [self.reg_trivial_sample['flat']])[0] self.trivial_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_grad**2, axis=[-1,], keep_dims=True)) self.trivial_grad_norm_1_penalties = ((self.trivial_grad_norm-1)**2) self.mean_critic_real = tf.reduce_mean(self.critic_real) self.mean_critic_gen = tf.reduce_mean(self.critic_gen) self.mean_critic_reg = tf.reduce_mean(self.critic_reg) self.mean_OT_dual = self.mean_critic_real-self.mean_critic_gen ############################################################################# # OBJECTIVES self.enc_reg_strength, self.disc_reg_strength = 100, 10 self.real_reconst_distances_sq = self.metric_distance_sq(self.input_sample, self.reconst_sample) self.OT_primal = tf.reduce_mean(helper.safe_tf_sqrt(self.real_reconst_distances_sq)) self.mean_OT_primal = tf.reduce_mean(self.OT_primal) self.enc_reg_cost = self.MMD # self.mean_z_divergence_minimizer self.mean_POT_primal = self.mean_OT_primal+self.enc_reg_strength*self.enc_reg_cost self.disc_reg_cost = tf.reduce_mean(self.trivial_grad_norm_1_penalties) # WGAN-GP # self.div_cost = -self.mean_z_divergence self.enc_cost = self.mean_POT_primal self.disc_cost = -self.mean_OT_dual+self.disc_reg_strength*self.disc_reg_cost self.gen_cost = -self.mean_critic_gen
def inference(self, batch, additional_inputs_tf): self.epoch = additional_inputs_tf[0] self.b_identity = additional_inputs_tf[1] if len(batch['observed']['properties']['flat']) > 0: for e in batch['observed']['properties']['flat']: e['dist'] = 'dirac' else: for e in batch['observed']['properties']['image']: e['dist'] = 'dirac' self.input_sample = batch['observed']['data'] self.input_dist = distributions.ProductDistribution( sample_properties=batch['observed']['properties'], params=self.input_sample) if not self.bModules: self.generate_modules(batch) try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1] except: self.n_time = batch['observed']['properties']['image'][0]['size'][ 1] try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0] except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0] ############################################################################# # GENERATOR self.prior_param = self.PriorMap.forward( (tf.zeros(shape=(self.batch_size_tf, 1)), )) self.prior_dist = distributions.DiagonalGaussianDistribution( params=self.prior_param) self.prior_latent_code = self.prior_dist.sample() self.prior_latent_code_expanded = tf.reshape( self.prior_latent_code, [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]]) # self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code) # self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior) self.obs_sample_param = self.Generator.forward( self.prior_latent_code_expanded) self.obs_sample_dist = distributions.ProductDistribution( sample_properties=batch['observed']['properties'], params=self.obs_sample_param) self.obs_sample = self.obs_sample_dist.sample(b_mode=True) if not os.path.exists('./fixed_samples/'): os.makedirs('./fixed_samples/') if os.path.exists( './fixed_samples/np_constant_prior_sample_' + str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz'): np_constant_prior_sample = np.load( './np_constant_prior_sample_' + str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz') else: np_constant_prior_sample = np.random.normal( loc=0., scale=1., size=[400, self.prior_latent_code.get_shape().as_list()[-1]]) np.save( './fixed_samples/np_constant_prior_sample_' + str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz', np_constant_prior_sample) self.constant_prior_latent_code = tf.constant( np.asarray(np_constant_prior_sample), dtype=np.float32) self.constant_prior_latent_code_expanded = tf.reshape( self.constant_prior_latent_code, [ -1, 1, *self.constant_prior_latent_code.get_shape().as_list()[1:] ]) self.constant_obs_sample_param = self.Generator.forward( self.constant_prior_latent_code_expanded) self.constant_obs_sample_dist = distributions.ProductDistribution( sample_properties=batch['observed']['properties'], params=self.constant_obs_sample_param) self.constant_obs_sample = self.constant_obs_sample_dist.sample( b_mode=True) ############################################################################# # ENCODER if self.config['encoder_mode'] == 'Deterministic': self.epsilon = None if self.config['encoder_mode'] == 'Gaussian' or self.config[ 'encoder_mode'] == 'UnivApprox' or self.config[ 'encoder_mode'] == 'UnivApproxNoSpatial' or self.config[ 'encoder_mode'] == 'UnivApproxSine': self.epsilon_param = self.PriorMap.forward( (tf.zeros(shape=(self.batch_size_tf, 1)), )) self.epsilon_dist = distributions.DiagonalGaussianDistribution( params=self.epsilon_param) self.epsilon = self.epsilon_dist.sample() self.posterior_latent_code_expanded = self.Encoder.forward( self.input_sample, noise=self.epsilon) self.posterior_latent_code = self.posterior_latent_code_expanded[:, 0, :] self.interpolated_posterior_latent_code = self.interpolate_latent_codes( self.posterior_latent_code, size=self.batch_size_tf // 2) self.interpolated_obs = self.Generator.forward( self.interpolated_posterior_latent_code) self.reconst_param = self.Generator.forward( self.posterior_latent_code_expanded) self.reconst_dist = distributions.ProductDistribution( sample_properties=batch['observed']['properties'], params=self.reconst_param) self.reconst_sample = self.reconst_dist.sample(b_mode=True) ### Primal Penalty if self.config['divergence_mode'] == 'MMD': self.MMD = self.compute_MMD(self.posterior_latent_code, self.prior_dist.sample()) self.enc_reg_cost = self.MMD if self.config['divergence_mode'] == 'INV-MMD': batch_rand_vectors = tf.random_normal(shape=[ self.config['enc_inv_MMD_n_trans'], self.config['n_latent'] ]) batch_rand_dirs = batch_rand_vectors / helper.safe_tf_sqrt( tf.reduce_sum((batch_rand_vectors**2), axis=1, keep_dims=True)) self.Inv_MMD = self.stable_div(self.compute_MMD, self.posterior_latent_code, batch_rand_dirs) self.MMD = self.compute_MMD(self.posterior_latent_code, self.prior_dist.sample()) self.enc_reg_cost = self.MMD + self.config[ 'enc_inv_MMD_strength'] * self.Inv_MMD elif self.config['divergence_mode'] == 'GAN' or self.config[ 'divergence_mode'] == 'NS-GAN': self.div_posterior = self.Diverger.forward( self.posterior_latent_code_expanded) self.div_prior = self.Diverger.forward( self.prior_latent_code_expanded) self.mean_z_divergence = tf.reduce_mean( tf.log(10e-7 + self.div_prior)) + tf.reduce_mean( tf.log(10e-7 + 1 - self.div_posterior)) if self.config['divergence_mode'] == 'NS-GAN': self.enc_reg_cost = -tf.reduce_mean( tf.log(10e-7 + self.div_posterior)) elif self.config['divergence_mode'] == 'GAN': self.enc_reg_cost = tf.reduce_mean( tf.log(10e-7 + 1 - self.div_posterior)) ############################################################################# # REGULARIZER self.uniform_dist = distributions.UniformDistribution( params=tf.concat([ tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1)) ], axis=1)) ### Visualized regularization sample self.reg_target_dist = self.reconst_dist self.reg_dist = None ### Uniform Sample From Geodesic of Trivial Coupling Pairs if self.compute_all_regularizers or \ 'Trivial Gradient Norm' in self.config['critic_reg_mode'] or \ 'Trivial Lipschitz' in self.config['critic_reg_mode']: self.uniform_sample_b = self.uniform_dist.sample() self.trivial_line_sample_param = {'image': None, 'flat': None} try: self.uniform_sample_b_expanded = self.uniform_sample_b[:, np. newaxis, :, np. newaxis, np. newaxis] self.trivial_line_sample_param['image'] = self.uniform_sample_b_expanded*self.obs_sample['image']+\ (1-self.uniform_sample_b_expanded)*self.input_sample['image'] except: self.uniform_sample_b_expanded = self.uniform_sample_b[:, np. newaxis, :] self.trivial_line_sample_param['flat'] = self.uniform_sample_b_expanded*self.obs_sample['flat']+\ (1-self.uniform_sample_b_expanded)*self.input_sample['flat'] self.trivial_line_dist = distributions.ProductDistribution( sample_properties=batch['observed']['properties'], params=self.trivial_line_sample_param) self.trivial_line_sample = self.trivial_line_dist.sample( b_mode=True) if self.reg_dist is None: self.reg_dist = self.trivial_line_dist ############################################################################# # CRITIC self.critic_real = self.Critic.forward(self.input_sample) self.critic_gen = self.Critic.forward(self.obs_sample) self.critic_rec = self.Critic.forward(self.reconst_sample) self.mean_critic_real = tf.reduce_mean(self.critic_real) self.mean_critic_gen = tf.reduce_mean(self.critic_gen) self.mean_critic_rec = tf.reduce_mean(self.critic_rec) self.mean_critic_reg = None if self.compute_all_regularizers or \ 'Trivial Gradient Norm' in self.config['critic_reg_mode'] or \ 'Trivial Lipschitz' in self.config['critic_reg_mode']: self.critic_trivial_line = self.Critic.forward( self.trivial_line_sample) if self.mean_critic_reg is None: self.mean_critic_reg = tf.reduce_mean(self.critic_trivial_line) try: self.trivial_line_grad = tf.gradients( self.critic_trivial_line, [self.trivial_line_sample['image']])[0] self.trivial_line_grad_norm = tf.sqrt( tf.reduce_sum(tf.square(self.trivial_line_grad), axis=[-1, -2, -3], keep_dims=False))[:, np.newaxis, :] except: self.trivial_line_grad = tf.gradients( self.critic_trivial_line, [self.trivial_line_sample['flat']])[0] self.trivial_line_grad_norm = helper.safe_tf_sqrt( tf.reduce_sum(self.trivial_line_grad**2, axis=[ -1, ], keep_dims=True)) self.cri_reg_cost = 0 if self.compute_all_regularizers or 'Trivial Gradient Norm' in self.config[ 'critic_reg_mode']: self.trivial_line_grad_norm_1_penalties = (( self.trivial_line_grad_norm - 1.)**2) self.mean_trivial_line_grad_norm_1_penalties = tf.reduce_mean( self.trivial_line_grad_norm_1_penalties) if 'Trivial Gradient Norm' in self.config['critic_reg_mode']: print('Adding Trivial Gradient Norm Penalty') self.cri_reg_cost += self.mean_trivial_line_grad_norm_1_penalties if len(self.config['critic_reg_mode']) > 0: self.cri_reg_cost /= len(self.config['critic_reg_mode']) ############################################################################# # OBJECTIVES ### Divergence if self.config['divergence_mode'] == 'GAN' or self.config[ 'divergence_mode'] == 'NS-GAN': self.div_cost = -self.config[ 'enc_reg_strength'] * self.mean_z_divergence ### Encoder self.OT_primal = self.sample_distance_function(self.input_sample, self.reconst_sample) self.mean_OT_primal = tf.reduce_mean(self.OT_primal) # self.mean_OT_primal = helper.tf_print(self.mean_OT_primal, [self.mean_OT_primal]) self.mean_POT_primal = self.mean_OT_primal #+self.config['enc_reg_strength']*self.enc_reg_cost self.enc_cost = self.mean_POT_primal ### Generator if self.config['dual_dist_mode'] == 'Coupling': self.gen_cost = -self.mean_critic_rec elif self.config['dual_dist_mode'] == 'Prior': self.gen_cost = -self.mean_critic_gen elif self.config['dual_dist_mode'] == 'CouplingAndPrior': self.gen_cost = -0.5 * (self.mean_critic_rec + self.mean_critic_gen) ### Critic if self.config['dual_dist_mode'] == 'Coupling': self.mean_OT_dual = self.mean_critic_real - self.mean_critic_rec elif self.config['dual_dist_mode'] == 'Prior': self.mean_OT_dual = self.mean_critic_real - self.mean_critic_gen elif self.config['dual_dist_mode'] == 'CouplingAndPrior': self.mean_OT_dual = self.mean_critic_real - 0.5 * ( self.mean_critic_rec + self.mean_critic_gen) self.cri_cost = -self.mean_OT_dual + self.config[ 'cri_reg_strength'] * self.cri_reg_cost
def sample(self): dir_normal = tf.random.normal(shape=tf.shape(self.centers)) dir_normal_norm = helper.safe_tf_sqrt(tf.reduce_sum(dir_normal**2, axis=1, keepdims=True)) sample_dir = dir_normal/dir_normal_norm sample = self.radius*sample_dir+self.centers return sample
layer_2_rec = tf.layers.dense(inputs=layer_1_rec, units=500, use_bias=True, activation=tf.nn.relu) layer_3_rec = tf.layers.dense(inputs=layer_2_rec, units=500, use_bias=True, activation=tf.nn.relu) # layer_4_rec = tf.layers.dense(inputs = layer_3_rec, units = 500, use_bias = True, activation = tf.nn.relu) x_rec = tf.layers.dense(inputs=layer_3_rec, units=input_dim * tile_rate, use_bias=True, activation=None) rec_cost = tf.reduce_mean( helper.safe_tf_sqrt(tf.reduce_sum((x_rec - x_input)**2, axis=1))) MMD = helper.compute_MMD(z, z_prior, positive_only=True) # start, timescale = 200, 1500 start, timescale = 0, 1 lambda_z_comp = 100 * helper.hardstep( (iter_tf - float(start)) / float(timescale)) cost = MMD + lambda_z_comp * rec_cost optimizer = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999, epsilon=1e-08) cost_step = optimizer.minimize(cost) init = tf.initialize_all_variables()