Example #1
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]
        
        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR 

        self.prior_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.prior_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.prior_latent_code_expanded = tf.reshape(self.prior_latent_code, [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code)
        self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.obs_sample_param = self.Decoder.forward(self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if os.path.exists('./np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
            np_constant_prior_sample = np.load('./np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz')
        else:
            np_constant_prior_sample = np.random.normal(loc=0., scale=1., size=[400, self.prior_latent_code.get_shape().as_list()[-1]])
            np.save('./np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    
        
        self.constant_prior_latent_code = tf.constant(np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_prior_latent_code_expanded = tf.reshape(self.constant_prior_latent_code, [-1, 1, *self.constant_prior_latent_code.get_shape().as_list()[1:]])

        self.constant_obs_sample_param = self.Decoder.forward(self.constant_prior_latent_code_expanded)
        self.constant_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(b_mode=True)
        
        #############################################################################
        # ENCODER 

        b_deterministic = False
        if b_deterministic: 
            self.epsilon = None
        else:
            self.epsilon_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param)        
            self.epsilon = self.epsilon_dist.sample()

        self.posterior_latent_code_expanded = self.EncoderSampler.forward(self.input_sample, noise=self.epsilon) 
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,0,:]
        self.interpolated_posterior_latent_code = self.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2)
        self.interpolated_obs = self.Decoder.forward(self.interpolated_posterior_latent_code) 

        self.reconst_param = self.Decoder.forward(self.posterior_latent_code_expanded) 
        self.reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        self.kernel_function = self.inv_multiquadratics_kernel
        self.k_post_prior = tf.reduce_mean(self.kernel_function(self.posterior_latent_code, self.prior_dist.sample()))
        self.k_post_post = tf.reduce_mean(self.kernel_function(self.posterior_latent_code))
        self.k_prior_prior = tf.reduce_mean(self.kernel_function(self.prior_dist.sample()))
        self.MMD = self.k_prior_prior+self.k_post_post-2*self.k_post_prior

        self.disc_posterior = self.DivergenceLatent.forward(self.posterior_latent_code_expanded)        
        self.disc_prior = self.DivergenceLatent.forward(self.prior_latent_code_expanded)   
        self.mean_z_divergence = tf.reduce_mean(tf.log(10e-7+self.disc_prior))+tf.reduce_mean(tf.log(10e-7+1-self.disc_posterior))
        self.mean_z_divergence_minimizer = tf.reduce_mean(tf.log(10e-7+1-self.disc_posterior))

        #############################################################################
        # REGULARIZER 

        self.uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        self.uniform_sample = self.uniform_dist.sample()
        
        self.reg_trivial_sample_param = {'image': None, 'flat': None}
        try: 
            self.reg_trivial_sample_param['image'] = self.uniform_sample[:, np.newaxis, :, np.newaxis, np.newaxis]*self.obs_sample['image']+\
                                                    (1-self.uniform_sample[:, np.newaxis, :, np.newaxis, np.newaxis])*self.input_sample['image']
        except: 
            self.reg_trivial_sample_param['flat'] = self.uniform_sample[:, np.newaxis, :]*self.obs_sample['flat']+\
                                                   (1-self.uniform_sample[:, np.newaxis, :])*self.input_sample['flat']
        self.reg_trivial_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reg_trivial_sample_param)
        self.reg_trivial_sample = self.reg_trivial_dist.sample(b_mode=True)

        self.reg_target_dist = self.obs_sample_dist
        self.reg_target_sample = self.obs_sample
        self.reg_dist = self.reg_trivial_dist
        self.reg_sample = self.reg_trivial_sample 
        
        #############################################################################
        # CRITIC 

        self.critic_real = self.Discriminator.forward(self.input_sample)
        self.critic_gen = self.Discriminator.forward(self.obs_sample)
        self.critic_reg_trivial = self.Discriminator.forward(self.reg_trivial_sample)
        self.critic_reg = self.critic_reg_trivial

        try:
            self.trivial_grad = tf.gradients(self.critic_reg_trivial, [self.reg_trivial_sample['image']])[0]
            self.trivial_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_grad**2, axis=[-1,-2,-3], keep_dims=False)[:,:,np.newaxis])
        except: 
            self.trivial_grad = tf.gradients(self.critic_reg_trivial, [self.reg_trivial_sample['flat']])[0]
            self.trivial_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_grad**2, axis=[-1,], keep_dims=True))      

        self.trivial_grad_norm_1_penalties = ((self.trivial_grad_norm-1)**2)

        self.mean_critic_real = tf.reduce_mean(self.critic_real)
        self.mean_critic_gen = tf.reduce_mean(self.critic_gen)
        self.mean_critic_reg = tf.reduce_mean(self.critic_reg)
        self.mean_OT_dual = self.mean_critic_real-self.mean_critic_gen

        #############################################################################
        # OBJECTIVES
        self.enc_reg_strength, self.disc_reg_strength = 100, 10

        self.real_reconst_distances_sq = self.metric_distance_sq(self.input_sample, self.reconst_sample)
        self.OT_primal = tf.reduce_mean(helper.safe_tf_sqrt(self.real_reconst_distances_sq))
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        self.enc_reg_cost = self.MMD # self.mean_z_divergence_minimizer
        self.mean_POT_primal = self.mean_OT_primal+self.enc_reg_strength*self.enc_reg_cost
        
        self.disc_reg_cost = tf.reduce_mean(self.trivial_grad_norm_1_penalties)

        # WGAN-GP
        # self.div_cost = -self.mean_z_divergence
        self.enc_cost = self.mean_POT_primal
        self.disc_cost = -self.mean_OT_dual+self.disc_reg_strength*self.disc_reg_cost
        self.gen_cost = -self.mean_critic_gen
