示例#1
0
	def sample(self):
		dir_normal = tf.random.normal(shape=tf.shape(self.centers))
		dir_normal_norm = helper.safe_tf_sqrt(tf.reduce_sum(dir_normal**2, axis=1, keepdims=True))
		sample_dir = dir_normal/dir_normal_norm
		sample_norm = tf.random_uniform(tf.shape(self.inner_radius), 0, 1, dtype=tf.float32)*(self.outer_radius-self.inner_radius)+self.inner_radius
		sample = sample_norm*sample_dir+self.centers
		return sample
示例#2
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            batch_size_tf = tf.shape(self.input_sample['image'])[0]

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code)
        self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(batch_size_tf, 1)),
                tf.ones(shape=(batch_size_tf, 1))
            ],
                             axis=1))
        self.convex_mix = self.uniform_dist.sample()

        self.prior_latent_code_expanded = tf.reshape(
            self.prior_latent_code,
            [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        self.obs_sample_param = self.Decoder.forward(
            self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        #############################################################################
        b_use_reconst_as_reg = True
        self.posterior_param_expanded = self.Encoder.forward(self.input_sample)
        self.posterior_param = self.posterior_param_expanded[:, 0, :]
        self.posterior_dist = distributions.DiagonalGaussianDistribution(
            params=self.posterior_param)
        self.posterior_latent_code = self.posterior_dist.sample()
        self.posterior_latent_code_expanded = self.posterior_latent_code[:, np.
                                                                         newaxis, :]

        self.reconst_param = self.Decoder.forward(
            self.posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        if b_use_reconst_as_reg:
            self.reg_target_dist = self.reconst_dist
            self.reg_target_sample = self.reconst_sample
        else:
            self.reg_target_dist = self.obs_sample_dist
            self.reg_target_sample = self.obs_sample

        self.neg_cross_ent_posterior = self.prior_dist.log_pdf(
            self.posterior_latent_code)
        self.mean_neg_cross_ent_posterior = tf.reduce_mean(
            self.neg_cross_ent_posterior)
        self.neg_ent_posterior = self.posterior_dist.log_pdf(
            self.posterior_latent_code)
        self.mean_neg_ent_posterior = tf.reduce_mean(self.neg_ent_posterior)

        # self.kl_posterior_prior = distributions.KLDivDiagGaussianVsDiagGaussian().forward(self.posterior_dist, self.prior_dist)
        # self.mean_kl_posterior_prior = tf.reduce_mean(self.kl_posterior_prior)

        self.mean_kl_posterior_prior = -self.mean_neg_cross_ent_posterior + self.mean_neg_ent_posterior
        #############################################################################

        self.reg_sample_param = {'image': None, 'flat': None}
        try:
            self.reg_sample_param['image'] = self.convex_mix[:, np.newaxis, :, np.newaxis, np.newaxis]*self.reg_target_sample['image']+\
                                            (1-self.convex_mix[:, np.newaxis, :, np.newaxis, np.newaxis])*self.input_sample['image']
        except:
            self.reg_sample_param['flat'] = self.convex_mix[:, np.newaxis, :]*self.reg_target_sample['flat']+\
                                            (1-self.convex_mix[:, np.newaxis, :])*self.input_sample['flat']
        self.reg_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reg_sample_param)
        self.reg_sample = self.reg_dist.sample(b_mode=True)

        self.critic_real = self.Discriminator.forward(self.input_sample)
        self.critic_gen = self.Discriminator.forward(self.obs_sample)
        self.critic_reg = self.Discriminator.forward(self.reg_sample)

        lambda_t = 1
        self.real_reconst_distances_sq = self.metric_distance_sq(
            self.input_sample, self.reconst_sample)
        self.autoencode_costs = self.real_reconst_distances_sq / 2 + lambda_t * self.mean_kl_posterior_prior

        try:
            self.convex_grad = tf.gradients(self.critic_reg,
                                            [self.reg_sample['image']])[0]
            self.convex_grad_norm = helper.safe_tf_sqrt(
                tf.reduce_sum(self.convex_grad**2,
                              axis=[-1, -2, -3],
                              keep_dims=False)[:, :, np.newaxis])
        except:
            self.convex_grad = tf.gradients(self.critic_reg,
                                            [self.reg_sample['flat']])[0]
            self.convex_grad_norm = helper.safe_tf_sqrt(
                tf.reduce_sum(self.convex_grad**2, axis=[-1], keep_dims=True))
        self.gradient_penalties = ((self.convex_grad_norm - 1)**2)

        self.mean_critic_real = tf.reduce_mean(self.critic_real)
        self.mean_critic_gen = tf.reduce_mean(self.critic_gen)
        self.mean_critic_reg = tf.reduce_mean(self.critic_reg)
        self.mean_autoencode_cost = tf.reduce_mean(self.autoencode_costs)
        self.mean_gradient_penalty = tf.reduce_mean(self.gradient_penalties)

        self.regularizer_cost = 10 * self.mean_gradient_penalty
        self.discriminator_cost = -self.mean_critic_real + self.mean_critic_gen + self.regularizer_cost
        self.generator_cost = -self.mean_critic_gen
        self.transporter_cost = self.mean_autoencode_cost
示例#3
0
 def euclidean_distance(self, a, b):
     return helper.safe_tf_sqrt(self.metric_distance_sq(a, b))
示例#4
0
文件: Model.py 项目: avellal14/pdwg
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]
        
        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################

        self.prior_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.prior_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.prior_latent_code_expanded = self.prior_latent_code[:,np.newaxis,:]

        # GENERATOR 

        self.prior_feature_expanded = self.Generator.forward(self.prior_latent_code_expanded)
        self.prior_feature = self.prior_feature_expanded[:,0,:]

        self.obs_sample_param = self.PostGen.forward(self.prior_feature_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)        

        if not os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/'): os.makedirs(str(Path.home())+'/ExperimentalResults/FixedSamples/')
        if os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
            np_constant_prior_sample = np.load(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz')
        else:
            np_constant_prior_sample = np.random.normal(loc=0., scale=1., size=[400, self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    
        
        self.constant_prior_latent_code = tf.constant(np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_prior_latent_code_expanded = self.constant_prior_latent_code[:, np.newaxis, :]
        
        self.constant_prior_feature_expanded = self.Generator.forward(self.constant_prior_latent_code_expanded)
        self.constant_prior_feature = self.constant_prior_feature_expanded[:,0,:]
        self.constant_obs_sample_param = self.PostGen.forward(self.constant_prior_feature_expanded)
        self.constant_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(b_mode=True)
    
        if self.config['n_latent'] == 2: 
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate((xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]), axis=1)
        
            self.constant_prior_grid_latent_code = tf.constant(np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_prior_grid_latent_code_expanded = self.constant_prior_grid_latent_code[:, np.newaxis, :]

            self.constant_prior_grid_feature_expanded = self.Generator.forward(self.constant_prior_grid_latent_code_expanded)
            self.constant_prior_grid_feature = self.constant_prior_grid_feature_expanded[:,0,:]
            self.constant_obs_grid_sample_param = self.PostGen.forward(self.constant_prior_grid_feature_expanded)
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(b_mode=True)
                
        #############################################################################
        # ENCODER 
        self.epsilon_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.epsilon_dist = distributions.DiagonalGaussianDistribution(params = self.epsilon_param) 

        if self.config['encoder_mode'] == 'Deterministic': 
            self.epsilon = None
        if self.config['encoder_mode'] == 'Gaussian' or self.config['encoder_mode'] == 'UnivApprox' or self.config['encoder_mode'] == 'UnivApproxNoSpatial' or self.config['encoder_mode'] == 'UnivApproxSine':         
            self.epsilon = self.epsilon_dist.sample()

        self.pre_posterior_feature_expanded = self.PreEnc.forward(self.input_sample, noise=self.epsilon)
        self.pre_posterior_feature = self.pre_posterior_feature_expanded[:,0,:]
        
        # self.tiny_perturb_dist = distributions.DiagonalGaussianDistribution(params = tf.zeros(shape=(self.batch_size_tf, 2*512)))  
        # self.tiny_perturb_sample_expanded = 0.1*self.tiny_perturb_dist.sample()[:,np.newaxis,:]

        self.flow_param = self.FlowMap.forward()
        self.flow_object = transforms.HouseholdRotationFlow(self.flow_param, 512)
        self.transformed_pre_posterior_feature, _ = self.flow_object.transform(self.pre_posterior_feature, tf.zeros(shape=(self.batch_size_tf, 1)))

        self.transformed_pre_posterior_feature_abs = helper.relu_abs(self.transformed_pre_posterior_feature)
        self.transformed_pre_posterior_feature_abs_means = tf.reduce_mean(self.transformed_pre_posterior_feature_abs, axis=0)

        # self.reconst_param = self.PostGen.forward(self.pre_posterior_feature_expanded+self.tiny_perturb_sample_expanded) 
        self.reconst_param = self.PostGen.forward(self.pre_posterior_feature_expanded) 
        self.reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        self.posterior_latent_code_expanded = self.Encoder.forward(self.pre_posterior_feature_expanded)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,0,:]
        self.posterior_feature_expanded = self.Generator.forward(self.posterior_latent_code_expanded)
        self.posterior_feature = self.posterior_feature_expanded[:,0,:]

        self.full_reconst_param = self.PostGen.forward(self.posterior_feature_expanded) 
        self.full_reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.full_reconst_param)
        self.full_reconst_sample = self.full_reconst_dist.sample(b_mode=True)

        self.interpolated_posterior_latent_code = self.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2)
        self.interpolated_posterior_latent_code_expanded = self.interpolated_posterior_latent_code[:,np.newaxis,:]
        self.interpolated_posterior_feature_expanded = self.Generator.forward(self.interpolated_posterior_latent_code_expanded)
        self.interpolated_posterior_feature = self.interpolated_posterior_feature_expanded[:,0,:]

        self.interpolated_reconst_param = self.PostGen.forward(self.interpolated_posterior_feature_expanded) 
        self.interpolated_reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.interpolated_reconst_param)
        self.interpolated_obs = self.interpolated_reconst_dist.sample(b_mode=True)

        ### Primal Penalty
        # self.enc_reg_cost = self.compute_MMD(self.prior_latent_code, self.posterior_latent_code)
        # self.div_posterior = self.Diverger.forward(self.posterior_latent_code_expanded)        
        # self.div_prior = self.Diverger.forward(self.prior_latent_code_expanded) 
            
        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample
        
        #############################################################################

        # # OBJECTIVES
        # # Divergence
        # if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN':
        #     self.div_cost = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_posterior)+10e-7))+tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7)))
        # if self.config['divergence_mode'] == 'WGAN-GP':
        #     uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        #     uniform_w = uniform_dist.sample()
        #     self.trivial_line = uniform_w[:,np.newaxis,:]*self.posterior_latent_code_expanded+(1-uniform_w[:,np.newaxis,:])*self.prior_latent_code_expanded
        #     self.div_trivial_line = self.Diverger.forward(self.trivial_line)
        #     self.trivial_line_grad = tf.gradients(tf.reduce_sum(self.div_trivial_line), [self.trivial_line])[0]
        #     self.trivial_line_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_line_grad**2, axis=-1, keep_dims=False)[:,:,np.newaxis])
        #     self.trivial_line_grad_norm_1_penalties = ((self.trivial_line_grad_norm-1)**2)
        #     self.div_reg_cost = tf.reduce_mean(self.trivial_line_grad_norm_1_penalties)
        #     self.div_cost = -(tf.reduce_mean(self.div_posterior)-tf.reduce_mean(self.div_prior))+10*self.div_reg_cost

        # ### Encoder
        # b_use_timer, timescale, starttime = False, 10, 5
        # self.OT_primal = self.sample_distance_function(self.input_sample, self.reconst_sample)
        # self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        # if b_use_timer:
        #     self.mean_POT_primal = self.mean_OT_primal+helper.hardstep((self.epoch-float(starttime))/float(timescale))*self.config['enc_reg_strength']*self.enc_reg_cost
        # else:
        #     self.mean_POT_primal = self.mean_OT_primal+self.config['enc_reg_strength']*self.enc_reg_cost
        # self.enc_cost = self.mean_POT_primal
        # self.enc_cost = self.mean_OT_primal

        # ### Critic
        # # self.cri_cost = self.compute_MMD(self.prior_latent_code, self.prior_latent_code)
        # if self.config['divergence_mode'] == 'NS-GAN': 
        #     self.cri_cost = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*self.compute_MMD(self.prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'GAN': 
        #     self.cri_cost = tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*self.compute_MMD(self.prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'WGAN-GP': 
        #     self.cri_cost = -tf.reduce_mean(self.div_prior)+self.config['enc_reg_strength']*self.compute_MMD(self.prior_latent_code, self.prior_latent_code)
        # self.cri_reg_cost = self.shouldbezero_cost
        # self.cri_cost = self.shouldbenormal_cost+self.cri_reg_cost
        
        ### Generator
        # self.gen_cost = self.mean_OT_primal

        # self.div_cost = tf.reduce_mean(tf.reduce_sum(self.transformed_pre_posterior_feature[:, self.config['n_latent']:]**2, axis=1))
        self.div_cost = tf.reduce_mean(tf.reduce_sum(helper.relu_abs(self.transformed_pre_posterior_feature[:, self.config['n_latent']:]), axis=1))
        self.enc_cost = tf.reduce_mean(helper.safe_tf_sqrt(tf.reduce_sum((self.pre_posterior_feature-self.posterior_feature)**2, axis=1)))
        self.gen_cost = tf.reduce_mean(helper.safe_tf_sqrt(tf.reduce_sum((self.pre_posterior_feature-self.posterior_feature)**2, axis=1)))
        self.cri_cost = tf.reduce_mean(self.sample_distance_function(self.input_sample, self.reconst_sample))        
示例#5
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]
        
        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR 

        self.prior_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.prior_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.prior_latent_code_expanded = tf.reshape(self.prior_latent_code, [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code)
        self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.obs_sample_param = self.Decoder.forward(self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if os.path.exists('./np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
            np_constant_prior_sample = np.load('./np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz')
        else:
            np_constant_prior_sample = np.random.normal(loc=0., scale=1., size=[400, self.prior_latent_code.get_shape().as_list()[-1]])
            np.save('./np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    
        
        self.constant_prior_latent_code = tf.constant(np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_prior_latent_code_expanded = tf.reshape(self.constant_prior_latent_code, [-1, 1, *self.constant_prior_latent_code.get_shape().as_list()[1:]])

        self.constant_obs_sample_param = self.Decoder.forward(self.constant_prior_latent_code_expanded)
        self.constant_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(b_mode=True)
        
        #############################################################################
        # ENCODER 

        b_deterministic = False
        if b_deterministic: 
            self.epsilon = None
        else:
            self.epsilon_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param)        
            self.epsilon = self.epsilon_dist.sample()

        self.posterior_latent_code_expanded = self.EncoderSampler.forward(self.input_sample, noise=self.epsilon) 
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,0,:]
        self.interpolated_posterior_latent_code = self.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2)
        self.interpolated_obs = self.Decoder.forward(self.interpolated_posterior_latent_code) 

        self.reconst_param = self.Decoder.forward(self.posterior_latent_code_expanded) 
        self.reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        self.kernel_function = self.inv_multiquadratics_kernel
        self.k_post_prior = tf.reduce_mean(self.kernel_function(self.posterior_latent_code, self.prior_dist.sample()))
        self.k_post_post = tf.reduce_mean(self.kernel_function(self.posterior_latent_code))
        self.k_prior_prior = tf.reduce_mean(self.kernel_function(self.prior_dist.sample()))
        self.MMD = self.k_prior_prior+self.k_post_post-2*self.k_post_prior

        self.disc_posterior = self.DivergenceLatent.forward(self.posterior_latent_code_expanded)        
        self.disc_prior = self.DivergenceLatent.forward(self.prior_latent_code_expanded)   
        self.mean_z_divergence = tf.reduce_mean(tf.log(10e-7+self.disc_prior))+tf.reduce_mean(tf.log(10e-7+1-self.disc_posterior))
        self.mean_z_divergence_minimizer = tf.reduce_mean(tf.log(10e-7+1-self.disc_posterior))

        #############################################################################
        # REGULARIZER 

        self.uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        self.uniform_sample = self.uniform_dist.sample()
        
        self.reg_trivial_sample_param = {'image': None, 'flat': None}
        try: 
            self.reg_trivial_sample_param['image'] = self.uniform_sample[:, np.newaxis, :, np.newaxis, np.newaxis]*self.obs_sample['image']+\
                                                    (1-self.uniform_sample[:, np.newaxis, :, np.newaxis, np.newaxis])*self.input_sample['image']
        except: 
            self.reg_trivial_sample_param['flat'] = self.uniform_sample[:, np.newaxis, :]*self.obs_sample['flat']+\
                                                   (1-self.uniform_sample[:, np.newaxis, :])*self.input_sample['flat']
        self.reg_trivial_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reg_trivial_sample_param)
        self.reg_trivial_sample = self.reg_trivial_dist.sample(b_mode=True)

        self.reg_target_dist = self.obs_sample_dist
        self.reg_target_sample = self.obs_sample
        self.reg_dist = self.reg_trivial_dist
        self.reg_sample = self.reg_trivial_sample 
        
        #############################################################################
        # CRITIC 

        self.critic_real = self.Discriminator.forward(self.input_sample)
        self.critic_gen = self.Discriminator.forward(self.obs_sample)
        self.critic_reg_trivial = self.Discriminator.forward(self.reg_trivial_sample)
        self.critic_reg = self.critic_reg_trivial

        try:
            self.trivial_grad = tf.gradients(self.critic_reg_trivial, [self.reg_trivial_sample['image']])[0]
            self.trivial_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_grad**2, axis=[-1,-2,-3], keep_dims=False)[:,:,np.newaxis])
        except: 
            self.trivial_grad = tf.gradients(self.critic_reg_trivial, [self.reg_trivial_sample['flat']])[0]
            self.trivial_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_grad**2, axis=[-1,], keep_dims=True))      

        self.trivial_grad_norm_1_penalties = ((self.trivial_grad_norm-1)**2)

        self.mean_critic_real = tf.reduce_mean(self.critic_real)
        self.mean_critic_gen = tf.reduce_mean(self.critic_gen)
        self.mean_critic_reg = tf.reduce_mean(self.critic_reg)
        self.mean_OT_dual = self.mean_critic_real-self.mean_critic_gen

        #############################################################################
        # OBJECTIVES
        self.enc_reg_strength, self.disc_reg_strength = 100, 10

        self.real_reconst_distances_sq = self.metric_distance_sq(self.input_sample, self.reconst_sample)
        self.OT_primal = tf.reduce_mean(helper.safe_tf_sqrt(self.real_reconst_distances_sq))
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        self.enc_reg_cost = self.MMD # self.mean_z_divergence_minimizer
        self.mean_POT_primal = self.mean_OT_primal+self.enc_reg_strength*self.enc_reg_cost
        
        self.disc_reg_cost = tf.reduce_mean(self.trivial_grad_norm_1_penalties)

        # WGAN-GP
        # self.div_cost = -self.mean_z_divergence
        self.enc_cost = self.mean_POT_primal
        self.disc_cost = -self.mean_OT_dual+self.disc_reg_strength*self.disc_reg_cost
        self.gen_cost = -self.mean_critic_gen