Example #2
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            batch_size_tf = tf.shape(self.input_sample['image'])[0]

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code)
        self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(batch_size_tf, 1)),
                tf.ones(shape=(batch_size_tf, 1))
            ],
                             axis=1))
        self.convex_mix = self.uniform_dist.sample()

        self.prior_latent_code_expanded = tf.reshape(
            self.prior_latent_code,
            [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        self.obs_sample_param = self.Decoder.forward(
            self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        #############################################################################
        b_use_reconst_as_reg = True
        self.posterior_param_expanded = self.Encoder.forward(self.input_sample)
        self.posterior_param = self.posterior_param_expanded[:, 0, :]
        self.posterior_dist = distributions.DiagonalGaussianDistribution(
            params=self.posterior_param)
        self.posterior_latent_code = self.posterior_dist.sample()
        self.posterior_latent_code_expanded = self.posterior_latent_code[:, np.
                                                                         newaxis, :]

        self.reconst_param = self.Decoder.forward(
            self.posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        if b_use_reconst_as_reg:
            self.reg_target_dist = self.reconst_dist
            self.reg_target_sample = self.reconst_sample
        else:
            self.reg_target_dist = self.obs_sample_dist
            self.reg_target_sample = self.obs_sample

        self.neg_cross_ent_posterior = self.prior_dist.log_pdf(
            self.posterior_latent_code)
        self.mean_neg_cross_ent_posterior = tf.reduce_mean(
            self.neg_cross_ent_posterior)
        self.neg_ent_posterior = self.posterior_dist.log_pdf(
            self.posterior_latent_code)
        self.mean_neg_ent_posterior = tf.reduce_mean(self.neg_ent_posterior)

        # self.kl_posterior_prior = distributions.KLDivDiagGaussianVsDiagGaussian().forward(self.posterior_dist, self.prior_dist)
        # self.mean_kl_posterior_prior = tf.reduce_mean(self.kl_posterior_prior)

        self.mean_kl_posterior_prior = -self.mean_neg_cross_ent_posterior + self.mean_neg_ent_posterior
        #############################################################################

        self.reg_sample_param = {'image': None, 'flat': None}
        try:
            self.reg_sample_param['image'] = self.convex_mix[:, np.newaxis, :, np.newaxis, np.newaxis]*self.reg_target_sample['image']+\
                                            (1-self.convex_mix[:, np.newaxis, :, np.newaxis, np.newaxis])*self.input_sample['image']
        except:
            self.reg_sample_param['flat'] = self.convex_mix[:, np.newaxis, :]*self.reg_target_sample['flat']+\
                                            (1-self.convex_mix[:, np.newaxis, :])*self.input_sample['flat']
        self.reg_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reg_sample_param)
        self.reg_sample = self.reg_dist.sample(b_mode=True)

        self.critic_real = self.Discriminator.forward(self.input_sample)
        self.critic_gen = self.Discriminator.forward(self.obs_sample)
        self.critic_reg = self.Discriminator.forward(self.reg_sample)

        lambda_t = 1
        self.real_reconst_distances_sq = self.metric_distance_sq(
            self.input_sample, self.reconst_sample)
        self.autoencode_costs = self.real_reconst_distances_sq / 2 + lambda_t * self.mean_kl_posterior_prior

        try:
            self.convex_grad = tf.gradients(self.critic_reg,
                                            [self.reg_sample['image']])[0]
            self.convex_grad_norm = helper.safe_tf_sqrt(
                tf.reduce_sum(self.convex_grad**2,
                              axis=[-1, -2, -3],
                              keep_dims=False)[:, :, np.newaxis])
        except:
            self.convex_grad = tf.gradients(self.critic_reg,
                                            [self.reg_sample['flat']])[0]
            self.convex_grad_norm = helper.safe_tf_sqrt(
                tf.reduce_sum(self.convex_grad**2, axis=[-1], keep_dims=True))
        self.gradient_penalties = ((self.convex_grad_norm - 1)**2)

        self.mean_critic_real = tf.reduce_mean(self.critic_real)
        self.mean_critic_gen = tf.reduce_mean(self.critic_gen)
        self.mean_critic_reg = tf.reduce_mean(self.critic_reg)
        self.mean_autoencode_cost = tf.reduce_mean(self.autoencode_costs)
        self.mean_gradient_penalty = tf.reduce_mean(self.gradient_penalties)

        self.regularizer_cost = 10 * self.mean_gradient_penalty
        self.discriminator_cost = -self.mean_critic_real + self.mean_critic_gen + self.regularizer_cost
        self.generator_cost = -self.mean_critic_gen
        self.transporter_cost = self.mean_autoencode_cost
Example #3
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        properties_dirac_dict = copy.deepcopy(batch['observed']['properties'])
        if len(properties_dirac_dict['flat']) > 0:
            for e in properties_dirac_dict['flat']:
                e['dist'] = 'dirac'
        else:
            for e in properties_dirac_dict['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=properties_dirac_dict, params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            batch_size_tf = tf.shape(self.input_sample['image'])[0]

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code)
        self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(batch_size_tf, 1)),
                tf.ones(shape=(batch_size_tf, 1))
            ],
                             axis=1))
        self.convex_mix = self.uniform_dist.sample()

        self.prior_latent_code_expanded = tf.reshape(
            self.prior_latent_code,
            [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        self.obs_sample_param = self.Decoder.forward(
            self.prior_latent_code_expanded)

        # self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param)
        # self.obs_sample = self.obs_sample_dist.sample()

        self.obs_sample_dist_pre = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist_pre.sample()
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=properties_dirac_dict, params=self.obs_sample)

        #############################################################################

        self.posterior_param_expanded = self.Encoder.forward(self.input_sample)
        self.posterior_param = self.posterior_param_expanded[:, 0, :]
        self.posterior_dist = distributions.DiagonalGaussianDistribution(
            params=self.posterior_param)
        self.posterior_latent_code = self.posterior_dist.sample()
        self.posterior_latent_code_expanded = self.posterior_latent_code[:, np.
                                                                         newaxis, :]

        self.reconst_param = self.Decoder.forward(
            self.posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)
        self.reconst_sample_noised = self.reconst_dist.sample()

        self.real_log_pdf = self.reconst_dist.log_pdf(self.input_sample)

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = distributions.ProductDistribution(
            sample_properties=properties_dirac_dict,
            params=self.reconst_sample_noised)
        self.reg_sample = self.reg_dist.sample(b_mode=True)

        self.neg_cross_ent_posterior = self.prior_dist.log_pdf(
            self.posterior_latent_code)
        self.mean_neg_cross_ent_posterior = tf.reduce_mean(
            self.neg_cross_ent_posterior)
        self.neg_ent_posterior = self.posterior_dist.log_pdf(
            self.posterior_latent_code)
        self.mean_neg_ent_posterior = tf.reduce_mean(self.neg_ent_posterior)

        # self.kl_posterior_prior = distributions.KLDivDiagGaussianVsDiagGaussian().forward(self.posterior_dist, self.prior_dist)
        # self.mean_kl_posterior_prior = tf.reduce_mean(self.kl_posterior_prior)

        self.mean_kl_posterior_prior = -self.mean_neg_cross_ent_posterior + self.mean_neg_ent_posterior
        self.mean_real_log_pdf = tf.reduce_mean(self.real_log_pdf)
        #############################################################################

        self.mean_autoencode_cost = -self.mean_real_log_pdf  #+self.mean_kl_posterior_prior
        self.generator_cost = self.mean_autoencode_cost
Example #4
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR

        self.pre_prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.pre_prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.pre_prior_param)

        self.pre_prior_latent_code = self.pre_prior_dist.sample()
        self.pre_prior_latent_code_expanded = self.pre_prior_latent_code[:, np.
                                                                         newaxis, :]
        self.neg_ent_prior = self.pre_prior_dist.log_pdf(
            self.pre_prior_latent_code)
        self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.prior_latent_code = self.PriorExpandMap.forward(
            self.pre_prior_latent_code_expanded)
        self.prior_latent_code_expanded = self.prior_latent_code[:,
                                                                 np.newaxis, :]

        self.obs_sample_param = self.Generator.forward(
            self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if not os.path.exists(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/'):
            os.makedirs(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/')
        if os.path.exists(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.pre_prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.pre_prior_latent_code.get_shape().as_list()[-1]) +
                '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[
                    400,
                    self.pre_prior_latent_code.get_shape().as_list()[-1]
                ])
            np.save(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.pre_prior_latent_code.get_shape().as_list()[-1]) +
                '.npz', np_constant_prior_sample)

        self.constant_pre_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_pre_prior_latent_code_expanded = self.constant_pre_prior_latent_code[:,
                                                                                           np
                                                                                           .
                                                                                           newaxis, :]

        self.constant_prior_latent_code = self.PriorExpandMap.forward(
            self.constant_pre_prior_latent_code_expanded)
        self.constant_prior_latent_code_expanded = self.constant_prior_latent_code[:,
                                                                                   np
                                                                                   .
                                                                                   newaxis, :]

        self.constant_obs_sample_param = self.Generator.forward(
            self.constant_prior_latent_code_expanded)
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(
            b_mode=True)

        if self.config['n_latent'] == 2:
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate(
                (xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]),
                axis=1)

            self.constant_prior_grid_latent_code = tf.constant(
                np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_prior_grid_latent_code_expanded = self.constant_prior_grid_latent_code[:,
                                                                                                 np
                                                                                                 .
                                                                                                 newaxis, :]

            self.constant_obs_grid_sample_param = self.Generator.forward(
                self.constant_prior_grid_latent_code_expanded)
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(
                sample_properties=batch['observed']['properties'],
                params=self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(
                b_mode=True)

        #############################################################################
        # ENCODER
        self.epsilon_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.epsilon_dist = distributions.DiagonalGaussianDistribution(
            params=self.epsilon_param)

        if self.config['encoder_mode'] == 'Deterministic':
            self.epsilon = None
        if self.config['encoder_mode'] == 'Gaussian' or self.config[
                'encoder_mode'] == 'UnivApprox' or self.config[
                    'encoder_mode'] == 'UnivApproxNoSpatial' or self.config[
                        'encoder_mode'] == 'UnivApproxSine':
            self.epsilon = self.epsilon_dist.sample()

        self.pre_posterior_latent_code_expanded, _ = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.pre_posterior_latent_code = self.pre_posterior_latent_code_expanded[:,
                                                                                 0, :]

        self.nball_param = tf.concat([
            self.pre_posterior_latent_code,
            0.1 * tf.ones(shape=(self.batch_size_tf, 1))
        ],
                                     axis=1)
        self.nball_dist = distributions.UniformBallDistribution(
            params=self.tiny_perturb_param)
        self.posterior_latent_code = self.nball_dist.sample()
        self.posterior_latent_code_log_pdf = self.nball_dist.log_pdf(
            self.posterior_latent_code)
        # self.posterior_latent_code = self.pre_posterior_latent_code
        self.posterior_latent_code_expanded = self.posterior_latent_code[:, np.
                                                                         newaxis, :]

        # self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2)
        self.interpolated_pre_posterior_latent_code = helper.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        self.interpolated_posterior_latent_code_collapsed = self.PriorExpandMap.forward(
            tf.reshape(self.interpolated_pre_posterior_latent_code, [
                -1, 1,
                self.interpolated_pre_posterior_latent_code.get_shape().
                as_list()[-1]
            ]))
        self.interpolated_posterior_latent_code = tf.reshape(
            self.interpolated_posterior_latent_code_collapsed, [
                -1, *self.interpolated_pre_posterior_latent_code.get_shape().
                as_list()[-2:]
            ])
        self.interpolated_obs = self.Generator.forward(
            self.interpolated_posterior_latent_code)

        self.reconst_param = self.Generator.forward(
            self.pre_posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        ### Primal Penalty
        self.enc_reg_cost = helper.compute_MMD(self.pre_prior_latent_code,
                                               self.pre_posterior_latent_code)
        self.div_posterior = self.Diverger.forward(
            self.pre_posterior_latent_code_expanded)
        self.div_prior = self.Diverger.forward(self.prior_latent_code_expanded)

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample

        #############################################################################

        # OBJECTIVES
        # Divergence
        if self.config['divergence_mode'] == 'GAN' or self.config[
                'divergence_mode'] == 'NS-GAN':
            self.div_cost = -(tf.reduce_mean(
                tf.log(tf.nn.sigmoid(self.div_posterior) +
                       10e-7)) + tf.reduce_mean(
                           tf.log(1 - tf.nn.sigmoid(self.div_prior) + 10e-7)))
        if self.config['divergence_mode'] == 'WGAN-GP':
            uniform_dist = distributions.UniformDistribution(
                params=tf.concat([
                    tf.zeros(shape=(self.batch_size_tf, 1)),
                    tf.ones(shape=(self.batch_size_tf, 1))
                ],
                                 axis=1))
            uniform_w = uniform_dist.sample()
            self.trivial_line = uniform_w[:, np.
                                          newaxis, :] * self.pre_posterior_latent_code_expanded + (
                                              1 - uniform_w[:, np.newaxis, :]
                                          ) * self.prior_latent_code_expanded
            self.div_trivial_line = self.Diverger.forward(self.trivial_line)
            self.trivial_line_grad = tf.gradients(
                tf.reduce_sum(self.div_trivial_line), [self.trivial_line])[0]
            self.trivial_line_grad_norm = helper.safe_tf_sqrt(
                tf.reduce_sum(self.trivial_line_grad**2,
                              axis=-1,
                              keep_dims=False)[:, :, np.newaxis])
            self.trivial_line_grad_norm_1_penalties = ((
                self.trivial_line_grad_norm - 1)**2)
            self.div_reg_cost = tf.reduce_mean(
                self.trivial_line_grad_norm_1_penalties)
            # self.div_cost = -(tf.reduce_mean(self.div_posterior)-tf.reduce_mean(self.div_prior))+10*self.div_reg_cost

        # ### Encoder
        b_use_timer, timescale, starttime = False, 10, 5
        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        # if b_use_timer:
        #     self.mean_POT_primal = self.mean_OT_primal+helper.hardstep((self.epoch-float(starttime))/float(timescale))*self.config['enc_reg_strength']*self.enc_reg_cost
        # else:
        #     self.mean_POT_primal = self.mean_OT_primal+self.config['enc_reg_strength']*self.enc_reg_cost
        self.mean_POT_primal = self.mean_OT_primal
        self.enc_cost = self.mean_POT_primal

        # ### Critic
        # # self.cri_cost = helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # if self.config['divergence_mode'] == 'NS-GAN':
        #     self.cri_cost = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'GAN':
        #     self.cri_cost = tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'WGAN-GP':
        #     self.cri_cost = -tf.reduce_mean(self.div_prior)+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)

        ### Generator
        self.gen_cost = self.mean_OT_primal
Example #5
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            batch_size_tf = tf.shape(self.input_sample['image'])[0]

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.prior_latent_code_logpdf = self.prior_dist.log_pdf(
            self.prior_latent_code)
        self.prior_logpdf = tf.reduce_mean(self.prior_latent_code_logpdf)

        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(batch_size_tf, 1)),
                tf.ones(shape=(batch_size_tf, 1))
            ],
                             axis=1))
        self.convex_mix = self.uniform_dist.sample()

        self.prior_latent_code_expanded = tf.reshape(
            self.prior_latent_code,
            [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        self.obs_sample_param = self.Decoder.forward(
            self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        #############################################################################
        b_stochastic_encoder = True
        self.epsilon_param = self.EpsilonMap.forward(
            (tf.zeros(shape=(batch_size_tf, 1)), ))
        self.epsilon_dist = distributions.DiagonalGaussianDistribution(
            params=self.epsilon_param)
        self.epsilon = self.epsilon_dist.sample()
        if b_stochastic_encoder:
            self.posterior_latent_code_expanded = self.EncodingPlan.forward(
                self.input_sample, epsilon_sample=self.epsilon)
        else:
            self.posterior_latent_code_expanded = self.EncodingPlan.forward(
                self.input_sample)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,
                                                                         0, :]
        self.posterior_latent_code_logpdf = self.prior_dist.log_pdf(
            self.posterior_latent_code)
        self.posterior_logpdf = tf.reduce_mean(
            self.posterior_latent_code_logpdf)
        self.reg_target_param = self.Decoder.forward(
            self.posterior_latent_code_expanded)
        self.reg_target_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reg_target_param)
        self.reg_target_sample = self.reg_target_dist.sample(b_mode=True)

        # #################################################################################

        self.reg_sample_param = {'image': None, 'flat': None}
        try:
            self.reg_sample_param['image'] = self.convex_mix[:, np.newaxis, :, np.newaxis, np.newaxis]*self.reg_target_sample['image']+\
                                            (1-self.convex_mix[:, np.newaxis, :, np.newaxis, np.newaxis])*self.input_sample['image']
        except:
            self.reg_sample_param['flat'] = self.convex_mix[:, np.newaxis, :]*self.reg_target_sample['flat']+\
                                            (1-self.convex_mix[:, np.newaxis, :])*self.input_sample['flat']
        self.reg_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reg_sample_param)
        self.reg_sample = self.reg_dist.sample(b_mode=True)

        self.critic_real = self.Discriminator.forward(self.input_sample)
        self.critic_gen = self.Discriminator.forward(self.obs_sample)
        self.critic_reg = self.Discriminator.forward(self.reg_sample)

        self.real_reg_distances_sq = self.metric_distance_sq(
            self.input_sample, self.reg_sample)
        self.real_reg_slopes_sq = ((self.critic_real - self.critic_reg)**
                                   2) / (self.real_reg_distances_sq + 1e-7)
        self.slope_penalties = ((self.real_reg_slopes_sq - 1)**2)

        self.mean_critic_real = tf.reduce_mean(self.critic_real)
        self.mean_critic_gen = tf.reduce_mean(self.critic_gen)
        self.mean_critic_reg = tf.reduce_mean(self.critic_reg)

        self.mean_real_reg_slopes_sq = tf.reduce_mean(self.real_reg_slopes_sq)
        self.mean_slope_penalty = tf.reduce_mean(self.slope_penalties)

        self.regularizer_cost = 10 * self.mean_slope_penalty
        self.discriminator_cost = -self.mean_critic_real + self.mean_critic_gen + self.regularizer_cost
        self.generator_cost = -self.mean_critic_gen
        self.transporter_cost = -self.mean_real_reg_slopes_sq
Example #6
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]
        
        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR 

        self.pre_prior_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))      
        # self.pre_prior_dist = distributions.UniformDistribution(params = self.pre_prior_param)        
        self.pre_prior_dist = distributions.DiagonalBetaDistribution(params = self.pre_prior_param)        
        self.pre_prior_latent_code = self.pre_prior_dist.sample()
        self.pre_prior_latent_code_log_pdf = self.pre_prior_dist.log_pdf(self.pre_prior_latent_code)

        self.ambient_prior_param = self.PriorMapBetaInverted.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))      
        self.ambient_prior_dist = distributions.DiagonalBetaDistribution(params = self.ambient_prior_param)        

        self.flow_param_list = self.FlowMap.forward()
        # self.flow_object = transforms.SerialFlow([ \
        #                                           transforms.InverseOpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
        #                                           transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[0]), 
        #                                           transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[1]),
        #                                           transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[2]),
        #                                           transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[3]),
        #                                           # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[4]),
        #                                           # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[5]),
        #                                           # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[6]),
        #                                           transforms.OpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
        #                                           ])
        
        # self.flow_object = transforms.SerialFlow([ \
        #                                           transforms.InverseOpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[0]), 
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[7]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[1]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[8]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[2]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[9]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[3]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[10]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[4]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[11]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[5]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[12]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[6]),
        #                                           transforms.OpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
        #                                           ])
        
        self.flow_object = transforms.SerialFlow([ \
                                                  transforms.InverseOpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
                                                  transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[0]), 
                                                  transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[1]),
                                                  transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[2]),
                                                  transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[3]),
                                                  # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  # transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[4]),
                                                  # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  # transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[5]),
                                                  # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  # transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[6]),
                                                  transforms.OpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
                                                  ])
        self.gen_flow_object = self.flow_object
        self.prior_latent_code, self.prior_latent_code_log_pdf = self.flow_object.inverse_transform(self.pre_prior_latent_code, self.pre_prior_latent_code_log_pdf)

        self.obs_sample_param = self.Generator.forward(self.prior_latent_code[:, np.newaxis, :])
        self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)        

        if not os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/'): os.makedirs(str(Path.home())+'/ExperimentalResults/FixedSamples/')
        # if os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform11_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
        if os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform01_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
            # np_constant_prior_sample = np.load(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform11_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz')
            np_constant_prior_sample = np.load(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform01_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz')
        else:
            # np_constant_prior_sample = np.random.uniform(low=-1., high=1., size=[400, self.pre_prior_latent_code.get_shape().as_list()[-1]])
            np_constant_prior_sample = np.random.uniform(low=0., high=1., size=[400, self.pre_prior_latent_code.get_shape().as_list()[-1]])
            # np.save(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform11_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    
            np.save(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform01_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    

        self.constant_pre_prior_latent_code = tf.constant(np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_prior_latent_code, _ = self.flow_object.inverse_transform(self.constant_pre_prior_latent_code, tf.zeros(shape=(self.batch_size_tf, 1)))

        self.constant_obs_sample_param = self.Generator.forward(self.constant_prior_latent_code[:, np.newaxis, :])
        self.constant_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(b_mode=True)
    
        if self.config['n_latent'] == 2: 
            # x = np.linspace(-1, 1, 20)
            # y = np.linspace(1, -1, 20)
            x = np.linspace(0, 1, 20)
            y = np.linspace(1, 0, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate((xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]), axis=1)        
            self.constant_prior_grid_latent_code = tf.constant(np.asarray(np_constant_prior_grid_sample), dtype=np.float32)

            self.constant_obs_grid_sample_param = self.Generator.forward(self.constant_prior_grid_latent_code[:, np.newaxis, :])
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(b_mode=True)
                
        #############################################################################
        # ENCODER 
        self.epsilon_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.epsilon_dist = distributions.DiagonalGaussianDistribution(params=self.epsilon_param) 
        
        if self.config['encoder_mode'] == 'Deterministic': 
            self.epsilon = None
        if self.config['encoder_mode'] == 'Gaussian' or self.config['encoder_mode'] == 'UnivApprox' or self.config['encoder_mode'] == 'UnivApproxNoSpatial' or self.config['encoder_mode'] == 'UnivApproxSine':         
            self.epsilon = self.epsilon_dist.sample()

        # self.pre_posterior_latent_code = 0.95*tf.nn.tanh(self.Encoder.forward(self.input_sample, noise=self.epsilon))[:,0,:]
        # self.pre_posterior_latent_code = 0.9*tf.nn.sigmoid(self.Encoder.forward(self.input_sample, noise=self.epsilon))[:,0,:]+0.05
        self.pre_posterior_latent_code = 0.8*tf.nn.sigmoid(self.Encoder.forward(self.input_sample, noise=self.epsilon))[:,0,:]+0.1
        self.nball_param = tf.concat([self.pre_posterior_latent_code, 0.05*tf.ones(shape=(self.batch_size_tf, 1))], axis=1)
        self.nball_dist = distributions.UniformBallDistribution(params=self.nball_param) 
        self.posterior_latent_code = self.nball_dist.sample()
        self.posterior_latent_code_log_pdf = -np.log(50000)+self.nball_dist.log_pdf(self.posterior_latent_code)
        self.transformed_posterior_latent_code, self.transformed_posterior_latent_code_log_pdf = self.flow_object.transform(self.posterior_latent_code, self.posterior_latent_code_log_pdf)

        self.hollow_nball_param = tf.concat([self.pre_posterior_latent_code, 0.05*tf.ones(shape=(self.batch_size_tf, 1)), 0.1*tf.ones(shape=(self.batch_size_tf, 1))], axis=1)
        self.hollow_nball_dist = distributions.UniformHollowBallDistribution(params=self.hollow_nball_param) 
        self.hollow_posterior_latent_code = self.hollow_nball_dist.sample()
        self.hollow_posterior_latent_code_log_pdf = -np.log(50000)+self.hollow_nball_dist.log_pdf(self.hollow_posterior_latent_code)
        self.transformed_hollow_posterior_latent_code, self.transformed_hollow_posterior_latent_code_log_pdf = self.flow_object.transform(self.hollow_posterior_latent_code, self.hollow_posterior_latent_code_log_pdf)

        self.ambient_param = self.AmbientMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.ambient_dist = distributions.UniformDistribution(params = self.ambient_param)  
        self.ambient_latent_code = self.ambient_dist.sample()
        self.ambient_latent_code_log_pdf = self.ambient_dist.log_pdf(self.ambient_latent_code)
        self.transformed_ambient_latent_code, self.transformed_ambient_latent_code_log_pdf = self.flow_object.transform(self.ambient_latent_code, self.ambient_latent_code_log_pdf)

        self.KL_transformed_prior_per = self.transformed_posterior_latent_code_log_pdf-self.pre_prior_dist.log_pdf(self.transformed_posterior_latent_code)
        self.KL_transformed_hollow_prior_per = self.transformed_hollow_posterior_latent_code_log_pdf-self.ambient_prior_dist.log_pdf(self.transformed_hollow_posterior_latent_code)
        self.KL_transformed_ambient_prior_per = self.transformed_ambient_latent_code_log_pdf-self.ambient_prior_dist.log_pdf(self.transformed_ambient_latent_code)
        self.KL_transformed_prior = tf.reduce_mean(self.KL_transformed_prior_per)
        self.KL_transformed_hollow_prior = tf.reduce_mean(self.KL_transformed_hollow_prior_per)
        self.KL_transformed_ambient_prior = tf.reduce_mean(self.KL_transformed_ambient_prior_per)

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2)
        self.interpolated_obs = self.Generator.forward(self.interpolated_posterior_latent_code) 

        self.reconst_param = self.Generator.forward(self.posterior_latent_code[:, np.newaxis, :]) 
        self.reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        ### Primal Penalty
        
        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample
        
        #############################################################################

        # OBJECTIVES
        # Divergence
        # if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN':
        #     self.div_cost = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_posterior)+10e-7))+tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7)))
        # if self.config['divergence_mode'] == 'WGAN-GP':
        #     uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        #     uniform_w = uniform_dist.sample()
        #     self.trivial_line = uniform_w[:,np.newaxis,:]*self.pre_posterior_latent_code_expanded+(1-uniform_w[:,np.newaxis,:])*self.prior_latent_code_expanded
        #     self.div_trivial_line = self.Diverger.forward(self.trivial_line)
        #     self.trivial_line_grad = tf.gradients(tf.reduce_sum(self.div_trivial_line), [self.trivial_line])[0]
        #     self.trivial_line_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_line_grad**2, axis=-1, keep_dims=False)[:,:,np.newaxis])
        #     self.trivial_line_grad_norm_1_penalties = ((self.trivial_line_grad_norm-1)**2)
        #     self.div_reg_cost = tf.reduce_mean(self.trivial_line_grad_norm_1_penalties)
        #     # self.div_cost = -(tf.reduce_mean(self.div_posterior)-tf.reduce_mean(self.div_prior))+10*self.div_reg_cost
        # self.div_cost = (self.KL_transformed_prior+self.KL_transformed_ambient_prior+self.KL_transformed_hollow_prior)
        self.div_cost = (self.KL_transformed_prior+self.KL_transformed_hollow_prior)
       
        # ### Encoder
        # b_use_timer, timescale, starttime = False, 10, 5
        self.OT_primal = self.sample_distance_function(self.input_sample, self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        # if b_use_timer:
        #     self.mean_POT_primal = self.mean_OT_primal+helper.hardstep((self.epoch-float(starttime))/float(timescale))*self.config['enc_reg_strength']*self.enc_reg_cost
        # else:
        #     self.mean_POT_primal = self.mean_OT_primal+self.config['enc_reg_strength']*self.enc_reg_cost
        # self.enc_cost = self.mean_POT_primal
        self.enc_cost = self.mean_OT_primal

        # ### Critic
        # # self.cri_cost = helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # if self.config['divergence_mode'] == 'NS-GAN': 
        #     self.cri_cost = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'GAN': 
        #     self.cri_cost = tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'WGAN-GP': 
        #     self.cri_cost = -tf.reduce_mean(self.div_prior)+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)

        ### Generator
        self.gen_cost = self.mean_OT_primal
Example #7
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.obs_sample_param = self.Generator.forward(
            self.prior_latent_code[:, np.newaxis, :])
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if not os.path.exists(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/'):
            os.makedirs(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/')
        if os.path.exists(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[400,
                      self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz',
                np_constant_prior_sample)

        self.constant_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_obs_sample_param = self.Generator.forward(
            self.constant_prior_latent_code[:, np.newaxis, :])
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(
            b_mode=True)

        if self.config['n_latent'] == 2:
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate(
                (xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]),
                axis=1)

            self.constant_prior_grid_latent_code = tf.constant(
                np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_obs_grid_sample_param = self.Generator.forward(
                self.constant_prior_grid_latent_code[:, np.newaxis, :])
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(
                sample_properties=batch['observed']['properties'],
                params=self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(
                b_mode=True)

        #############################################################################
        # ENCODER
        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(self.batch_size_tf, 1)),
                tf.ones(shape=(self.batch_size_tf, 1))
            ],
                             axis=1))

        if self.config['encoder_mode'] == 'Deterministic':
            pdb.set_trace()
        if self.config['encoder_mode'] == 'Gaussian' or self.config[
                'encoder_mode'] == 'UnivApprox' or self.config[
                    'encoder_mode'] == 'UnivApproxNoSpatial' or self.config[
                        'encoder_mode'] == 'UnivApproxSine':
            self.epsilon_param = self.EpsilonMap.forward(
                (tf.zeros(shape=(self.batch_size_tf, 1)), ))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(
                params=self.epsilon_param)
            # self.epsilon_dist = distributions.BernoulliDistribution(params = self.epsilon_param)
            self.epsilon = self.epsilon_dist.sample()

        self.posterior_latent_code_expanded, self.posterior_latent_code_det_expanded = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,
                                                                         0, :]
        self.posterior_latent_code_det = self.posterior_latent_code_det_expanded[:,
                                                                                 0, :]

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        self.interpolated_obs = self.Generator.forward(
            self.interpolated_posterior_latent_code)

        self.reconst_param = self.Generator.forward(
            self.posterior_latent_code[:, np.newaxis, :])
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample

        #############################################################################
        # OBJECTIVES

        ### Encoder
        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        self.enc_cost = self.mean_OT_primal

        ### Generator
        self.gen_cost = self.mean_OT_primal
Example #8
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]
        
        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR 

        self.prior_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.prior_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()        
        self.obs_sample_param = self.Generator.forward(self.prior_latent_code[:, np.newaxis, :])
        self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)
        
        if not os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/'): os.makedirs(str(Path.home())+'/ExperimentalResults/FixedSamples/')
        if os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
            np_constant_prior_sample = np.load(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz')
        else:
            np_constant_prior_sample = np.random.normal(loc=0., scale=1., size=[400, self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    
        
        self.constant_prior_latent_code = tf.constant(np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_obs_sample_param = self.Generator.forward(self.constant_prior_latent_code[:, np.newaxis, :])
        self.constant_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(b_mode=True)
                
        if self.config['n_latent'] == 2: 
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate((xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]), axis=1)
        
            self.constant_prior_grid_latent_code = tf.constant(np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_obs_grid_sample_param = self.Generator.forward(self.constant_prior_grid_latent_code[:, np.newaxis, :])
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(b_mode=True)
                
        #############################################################################
        # ENCODER 
        self.uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        self.uniform_sample = self.uniform_dist.sample()

        if self.config['encoder_mode'] == 'Deterministic': 
            pdb.set_trace()
        if self.config['encoder_mode'] == 'Gaussian' or self.config['encoder_mode'] == 'UnivApprox' or self.config['encoder_mode'] == 'UnivApproxNoSpatial' or self.config['encoder_mode'] == 'UnivApproxSine': 
            self.epsilon_param = self.EpsilonMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(params = self.epsilon_param)        
            # self.epsilon_dist = distributions.BernoulliDistribution(params = self.epsilon_param)        
            self.epsilon = self.epsilon_dist.sample()
            

        self.posterior_latent_code_expanded, self.posterior_latent_code_det_expanded = self.Encoder.forward(self.input_sample, noise=self.epsilon)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,0,:]
        self.posterior_latent_code_det = self.posterior_latent_code_det_expanded[:,0,:]

        # self.flow_param_list = self.FlowMap.forward()
        # self.flow_object = transforms.SerialFlow([\
        #                                           transforms.NonLinearIARFlow(input_dim=2*self.config['n_latent'], parameters=self.flow_param_list[0]), 
        #                                           transforms.CustomSpecificOrderDimensionFlow(input_dim=2*self.config['n_latent']), 
        #                                           transforms.NonLinearIARFlow(input_dim=2*self.config['n_latent'], parameters=self.flow_param_list[1]),
        #                                           # transforms.CustomSpecificOrderDimensionFlow(input_dim=2*self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=2*self.config['n_latent'], parameters=self.flow_param_list[2]),
        #                                           # transforms.CustomSpecificOrderDimensionFlow(input_dim=2*self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=2*self.config['n_latent'], parameters=self.flow_param_list[3]),
        #                                           ])
        
        # self.transformed_posterior_latent_code, _ = self.flow_object.transform(tf.concat([self.posterior_latent_code_det, self.epsilon], axis=1), tf.zeros(shape=(self.batch_size_tf, 1)))
        # self.posterior_latent_code = self.transformed_posterior_latent_code[:, self.config['n_latent']:]

        self.info_param = self.InfoMap.forward(self.posterior_latent_code)
        self.info_dist = distributions.DiagonalGaussianDistribution(params = self.info_param)
        # self.info_dist = distributions.BernoulliDistribution(params = self.info_param)
        self.epsilon_info_log_pdf = self.info_dist.log_pdf(self.epsilon)

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2)
        self.interpolated_obs = self.Generator.forward(self.interpolated_posterior_latent_code) 

        self.reconst_param = self.Generator.forward(self.posterior_latent_code[:, np.newaxis, :]) 
        self.reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        self.MMD = helper.compute_MMD(self.posterior_latent_code, self.prior_dist.sample())
        self.info_cost = -tf.reduce_mean(self.epsilon_info_log_pdf)

        if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN' or self.config['divergence_mode'] == 'WGAN-GP':
            self.div_posterior = self.Diverger.forward(self.posterior_latent_code[:, np.newaxis, :])        
            self.div_prior = self.Diverger.forward(self.prior_latent_code[:, np.newaxis, :])
            
            # self.mean_z_divergence = tf.reduce_mean(tf.log(1e-7+self.div_prior))+tf.reduce_mean(tf.log(1e-7+1-self.div_posterior))
            # if self.config['divergence_mode'] == 'NS-GAN': 
            #     self.enc_reg_cost = -tf.reduce_mean(tf.log(1e-7+self.div_posterior))
            # elif self.config['divergence_mode'] == 'GAN': 
            #     self.enc_reg_cost = tf.reduce_mean(tf.log(1e-7+1-self.div_posterior))

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample
        
        #############################################################################
        # OBJECTIVES

        # ## Divergence
        # if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN':
        #     if b_use_timer:
        #         self.div_cost = helper.hardstep((self.epoch-float(starttime))/float(timescale)+0.0001)*(self.config['enc_reg_strength']*(-self.mean_z_divergence))
        #     else:
        #         self.div_cost = self.config['enc_reg_strength']*(-self.mean_z_divergence)

        # ### Encoder
        # b_use_timer, timescale, starttime = True, 10, 10
        # self.OT_primal = self.sample_distance_function(self.input_sample, self.reconst_sample)
        # self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        # if b_use_timer:
        #     self.enc_overall_cost = helper.hardstep((self.epoch-float(starttime))/float(timescale)+0.0001)*(self.config['enc_reg_strength']*self.enc_reg_cost+self.config['enc_inv_MMD_strength']*self.cri_reg_cost)
        # else:
        #     self.enc_overall_cost = (self.config['enc_reg_strength']*self.enc_reg_cost+self.config['enc_inv_MMD_strength']*self.cri_reg_cost)
        # self.enc_cost = self.mean_OT_primal + self.enc_overall_cost
        
        # ### Critic
        # self.cri_cost = self.config['enc_inv_MMD_strength']*self.cri_reg_cost

        # ### Generator
        # self.gen_cost = self.mean_OT_primal

        # b_use_timer, timescale, starttime = True, 5, 5 # MNIST
        b_use_timer, timescale, starttime = True, 100, 5
        # # ## Divergence
        # if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN':
        #     self.div_cost = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_posterior)+1e-7))+tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+1e-7)))
        # if self.config['divergence_mode'] == 'WGAN-GP':
        #     uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        #     uniform_w = uniform_dist.sample()
        #     self.trivial_line = uniform_w[:,np.newaxis,:]*self.posterior_latent_code[:,np.newaxis,:]+(1-uniform_w[:,np.newaxis,:])*self.prior_latent_code[:,np.newaxis,:]
        #     self.div_trivial_line = self.Diverger.forward(self.trivial_line)
        #     self.trivial_line_grad = tf.gradients(tf.reduce_sum(self.div_trivial_line), [self.trivial_line])[0]
        #     self.trivial_line_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_line_grad**2, axis=-1, keep_dims=False)[:,:,np.newaxis])
        #     self.trivial_line_grad_norm_1_penalties = ((self.trivial_line_grad_norm-1)**2)
        #     self.div_reg_cost = tf.reduce_mean(self.trivial_line_grad_norm_1_penalties)
        #     self.div_cost = -(tf.reduce_mean(self.div_posterior)-tf.reduce_mean(self.div_prior))+10*self.div_reg_cost

        # if self.config['divergence_mode'] == 'NS-GAN': 
        #     self.enc_reg_cost = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_prior)+1e-7))+1*self.MMD
        # elif self.config['divergence_mode'] == 'GAN': 
        #     self.enc_reg_cost = tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+1e-7))+1*self.MMD
        # elif self.config['divergence_mode'] == 'WGAN-GP': 
        #     self.enc_reg_cost = -tf.reduce_mean(self.div_prior)+1*self.MMD
        
        self.enc_reg_cost = self.MMD
        self.cri_reg_cost = self.info_cost

        ### Critic
        # self.cri_cost = self.config['enc_inv_MMD_strength']*self.cri_reg_cost
        self.cri_cost = self.cri_reg_cost


        ### Encoder
        self.OT_primal = self.sample_distance_function(self.input_sample, self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        if b_use_timer:
            self.enc_overall_cost = helper.hardstep((self.epoch-float(starttime))/float(timescale)+0.0001)*(self.config['enc_reg_strength']*self.enc_reg_cost + self.config['enc_inv_MMD_strength']*self.info_cost)
        else:
            self.enc_overall_cost = self.config['enc_reg_strength']*self.enc_reg_cost + self.config['enc_inv_MMD_strength']*self.info_cost
        self.enc_cost = self.mean_OT_primal + self.enc_overall_cost 

        # if b_use_timer:
        #     self.enc_overall_cost = helper.hardstep((self.epoch-float(starttime))/float(timescale)+0.0001)*(self.config['enc_reg_strength']*self.enc_reg_cost)+self.config['enc_inv_MMD_strength']*self.cri_reg_cost
        # else:
        #     self.enc_overall_cost = (self.config['enc_reg_strength']*self.enc_reg_cost+self.config['enc_inv_MMD_strength']*self.cri_reg_cost)
        # self.enc_cost = self.mean_OT_primal + self.enc_overall_cost
        
        ### Critic
        # self.cri_cost = self.config['enc_inv_MMD_strength']*self.cri_reg_cost

        ### Generator
        self.gen_cost = self.mean_OT_primal
Example #9
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.prior_latent_code_expanded = tf.reshape(
            self.prior_latent_code,
            [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        # self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code)
        # self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.obs_sample_param = self.Generator.forward(
            self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if not os.path.exists('./fixed_samples/'):
            os.makedirs('./fixed_samples/')
        if os.path.exists(
                './fixed_samples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                './np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[400,
                      self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(
                './fixed_samples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz',
                np_constant_prior_sample)

        self.constant_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_prior_latent_code_expanded = tf.reshape(
            self.constant_prior_latent_code, [
                -1, 1,
                *self.constant_prior_latent_code.get_shape().as_list()[1:]
            ])

        self.constant_obs_sample_param = self.Generator.forward(
            self.constant_prior_latent_code_expanded)
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(
            b_mode=True)

        #############################################################################
        # ENCODER

        if self.config['encoder_mode'] == 'Deterministic':
            self.epsilon = None
        if self.config['encoder_mode'] == 'Gaussian' or self.config[
                'encoder_mode'] == 'UnivApprox' or self.config[
                    'encoder_mode'] == 'UnivApproxNoSpatial' or self.config[
                        'encoder_mode'] == 'UnivApproxSine':
            self.epsilon_param = self.PriorMap.forward(
                (tf.zeros(shape=(self.batch_size_tf, 1)), ))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(
                params=self.epsilon_param)
            self.epsilon = self.epsilon_dist.sample()

        self.posterior_latent_code_expanded = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,
                                                                         0, :]
        self.interpolated_posterior_latent_code = self.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        self.interpolated_obs = self.Generator.forward(
            self.interpolated_posterior_latent_code)

        self.reconst_param = self.Generator.forward(
            self.posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        ### Primal Penalty
        if self.config['divergence_mode'] == 'MMD':
            self.MMD = self.compute_MMD(self.posterior_latent_code,
                                        self.prior_dist.sample())
            self.enc_reg_cost = self.MMD
        if self.config['divergence_mode'] == 'INV-MMD':
            batch_rand_vectors = tf.random_normal(shape=[
                self.config['enc_inv_MMD_n_trans'], self.config['n_latent']
            ])
            batch_rand_dirs = batch_rand_vectors / helper.safe_tf_sqrt(
                tf.reduce_sum((batch_rand_vectors**2), axis=1, keep_dims=True))
            self.Inv_MMD = self.stable_div(self.compute_MMD,
                                           self.posterior_latent_code,
                                           batch_rand_dirs)
            self.MMD = self.compute_MMD(self.posterior_latent_code,
                                        self.prior_dist.sample())
            self.enc_reg_cost = self.MMD + self.config[
                'enc_inv_MMD_strength'] * self.Inv_MMD
        elif self.config['divergence_mode'] == 'GAN' or self.config[
                'divergence_mode'] == 'NS-GAN':
            self.div_posterior = self.Diverger.forward(
                self.posterior_latent_code_expanded)
            self.div_prior = self.Diverger.forward(
                self.prior_latent_code_expanded)
            self.mean_z_divergence = tf.reduce_mean(
                tf.log(10e-7 + self.div_prior)) + tf.reduce_mean(
                    tf.log(10e-7 + 1 - self.div_posterior))
            if self.config['divergence_mode'] == 'NS-GAN':
                self.enc_reg_cost = -tf.reduce_mean(
                    tf.log(10e-7 + self.div_posterior))
            elif self.config['divergence_mode'] == 'GAN':
                self.enc_reg_cost = tf.reduce_mean(
                    tf.log(10e-7 + 1 - self.div_posterior))

        #############################################################################
        # REGULARIZER
        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(self.batch_size_tf, 1)),
                tf.ones(shape=(self.batch_size_tf, 1))
            ],
                             axis=1))

        ### Visualized regularization sample
        self.reg_target_dist = self.reconst_dist
        self.reg_dist = None

        ### Uniform Sample From Geodesic of Trivial Coupling Pairs
        if self.compute_all_regularizers or \
           'Trivial Gradient Norm' in self.config['critic_reg_mode'] or \
           'Trivial Lipschitz' in self.config['critic_reg_mode']:
            self.uniform_sample_b = self.uniform_dist.sample()
            self.trivial_line_sample_param = {'image': None, 'flat': None}
            try:
                self.uniform_sample_b_expanded = self.uniform_sample_b[:, np.
                                                                       newaxis, :,
                                                                       np.
                                                                       newaxis,
                                                                       np.
                                                                       newaxis]
                self.trivial_line_sample_param['image'] = self.uniform_sample_b_expanded*self.obs_sample['image']+\
                                                          (1-self.uniform_sample_b_expanded)*self.input_sample['image']
            except:
                self.uniform_sample_b_expanded = self.uniform_sample_b[:, np.
                                                                       newaxis, :]
                self.trivial_line_sample_param['flat'] = self.uniform_sample_b_expanded*self.obs_sample['flat']+\
                                                         (1-self.uniform_sample_b_expanded)*self.input_sample['flat']
            self.trivial_line_dist = distributions.ProductDistribution(
                sample_properties=batch['observed']['properties'],
                params=self.trivial_line_sample_param)
            self.trivial_line_sample = self.trivial_line_dist.sample(
                b_mode=True)
            if self.reg_dist is None: self.reg_dist = self.trivial_line_dist

        #############################################################################
        # CRITIC

        self.critic_real = self.Critic.forward(self.input_sample)
        self.critic_gen = self.Critic.forward(self.obs_sample)
        self.critic_rec = self.Critic.forward(self.reconst_sample)

        self.mean_critic_real = tf.reduce_mean(self.critic_real)
        self.mean_critic_gen = tf.reduce_mean(self.critic_gen)
        self.mean_critic_rec = tf.reduce_mean(self.critic_rec)

        self.mean_critic_reg = None

        if self.compute_all_regularizers or \
           'Trivial Gradient Norm' in self.config['critic_reg_mode'] or \
           'Trivial Lipschitz' in self.config['critic_reg_mode']:
            self.critic_trivial_line = self.Critic.forward(
                self.trivial_line_sample)
            if self.mean_critic_reg is None:
                self.mean_critic_reg = tf.reduce_mean(self.critic_trivial_line)
            try:
                self.trivial_line_grad = tf.gradients(
                    self.critic_trivial_line,
                    [self.trivial_line_sample['image']])[0]
                self.trivial_line_grad_norm = tf.sqrt(
                    tf.reduce_sum(tf.square(self.trivial_line_grad),
                                  axis=[-1, -2, -3],
                                  keep_dims=False))[:, np.newaxis, :]
            except:
                self.trivial_line_grad = tf.gradients(
                    self.critic_trivial_line,
                    [self.trivial_line_sample['flat']])[0]
                self.trivial_line_grad_norm = helper.safe_tf_sqrt(
                    tf.reduce_sum(self.trivial_line_grad**2,
                                  axis=[
                                      -1,
                                  ],
                                  keep_dims=True))

        self.cri_reg_cost = 0

        if self.compute_all_regularizers or 'Trivial Gradient Norm' in self.config[
                'critic_reg_mode']:
            self.trivial_line_grad_norm_1_penalties = ((
                self.trivial_line_grad_norm - 1.)**2)
            self.mean_trivial_line_grad_norm_1_penalties = tf.reduce_mean(
                self.trivial_line_grad_norm_1_penalties)
            if 'Trivial Gradient Norm' in self.config['critic_reg_mode']:
                print('Adding Trivial Gradient Norm Penalty')
                self.cri_reg_cost += self.mean_trivial_line_grad_norm_1_penalties

        if len(self.config['critic_reg_mode']) > 0:
            self.cri_reg_cost /= len(self.config['critic_reg_mode'])

        #############################################################################

        # OBJECTIVES
        ### Divergence
        if self.config['divergence_mode'] == 'GAN' or self.config[
                'divergence_mode'] == 'NS-GAN':
            self.div_cost = -self.config[
                'enc_reg_strength'] * self.mean_z_divergence

        ### Encoder
        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        # self.mean_OT_primal = helper.tf_print(self.mean_OT_primal, [self.mean_OT_primal])

        self.mean_POT_primal = self.mean_OT_primal  #+self.config['enc_reg_strength']*self.enc_reg_cost
        self.enc_cost = self.mean_POT_primal

        ### Generator
        if self.config['dual_dist_mode'] == 'Coupling':
            self.gen_cost = -self.mean_critic_rec
        elif self.config['dual_dist_mode'] == 'Prior':
            self.gen_cost = -self.mean_critic_gen
        elif self.config['dual_dist_mode'] == 'CouplingAndPrior':
            self.gen_cost = -0.5 * (self.mean_critic_rec +
                                    self.mean_critic_gen)

        ### Critic
        if self.config['dual_dist_mode'] == 'Coupling':
            self.mean_OT_dual = self.mean_critic_real - self.mean_critic_rec
        elif self.config['dual_dist_mode'] == 'Prior':
            self.mean_OT_dual = self.mean_critic_real - self.mean_critic_gen
        elif self.config['dual_dist_mode'] == 'CouplingAndPrior':
            self.mean_OT_dual = self.mean_critic_real - 0.5 * (
                self.mean_critic_rec + self.mean_critic_gen)

        self.cri_cost = -self.mean_OT_dual + self.config[
            'cri_reg_strength'] * self.cri_reg_cost