示例#6
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.prior_latent_code_expanded = tf.reshape(
            self.prior_latent_code,
            [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        # self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code)
        # self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.obs_sample_param = self.Generator.forward(
            self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if not os.path.exists('./fixed_samples/'):
            os.makedirs('./fixed_samples/')
        if os.path.exists(
                './fixed_samples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                './np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[400,
                      self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(
                './fixed_samples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz',
                np_constant_prior_sample)

        self.constant_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_prior_latent_code_expanded = tf.reshape(
            self.constant_prior_latent_code, [
                -1, 1,
                *self.constant_prior_latent_code.get_shape().as_list()[1:]
            ])

        self.constant_obs_sample_param = self.Generator.forward(
            self.constant_prior_latent_code_expanded)
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(
            b_mode=True)

        #############################################################################
        # ENCODER

        if self.config['encoder_mode'] == 'Deterministic':
            self.epsilon = None
        if self.config['encoder_mode'] == 'Gaussian' or self.config[
                'encoder_mode'] == 'UnivApprox' or self.config[
                    'encoder_mode'] == 'UnivApproxNoSpatial' or self.config[
                        'encoder_mode'] == 'UnivApproxSine':
            self.epsilon_param = self.PriorMap.forward(
                (tf.zeros(shape=(self.batch_size_tf, 1)), ))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(
                params=self.epsilon_param)
            self.epsilon = self.epsilon_dist.sample()

        self.posterior_latent_code_expanded = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,
                                                                         0, :]
        self.interpolated_posterior_latent_code = self.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        self.interpolated_obs = self.Generator.forward(
            self.interpolated_posterior_latent_code)

        self.reconst_param = self.Generator.forward(
            self.posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        ### Primal Penalty
        if self.config['divergence_mode'] == 'MMD':
            self.MMD = self.compute_MMD(self.posterior_latent_code,
                                        self.prior_dist.sample())
            self.enc_reg_cost = self.MMD
        if self.config['divergence_mode'] == 'INV-MMD':
            batch_rand_vectors = tf.random_normal(shape=[
                self.config['enc_inv_MMD_n_trans'], self.config['n_latent']
            ])
            batch_rand_dirs = batch_rand_vectors / helper.safe_tf_sqrt(
                tf.reduce_sum((batch_rand_vectors**2), axis=1, keep_dims=True))
            self.Inv_MMD = self.stable_div(self.compute_MMD,
                                           self.posterior_latent_code,
                                           batch_rand_dirs)
            self.MMD = self.compute_MMD(self.posterior_latent_code,
                                        self.prior_dist.sample())
            self.enc_reg_cost = self.MMD + self.config[
                'enc_inv_MMD_strength'] * self.Inv_MMD
        elif self.config['divergence_mode'] == 'GAN' or self.config[
                'divergence_mode'] == 'NS-GAN':
            self.div_posterior = self.Diverger.forward(
                self.posterior_latent_code_expanded)
            self.div_prior = self.Diverger.forward(
                self.prior_latent_code_expanded)
            self.mean_z_divergence = tf.reduce_mean(
                tf.log(10e-7 + self.div_prior)) + tf.reduce_mean(
                    tf.log(10e-7 + 1 - self.div_posterior))
            if self.config['divergence_mode'] == 'NS-GAN':
                self.enc_reg_cost = -tf.reduce_mean(
                    tf.log(10e-7 + self.div_posterior))
            elif self.config['divergence_mode'] == 'GAN':
                self.enc_reg_cost = tf.reduce_mean(
                    tf.log(10e-7 + 1 - self.div_posterior))

        #############################################################################
        # REGULARIZER
        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(self.batch_size_tf, 1)),
                tf.ones(shape=(self.batch_size_tf, 1))
            ],
                             axis=1))

        ### Visualized regularization sample
        self.reg_target_dist = self.reconst_dist
        self.reg_dist = None

        ### Uniform Sample From Geodesic of Trivial Coupling Pairs
        if self.compute_all_regularizers or \
           'Trivial Gradient Norm' in self.config['critic_reg_mode'] or \
           'Trivial Lipschitz' in self.config['critic_reg_mode']:
            self.uniform_sample_b = self.uniform_dist.sample()
            self.trivial_line_sample_param = {'image': None, 'flat': None}
            try:
                self.uniform_sample_b_expanded = self.uniform_sample_b[:, np.
                                                                       newaxis, :,
                                                                       np.
                                                                       newaxis,
                                                                       np.
                                                                       newaxis]
                self.trivial_line_sample_param['image'] = self.uniform_sample_b_expanded*self.obs_sample['image']+\
                                                          (1-self.uniform_sample_b_expanded)*self.input_sample['image']
            except:
                self.uniform_sample_b_expanded = self.uniform_sample_b[:, np.
                                                                       newaxis, :]
                self.trivial_line_sample_param['flat'] = self.uniform_sample_b_expanded*self.obs_sample['flat']+\
                                                         (1-self.uniform_sample_b_expanded)*self.input_sample['flat']
            self.trivial_line_dist = distributions.ProductDistribution(
                sample_properties=batch['observed']['properties'],
                params=self.trivial_line_sample_param)
            self.trivial_line_sample = self.trivial_line_dist.sample(
                b_mode=True)
            if self.reg_dist is None: self.reg_dist = self.trivial_line_dist

        #############################################################################
        # CRITIC

        self.critic_real = self.Critic.forward(self.input_sample)
        self.critic_gen = self.Critic.forward(self.obs_sample)
        self.critic_rec = self.Critic.forward(self.reconst_sample)

        self.mean_critic_real = tf.reduce_mean(self.critic_real)
        self.mean_critic_gen = tf.reduce_mean(self.critic_gen)
        self.mean_critic_rec = tf.reduce_mean(self.critic_rec)

        self.mean_critic_reg = None

        if self.compute_all_regularizers or \
           'Trivial Gradient Norm' in self.config['critic_reg_mode'] or \
           'Trivial Lipschitz' in self.config['critic_reg_mode']:
            self.critic_trivial_line = self.Critic.forward(
                self.trivial_line_sample)
            if self.mean_critic_reg is None:
                self.mean_critic_reg = tf.reduce_mean(self.critic_trivial_line)
            try:
                self.trivial_line_grad = tf.gradients(
                    self.critic_trivial_line,
                    [self.trivial_line_sample['image']])[0]
                self.trivial_line_grad_norm = tf.sqrt(
                    tf.reduce_sum(tf.square(self.trivial_line_grad),
                                  axis=[-1, -2, -3],
                                  keep_dims=False))[:, np.newaxis, :]
            except:
                self.trivial_line_grad = tf.gradients(
                    self.critic_trivial_line,
                    [self.trivial_line_sample['flat']])[0]
                self.trivial_line_grad_norm = helper.safe_tf_sqrt(
                    tf.reduce_sum(self.trivial_line_grad**2,
                                  axis=[
                                      -1,
                                  ],
                                  keep_dims=True))

        self.cri_reg_cost = 0

        if self.compute_all_regularizers or 'Trivial Gradient Norm' in self.config[
                'critic_reg_mode']:
            self.trivial_line_grad_norm_1_penalties = ((
                self.trivial_line_grad_norm - 1.)**2)
            self.mean_trivial_line_grad_norm_1_penalties = tf.reduce_mean(
                self.trivial_line_grad_norm_1_penalties)
            if 'Trivial Gradient Norm' in self.config['critic_reg_mode']:
                print('Adding Trivial Gradient Norm Penalty')
                self.cri_reg_cost += self.mean_trivial_line_grad_norm_1_penalties

        if len(self.config['critic_reg_mode']) > 0:
            self.cri_reg_cost /= len(self.config['critic_reg_mode'])

        #############################################################################

        # OBJECTIVES
        ### Divergence
        if self.config['divergence_mode'] == 'GAN' or self.config[
                'divergence_mode'] == 'NS-GAN':
            self.div_cost = -self.config[
                'enc_reg_strength'] * self.mean_z_divergence

        ### Encoder
        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        # self.mean_OT_primal = helper.tf_print(self.mean_OT_primal, [self.mean_OT_primal])

        self.mean_POT_primal = self.mean_OT_primal  #+self.config['enc_reg_strength']*self.enc_reg_cost
        self.enc_cost = self.mean_POT_primal

        ### Generator
        if self.config['dual_dist_mode'] == 'Coupling':
            self.gen_cost = -self.mean_critic_rec
        elif self.config['dual_dist_mode'] == 'Prior':
            self.gen_cost = -self.mean_critic_gen
        elif self.config['dual_dist_mode'] == 'CouplingAndPrior':
            self.gen_cost = -0.5 * (self.mean_critic_rec +
                                    self.mean_critic_gen)

        ### Critic
        if self.config['dual_dist_mode'] == 'Coupling':
            self.mean_OT_dual = self.mean_critic_real - self.mean_critic_rec
        elif self.config['dual_dist_mode'] == 'Prior':
            self.mean_OT_dual = self.mean_critic_real - self.mean_critic_gen
        elif self.config['dual_dist_mode'] == 'CouplingAndPrior':
            self.mean_OT_dual = self.mean_critic_real - 0.5 * (
                self.mean_critic_rec + self.mean_critic_gen)

        self.cri_cost = -self.mean_OT_dual + self.config[
            'cri_reg_strength'] * self.cri_reg_cost
示例#7
0
	def sample(self):
		dir_normal = tf.random.normal(shape=tf.shape(self.centers))
		dir_normal_norm = helper.safe_tf_sqrt(tf.reduce_sum(dir_normal**2, axis=1, keepdims=True))
		sample_dir = dir_normal/dir_normal_norm
		sample = self.radius*sample_dir+self.centers
		return sample
示例#8
0
layer_2_rec = tf.layers.dense(inputs=layer_1_rec,
                              units=500,
                              use_bias=True,
                              activation=tf.nn.relu)
layer_3_rec = tf.layers.dense(inputs=layer_2_rec,
                              units=500,
                              use_bias=True,
                              activation=tf.nn.relu)
# layer_4_rec = tf.layers.dense(inputs = layer_3_rec, units = 500, use_bias = True, activation = tf.nn.relu)
x_rec = tf.layers.dense(inputs=layer_3_rec,
                        units=input_dim * tile_rate,
                        use_bias=True,
                        activation=None)

rec_cost = tf.reduce_mean(
    helper.safe_tf_sqrt(tf.reduce_sum((x_rec - x_input)**2, axis=1)))
MMD = helper.compute_MMD(z, z_prior, positive_only=True)

# start, timescale = 200, 1500
start, timescale = 0, 1
lambda_z_comp = 100 * helper.hardstep(
    (iter_tf - float(start)) / float(timescale))
cost = MMD + lambda_z_comp * rec_cost

optimizer = tf.train.AdamOptimizer(learning_rate=0.0001,
                                   beta1=0.9,
                                   beta2=0.999,
                                   epsilon=1e-08)
cost_step = optimizer.minimize(cost)

init = tf.initialize_all_variables()