Exemplo n.º 1
0
    def generative_model(self, batch, additional_inputs_tf):
        self.gen_epoch = additional_inputs_tf[0]
        self.gen_b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.gen_input_sample = batch['observed']['data']
        self.gen_input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.gen_input_sample)

        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: self.gen_batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: self.gen_batch_size_tf = tf.shape(self.input_sample['image'])[0]
        
        self.gen_pre_prior_param = self.PriorMap.forward((tf.zeros(shape=(self.gen_batch_size_tf, 1)),))
        # self.gen_pre_prior_dist = distributions.UniformDistribution(params = self.gen_pre_prior_param)
        self.gen_pre_prior_dist = distributions.DiagonalBetaDistribution(params = self.gen_pre_prior_param)        
        self.gen_pre_prior_latent_code = self.gen_pre_prior_dist.sample()
        self.gen_pre_prior_latent_code_log_pdf = self.gen_pre_prior_dist.log_pdf(self.gen_pre_prior_latent_code)
        self.gen_prior_latent_code, self.gen_prior_latent_code_log_pdf = self.gen_flow_object.inverse_transform(self.gen_pre_prior_latent_code, self.gen_pre_prior_latent_code_log_pdf)

        self.gen_obs_sample_param = self.Generator.forward(self.gen_prior_latent_code[:, np.newaxis, :])
        self.gen_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.gen_obs_sample_param)
        self.gen_obs_sample = self.gen_obs_sample_dist.sample(b_mode=True)
Exemplo n.º 2
0
    def generative_model(self, batch, additional_inputs_tf):
        self.gen_epoch = additional_inputs_tf[0]
        self.gen_b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.gen_input_sample = batch['observed']['data']
        self.gen_input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.gen_input_sample)

        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: batch_size_tf = tf.shape(self.input_sample['image'])[0]
        
        self.gen_prior_param = self.PriorMap.forward((tf.zeros(shape=(batch_size_tf, 1)),))
        self.gen_prior_dist = distributions.DiagonalGaussianDistribution(params = self.gen_prior_param)
        self.gen_prior_latent_code = self.gen_prior_dist.sample()
        self.gen_neg_ent_prior = self.prior_dist.log_pdf(self.gen_prior_latent_code)
        self.gen_mean_neg_ent_prior = tf.reduce_mean(self.gen_neg_ent_prior)

        self.gen_prior_latent_code_expanded = tf.reshape(self.gen_prior_latent_code, [-1, 1, *self.gen_prior_latent_code.get_shape().as_list()[1:]])
        self.gen_obs_sample_param = self.Decoder.forward(self.gen_prior_latent_code_expanded)
        self.gen_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.gen_obs_sample_param)
        self.gen_obs_sample = self.gen_obs_sample_dist.sample(b_mode=True)
Exemplo n.º 3
0
    def generative_model(self, batch, additional_inputs_tf):
        self.gen_epoch = additional_inputs_tf[0]
        self.gen_b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.gen_input_sample = batch['observed']['data']
        self.gen_input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.gen_input_sample)

        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: self.gen_batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: self.gen_batch_size_tf = tf.shape(self.input_sample['image'])[0]
        
        self.gen_prior_param = self.PriorMap.forward((tf.zeros(shape=(self.gen_batch_size_tf, 1)),))
        self.gen_prior_dist = distributions.DiagonalGaussianDistribution(params = self.gen_prior_param)
        self.gen_prior_latent_code = self.gen_prior_dist.sample()
        self.gen_prior_latent_code_expanded = self.gen_prior_latent_code[:,np.newaxis,:]

        self.gen_prior_feature_expanded = self.Generator.forward(self.gen_prior_latent_code_expanded)
        self.gen_prior_feature = self.gen_prior_feature_expanded[:,0,:]
        self.gen_obs_sample_param = self.PostGen.forward(self.gen_prior_feature_expanded)
        self.gen_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.gen_obs_sample_param)
        self.gen_obs_sample = self.gen_obs_sample_dist.sample(b_mode=True)
Exemplo n.º 4
0
    def generative_model(self, batch, additional_inputs_tf):
        self.gen_epoch = additional_inputs_tf[0]
        self.gen_b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.gen_input_sample = batch['observed']['data']
        self.gen_input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.gen_input_sample)

        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.gen_batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.gen_batch_size_tf = tf.shape(self.input_sample['image'])[0]

        self.gen_pre_prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.gen_batch_size_tf, 1)), ))
        self.gen_pre_prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.gen_pre_prior_param)
        # self.gen_pre_prior_latent_template = tf.concat([tf.zeros(shape=[1, self.config['n_latent']-48]), tf.ones(shape=[1, 48])], axis=1)
        # self.gen_pre_prior_latent_template = tf.concat([tf.ones(shape=[1, 48]), tf.zeros(shape=[1, self.config['n_latent']-48])], axis=1)

        self.gen_pre_prior_latent_code = self.gen_pre_prior_dist.sample()
        self.gen_pre_prior_latent_code = self.gen_pre_prior_latent_code  #*self.gen_pre_prior_latent_template
        self.gen_pre_prior_latent_code_expanded = self.gen_pre_prior_latent_code[:,
                                                                                 np
                                                                                 .
                                                                                 newaxis, :]

        self.gen_flow_param = self.FlowMap.forward()
        self.gen_flow_object = transforms.HouseholdRotationFlow(
            self.gen_flow_param, self.config['n_latent'])
        self.gen_prior_latent_code, _ = self.gen_flow_object.inverse_transform(
            self.gen_pre_prior_latent_code,
            tf.zeros(shape=(self.gen_batch_size_tf, 1)))
        self.gen_prior_latent_code_expanded = self.gen_prior_latent_code[:, np.
                                                                         newaxis, :]

        self.gen_obs_sample_param = self.Generator.forward(
            self.gen_prior_latent_code_expanded)
        self.gen_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.gen_obs_sample_param)
        self.gen_obs_sample = self.gen_obs_sample_dist.sample(b_mode=True)
Exemplo n.º 5
0
    def generative_model(self, batch, additional_inputs_tf):
        self.gen_epoch = additional_inputs_tf[0]
        self.gen_b_identity = additional_inputs_tf[1]

        empirical_observed_properties = copy.deepcopy(
            batch['observed']['properties'])
        for e in empirical_observed_properties['flat']:
            e['dist'] = 'dirac'
        for e in empirical_observed_properties['image']:
            e['dist'] = 'dirac'

        self.gen_input_sample = batch['observed']['data']
        self.gen_input_dist = distributions.ProductDistribution(
            sample_properties=empirical_observed_properties,
            params=self.gen_input_sample)

        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.gen_batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.gen_batch_size_tf = tf.shape(self.input_sample['image'])[0]

        self.gen_prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.gen_batch_size_tf, 1)), ))
        self.gen_prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.gen_prior_param)
        self.gen_prior_latent_code = self.gen_prior_dist.sample()

        self.gen_obs_sample_param = self.Generator.forward(
            self.gen_prior_latent_code[:, np.newaxis, :])
        self.gen_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.gen_obs_sample_param)
        self.gen_obs_sample = self.gen_obs_sample_dist.sample()
Exemplo n.º 6
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR
        self.flow_param_list = self.FlowMap.forward(batch)
        self.wolf_param_list = self.WolfMap.forward(batch)
        n_output = np.prod(
            batch['observed']['properties']['image'][0]['size'][2:])

        # Euclidean_flow_class = transforms.PiecewisePlanarScalingFlow
        # Euclidean_flow_class = transforms.RealNVPFlow
        Euclidean_flow_class = transforms.NonLinearIARFlow
        self.pre_flow_object = transforms.SerialFlow([\
                                                        Euclidean_flow_class(input_dim=self.config['n_latent'], parameters=self.wolf_param_list[0]),
                                                        transforms.SpecificRotationFlow(input_dim=self.config['n_latent']),
                                                        Euclidean_flow_class(input_dim=self.config['n_latent'], parameters=self.wolf_param_list[1]),
                                                        transforms.SpecificRotationFlow(input_dim=self.config['n_latent']),
                                                        Euclidean_flow_class(input_dim=self.config['n_latent'], parameters=self.wolf_param_list[2]),
                                                     ])

        self.flow_object = transforms.SerialFlow([\
                                                    Euclidean_flow_class(input_dim=self.config['n_latent'], parameters=self.flow_param_list[0]),
                                                    transforms.SpecificRotationFlow(input_dim=self.config['n_latent']),
                                                    Euclidean_flow_class(input_dim=self.config['n_latent'], parameters=self.flow_param_list[1]),
                                                    transforms.RiemannianFlow(input_dim=self.config['n_latent'], output_dim=n_output, n_input_NOM=self.config['rnf_prop']['n_input_NOM'], n_output_NOM=self.config['rnf_prop']['n_output_NOM'], parameters=self.flow_param_list[-2]),
                                                    transforms.CompoundRotationFlow(input_dim=n_output, parameters=self.flow_param_list[-1]),
                                                 ])

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()

        self.pre_prior_latent_code, _ = self.pre_flow_object.inverse_transform(
            self.prior_latent_code, tf.zeros(shape=(self.batch_size_tf, 1)))
        self.transformed_prior_latent_code, _ = self.flow_object.transform(
            self.pre_prior_latent_code,
            tf.zeros(shape=(self.batch_size_tf, 1)))

        self.obs_sample_param = {
            'flat':
            None,
            'image':
            tf.reshape(self.transformed_prior_latent_code, [
                -1, 1, *batch['observed']['properties']['image'][0]['size'][2:]
            ])
        }
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if not os.path.exists(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/'):
            os.makedirs(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/')
        if os.path.exists(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[400,
                      self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz',
                np_constant_prior_sample)
        self.constant_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)

        self.constant_pre_prior_latent_code, _ = self.pre_flow_object.inverse_transform(
            self.constant_prior_latent_code,
            tf.zeros(shape=(self.batch_size_tf, 1)))
        self.constant_transformed_prior_latent_code, _ = self.flow_object.transform(
            self.constant_pre_prior_latent_code,
            tf.zeros(shape=(self.batch_size_tf, 1)))

        self.constant_obs_sample_param = {
            'flat':
            None,
            'image':
            tf.reshape(self.constant_transformed_prior_latent_code, [
                -1, 1, *batch['observed']['properties']['image'][0]['size'][2:]
            ])
        }
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(
            b_mode=True)

        # if self.config['n_latent'] == 2:
        #     grid_scale = 3
        #     x = np.linspace(-grid_scale, grid_scale, 20)
        #     y = np.linspace(grid_scale, -grid_scale, 20)
        #     xv, yv = np.meshgrid(x, y)
        #     np_constant_prior_grid_sample = np.concatenate((xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]), axis=1)
        #     self.constant_grid_prior_latent_code = tf.constant(np.asarray(np_constant_prior_grid_sample), dtype=np.float32)

        #     self.constant_grid_pre_prior_latent_code, _ = self.pre_flow_object.inverse_transform(self.constant_grid_prior_latent_code, tf.zeros(shape=(self.batch_size_tf, 1)))
        #     self.constant_grid_transformed_prior_latent_code, _ = self.flow_object.transform(self.constant_grid_pre_prior_latent_code, tf.zeros(shape=(self.batch_size_tf, 1)))

        #     self.constant_grid_obs_sample_param = {'flat': None, 'image': tf.reshape(self.constant_grid_transformed_prior_latent_code, [-1, 1, *batch['observed']['properties']['image'][0]['size'][2:]])}
        #     self.constant_grid_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_grid_obs_sample_param)
        #     self.constant_grid_obs_sample = self.constant_grid_obs_sample_dist.sample(b_mode=True)

        #############################################################################
        # ENCODER
        if self.config['encoder_mode'] == 'Deterministic':
            self.epsilon = None
        if self.config['encoder_mode'] == 'Gaussian' or self.config[
                'encoder_mode'] == 'GaussianLeastVariance' or self.config[
                    'encoder_mode'] == 'UnivApprox' or 'UnivApproxNoSpatial' in self.config[
                        'encoder_mode'] or self.config[
                            'encoder_mode'] == 'UnivApproxSine':
            self.epsilon_param = self.EpsilonMap.forward(
                (tf.zeros(shape=(self.batch_size_tf, 1)), ))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(
                params=self.epsilon_param)
            self.epsilon = self.epsilon_dist.sample()

        self.pre_posterior_latent_code_expanded, self.pre_posterior_latent_code_det_expanded = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.pre_posterior_latent_code = self.pre_posterior_latent_code_expanded[:,
                                                                                 0, :]

        self.posterior_latent_code, self.posterior_delta_log_pdf = self.pre_flow_object.transform(
            self.pre_posterior_latent_code,
            tf.zeros(shape=(self.batch_size_tf, 1)))
        self.posterior_log_pdf = self.prior_dist.log_pdf(
            self.posterior_latent_code)
        self.pre_posterior_log_pdf = self.posterior_log_pdf - self.posterior_delta_log_pdf

        # self.pre_posterior_latent_code, self.pre_posterior_log_pdf = self.pre_flow_object.inverse_transform(self.posterior_latent_code, self.posterior_log_pdf)
        self.transformed_pre_posterior_latent_code, self.transformed_pre_posterior_log_pdf = self.flow_object.transform(
            self.pre_posterior_latent_code, self.pre_posterior_log_pdf)

        self.reconst_param = {
            'flat':
            None,
            'image':
            tf.reshape(self.transformed_pre_posterior_latent_code, [
                -1, 1, *batch['observed']['properties']['image'][0]['size'][2:]
            ])
        }
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        # self.interpolated_pre_posterior_latent_code, _ = self.pre_flow_object.inverse_transform(tf.reshape(self.interpolated_posterior_latent_code, [-1, self.interpolated_posterior_latent_code.get_shape().as_list()[-1]]), tf.zeros(shape=(self.batch_size_tf, 1)))
        # self.interpolated_transformed_posterior_latent_code, _ = self.flow_object.transform(self.interpolated_pre_posterior_latent_code, tf.zeros(shape=(self.batch_size_tf, 1)))
        # self.interpolated_obs = {'flat': None, 'image': tf.reshape(self.interpolated_transformed_posterior_latent_code, [-1, 10, *batch['observed']['properties']['image'][0]['size'][2:]])}
        self.interpolated_obs = {
            'flat':
            None,
            'image':
            tf.tile(
                self.input_sample['image'][:self.batch_size_tf //
                                           2, :, :, :, :], [1, 10, 1, 1, 1])
        }

        self.enc_reg_cost = -tf.reduce_mean(
            self.transformed_pre_posterior_log_pdf)
        self.cri_reg_cost = -tf.reduce_mean(self.pre_posterior_log_pdf)

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample

        #############################################################################
        # OBJECTIVES

        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        if '0' in self.config['timers']:
            lambda_t = helper.hardstep(
                (self.epoch - float(self.config['timers']['0']['start'])) /
                float(self.config['timers']['0']['timescale']) + 1e-5)
            overall_cost = self.mean_OT_primal + lambda_t * self.config[
                'enc_reg_strength'] * self.enc_reg_cost
        else:
            overall_cost = self.mean_OT_primal + self.config[
                'enc_reg_strength'] * self.enc_reg_cost

        # self.cri_cost = self.cri_reg_cost
        # self.cri_cost = self.config['enc_reg_strength']*self.enc_reg_cost
        self.cri_cost = self.enc_reg_cost
        self.enc_cost = overall_cost
        self.gen_cost = overall_cost
Exemplo n.º 7
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]
        
        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR 

        self.prior_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.prior_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.prior_latent_code_expanded = tf.reshape(self.prior_latent_code, [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code)
        self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.obs_sample_param = self.Decoder.forward(self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if os.path.exists('./np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
            np_constant_prior_sample = np.load('./np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz')
        else:
            np_constant_prior_sample = np.random.normal(loc=0., scale=1., size=[400, self.prior_latent_code.get_shape().as_list()[-1]])
            np.save('./np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    
        
        self.constant_prior_latent_code = tf.constant(np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_prior_latent_code_expanded = tf.reshape(self.constant_prior_latent_code, [-1, 1, *self.constant_prior_latent_code.get_shape().as_list()[1:]])

        self.constant_obs_sample_param = self.Decoder.forward(self.constant_prior_latent_code_expanded)
        self.constant_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(b_mode=True)
        
        #############################################################################
        # ENCODER 

        b_deterministic = False
        if b_deterministic: 
            self.epsilon = None
        else:
            self.epsilon_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param)        
            self.epsilon = self.epsilon_dist.sample()

        self.posterior_latent_code_expanded = self.EncoderSampler.forward(self.input_sample, noise=self.epsilon) 
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,0,:]
        self.interpolated_posterior_latent_code = self.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2)
        self.interpolated_obs = self.Decoder.forward(self.interpolated_posterior_latent_code) 

        self.reconst_param = self.Decoder.forward(self.posterior_latent_code_expanded) 
        self.reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        self.kernel_function = self.inv_multiquadratics_kernel
        self.k_post_prior = tf.reduce_mean(self.kernel_function(self.posterior_latent_code, self.prior_dist.sample()))
        self.k_post_post = tf.reduce_mean(self.kernel_function(self.posterior_latent_code))
        self.k_prior_prior = tf.reduce_mean(self.kernel_function(self.prior_dist.sample()))
        self.MMD = self.k_prior_prior+self.k_post_post-2*self.k_post_prior

        self.disc_posterior = self.DivergenceLatent.forward(self.posterior_latent_code_expanded)        
        self.disc_prior = self.DivergenceLatent.forward(self.prior_latent_code_expanded)   
        self.mean_z_divergence = tf.reduce_mean(tf.log(10e-7+self.disc_prior))+tf.reduce_mean(tf.log(10e-7+1-self.disc_posterior))
        self.mean_z_divergence_minimizer = tf.reduce_mean(tf.log(10e-7+1-self.disc_posterior))

        #############################################################################
        # REGULARIZER 

        self.uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        self.uniform_sample = self.uniform_dist.sample()
        
        self.reg_trivial_sample_param = {'image': None, 'flat': None}
        try: 
            self.reg_trivial_sample_param['image'] = self.uniform_sample[:, np.newaxis, :, np.newaxis, np.newaxis]*self.obs_sample['image']+\
                                                    (1-self.uniform_sample[:, np.newaxis, :, np.newaxis, np.newaxis])*self.input_sample['image']
        except: 
            self.reg_trivial_sample_param['flat'] = self.uniform_sample[:, np.newaxis, :]*self.obs_sample['flat']+\
                                                   (1-self.uniform_sample[:, np.newaxis, :])*self.input_sample['flat']
        self.reg_trivial_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reg_trivial_sample_param)
        self.reg_trivial_sample = self.reg_trivial_dist.sample(b_mode=True)

        self.reg_target_dist = self.obs_sample_dist
        self.reg_target_sample = self.obs_sample
        self.reg_dist = self.reg_trivial_dist
        self.reg_sample = self.reg_trivial_sample 
        
        #############################################################################
        # CRITIC 

        self.critic_real = self.Discriminator.forward(self.input_sample)
        self.critic_gen = self.Discriminator.forward(self.obs_sample)
        self.critic_reg_trivial = self.Discriminator.forward(self.reg_trivial_sample)
        self.critic_reg = self.critic_reg_trivial

        try:
            self.trivial_grad = tf.gradients(self.critic_reg_trivial, [self.reg_trivial_sample['image']])[0]
            self.trivial_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_grad**2, axis=[-1,-2,-3], keep_dims=False)[:,:,np.newaxis])
        except: 
            self.trivial_grad = tf.gradients(self.critic_reg_trivial, [self.reg_trivial_sample['flat']])[0]
            self.trivial_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_grad**2, axis=[-1,], keep_dims=True))      

        self.trivial_grad_norm_1_penalties = ((self.trivial_grad_norm-1)**2)

        self.mean_critic_real = tf.reduce_mean(self.critic_real)
        self.mean_critic_gen = tf.reduce_mean(self.critic_gen)
        self.mean_critic_reg = tf.reduce_mean(self.critic_reg)
        self.mean_OT_dual = self.mean_critic_real-self.mean_critic_gen

        #############################################################################
        # OBJECTIVES
        self.enc_reg_strength, self.disc_reg_strength = 100, 10

        self.real_reconst_distances_sq = self.metric_distance_sq(self.input_sample, self.reconst_sample)
        self.OT_primal = tf.reduce_mean(helper.safe_tf_sqrt(self.real_reconst_distances_sq))
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        self.enc_reg_cost = self.MMD # self.mean_z_divergence_minimizer
        self.mean_POT_primal = self.mean_OT_primal+self.enc_reg_strength*self.enc_reg_cost
        
        self.disc_reg_cost = tf.reduce_mean(self.trivial_grad_norm_1_penalties)

        # WGAN-GP
        # self.div_cost = -self.mean_z_divergence
        self.enc_cost = self.mean_POT_primal
        self.disc_cost = -self.mean_OT_dual+self.disc_reg_strength*self.disc_reg_cost
        self.gen_cost = -self.mean_critic_gen
Exemplo n.º 8
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        properties_dirac_dict = copy.deepcopy(batch['observed']['properties'])
        if len(properties_dirac_dict['flat']) > 0:
            for e in properties_dirac_dict['flat']:
                e['dist'] = 'dirac'
        else:
            for e in properties_dirac_dict['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=properties_dirac_dict, params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            batch_size_tf = tf.shape(self.input_sample['image'])[0]

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code)
        self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(batch_size_tf, 1)),
                tf.ones(shape=(batch_size_tf, 1))
            ],
                             axis=1))
        self.convex_mix = self.uniform_dist.sample()

        self.prior_latent_code_expanded = tf.reshape(
            self.prior_latent_code,
            [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        self.obs_sample_param = self.Decoder.forward(
            self.prior_latent_code_expanded)

        # self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param)
        # self.obs_sample = self.obs_sample_dist.sample()

        self.obs_sample_dist_pre = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist_pre.sample()
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=properties_dirac_dict, params=self.obs_sample)

        #############################################################################

        self.posterior_param_expanded = self.Encoder.forward(self.input_sample)
        self.posterior_param = self.posterior_param_expanded[:, 0, :]
        self.posterior_dist = distributions.DiagonalGaussianDistribution(
            params=self.posterior_param)
        self.posterior_latent_code = self.posterior_dist.sample()
        self.posterior_latent_code_expanded = self.posterior_latent_code[:, np.
                                                                         newaxis, :]

        self.reconst_param = self.Decoder.forward(
            self.posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)
        self.reconst_sample_noised = self.reconst_dist.sample()

        self.real_log_pdf = self.reconst_dist.log_pdf(self.input_sample)

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = distributions.ProductDistribution(
            sample_properties=properties_dirac_dict,
            params=self.reconst_sample_noised)
        self.reg_sample = self.reg_dist.sample(b_mode=True)

        self.neg_cross_ent_posterior = self.prior_dist.log_pdf(
            self.posterior_latent_code)
        self.mean_neg_cross_ent_posterior = tf.reduce_mean(
            self.neg_cross_ent_posterior)
        self.neg_ent_posterior = self.posterior_dist.log_pdf(
            self.posterior_latent_code)
        self.mean_neg_ent_posterior = tf.reduce_mean(self.neg_ent_posterior)

        # self.kl_posterior_prior = distributions.KLDivDiagGaussianVsDiagGaussian().forward(self.posterior_dist, self.prior_dist)
        # self.mean_kl_posterior_prior = tf.reduce_mean(self.kl_posterior_prior)

        self.mean_kl_posterior_prior = -self.mean_neg_cross_ent_posterior + self.mean_neg_ent_posterior
        self.mean_real_log_pdf = tf.reduce_mean(self.real_log_pdf)
        #############################################################################

        self.mean_autoencode_cost = -self.mean_real_log_pdf  #+self.mean_kl_posterior_prior
        self.generator_cost = self.mean_autoencode_cost
Exemplo n.º 9
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code)
        self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.prior_latent_code_expanded = tf.reshape(
            self.prior_latent_code,
            [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        self.obs_sample_param = self.Generator.forward(
            self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if not os.path.exists(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/'):
            os.makedirs(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/')
        if os.path.exists(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[400,
                      self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz',
                np_constant_prior_sample)

        self.constant_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_prior_latent_code_expanded = tf.reshape(
            self.constant_prior_latent_code, [
                -1, 1,
                *self.constant_prior_latent_code.get_shape().as_list()[1:]
            ])

        self.constant_obs_sample_param = self.Generator.forward(
            self.constant_prior_latent_code_expanded)
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(
            b_mode=True)

        if self.config['n_latent'] == 2:
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate(
                (xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]),
                axis=1)

            self.constant_prior_grid_latent_code = tf.constant(
                np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_prior_grid_latent_code_expanded = tf.reshape(
                self.constant_prior_grid_latent_code, [
                    -1, 1, *self.constant_prior_grid_latent_code.get_shape().
                    as_list()[1:]
                ])

            self.constant_obs_grid_sample_param = self.Generator.forward(
                self.constant_prior_grid_latent_code_expanded)
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(
                sample_properties=batch['observed']['properties'],
                params=self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(
                b_mode=True)

        #############################################################################
        # ENCODER
        # flat_input_sample = tf.reshape(self.input_sample['image'], [-1, 1, np.prod(self.input_sample['image'].get_shape().as_list()[-3:])])
        # self.decomposition_out_2, self.deterioration_out_2, self.decomposition_concat_out_2, self.deterioration_concat_out_2, self.reconstruction_cost_2 = self.Encoder.forward(flat_input_sample, noise=None)
        self.decomposition_out, self.deterioration_out, self.decomposition_concat_out, self.deterioration_concat_out, self.reconstruction_cost = self.Encoder.forward(
            self.input_sample['image'], noise=None)

        self.interpolated_obs = {
            'flat':
            None,
            'image':
            tf.concat([self.deterioration_out, self.decomposition_out], axis=1)
        }

        self.posterior_latent_code = self.prior_latent_code
        self.posterior_latent_code_expanded = self.prior_latent_code_expanded
        self.interpolated_posterior_latent_code = self.prior_latent_code
        self.reconst_param = self.Generator.forward(
            self.posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample

        #############################################################################
        # OBJECTIVES
        ### Divergence

        ### Encoder
        self.enc_cost = self.reconstruction_cost
Exemplo n.º 10
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]
        
        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR 

        self.pre_prior_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))      
        # self.pre_prior_dist = distributions.UniformDistribution(params = self.pre_prior_param)        
        self.pre_prior_dist = distributions.DiagonalBetaDistribution(params = self.pre_prior_param)        
        self.pre_prior_latent_code = self.pre_prior_dist.sample()
        self.pre_prior_latent_code_log_pdf = self.pre_prior_dist.log_pdf(self.pre_prior_latent_code)

        self.ambient_prior_param = self.PriorMapBetaInverted.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))      
        self.ambient_prior_dist = distributions.DiagonalBetaDistribution(params = self.ambient_prior_param)        

        self.flow_param_list = self.FlowMap.forward()
        # self.flow_object = transforms.SerialFlow([ \
        #                                           transforms.InverseOpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
        #                                           transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[0]), 
        #                                           transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[1]),
        #                                           transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[2]),
        #                                           transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[3]),
        #                                           # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[4]),
        #                                           # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[5]),
        #                                           # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[6]),
        #                                           transforms.OpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
        #                                           ])
        
        # self.flow_object = transforms.SerialFlow([ \
        #                                           transforms.InverseOpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[0]), 
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[7]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[1]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[8]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[2]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[9]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[3]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[10]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[4]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[11]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[5]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[12]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[6]),
        #                                           transforms.OpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
        #                                           ])
        
        self.flow_object = transforms.SerialFlow([ \
                                                  transforms.InverseOpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
                                                  transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[0]), 
                                                  transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[1]),
                                                  transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[2]),
                                                  transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[3]),
                                                  # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  # transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[4]),
                                                  # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  # transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[5]),
                                                  # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  # transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[6]),
                                                  transforms.OpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
                                                  ])
        self.gen_flow_object = self.flow_object
        self.prior_latent_code, self.prior_latent_code_log_pdf = self.flow_object.inverse_transform(self.pre_prior_latent_code, self.pre_prior_latent_code_log_pdf)

        self.obs_sample_param = self.Generator.forward(self.prior_latent_code[:, np.newaxis, :])
        self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)        

        if not os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/'): os.makedirs(str(Path.home())+'/ExperimentalResults/FixedSamples/')
        # if os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform11_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
        if os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform01_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
            # np_constant_prior_sample = np.load(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform11_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz')
            np_constant_prior_sample = np.load(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform01_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz')
        else:
            # np_constant_prior_sample = np.random.uniform(low=-1., high=1., size=[400, self.pre_prior_latent_code.get_shape().as_list()[-1]])
            np_constant_prior_sample = np.random.uniform(low=0., high=1., size=[400, self.pre_prior_latent_code.get_shape().as_list()[-1]])
            # np.save(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform11_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    
            np.save(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform01_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    

        self.constant_pre_prior_latent_code = tf.constant(np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_prior_latent_code, _ = self.flow_object.inverse_transform(self.constant_pre_prior_latent_code, tf.zeros(shape=(self.batch_size_tf, 1)))

        self.constant_obs_sample_param = self.Generator.forward(self.constant_prior_latent_code[:, np.newaxis, :])
        self.constant_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(b_mode=True)
    
        if self.config['n_latent'] == 2: 
            # x = np.linspace(-1, 1, 20)
            # y = np.linspace(1, -1, 20)
            x = np.linspace(0, 1, 20)
            y = np.linspace(1, 0, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate((xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]), axis=1)        
            self.constant_prior_grid_latent_code = tf.constant(np.asarray(np_constant_prior_grid_sample), dtype=np.float32)

            self.constant_obs_grid_sample_param = self.Generator.forward(self.constant_prior_grid_latent_code[:, np.newaxis, :])
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(b_mode=True)
                
        #############################################################################
        # ENCODER 
        self.epsilon_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.epsilon_dist = distributions.DiagonalGaussianDistribution(params=self.epsilon_param) 
        
        if self.config['encoder_mode'] == 'Deterministic': 
            self.epsilon = None
        if self.config['encoder_mode'] == 'Gaussian' or self.config['encoder_mode'] == 'UnivApprox' or self.config['encoder_mode'] == 'UnivApproxNoSpatial' or self.config['encoder_mode'] == 'UnivApproxSine':         
            self.epsilon = self.epsilon_dist.sample()

        # self.pre_posterior_latent_code = 0.95*tf.nn.tanh(self.Encoder.forward(self.input_sample, noise=self.epsilon))[:,0,:]
        # self.pre_posterior_latent_code = 0.9*tf.nn.sigmoid(self.Encoder.forward(self.input_sample, noise=self.epsilon))[:,0,:]+0.05
        self.pre_posterior_latent_code = 0.8*tf.nn.sigmoid(self.Encoder.forward(self.input_sample, noise=self.epsilon))[:,0,:]+0.1
        self.nball_param = tf.concat([self.pre_posterior_latent_code, 0.05*tf.ones(shape=(self.batch_size_tf, 1))], axis=1)
        self.nball_dist = distributions.UniformBallDistribution(params=self.nball_param) 
        self.posterior_latent_code = self.nball_dist.sample()
        self.posterior_latent_code_log_pdf = -np.log(50000)+self.nball_dist.log_pdf(self.posterior_latent_code)
        self.transformed_posterior_latent_code, self.transformed_posterior_latent_code_log_pdf = self.flow_object.transform(self.posterior_latent_code, self.posterior_latent_code_log_pdf)

        self.hollow_nball_param = tf.concat([self.pre_posterior_latent_code, 0.05*tf.ones(shape=(self.batch_size_tf, 1)), 0.1*tf.ones(shape=(self.batch_size_tf, 1))], axis=1)
        self.hollow_nball_dist = distributions.UniformHollowBallDistribution(params=self.hollow_nball_param) 
        self.hollow_posterior_latent_code = self.hollow_nball_dist.sample()
        self.hollow_posterior_latent_code_log_pdf = -np.log(50000)+self.hollow_nball_dist.log_pdf(self.hollow_posterior_latent_code)
        self.transformed_hollow_posterior_latent_code, self.transformed_hollow_posterior_latent_code_log_pdf = self.flow_object.transform(self.hollow_posterior_latent_code, self.hollow_posterior_latent_code_log_pdf)

        self.ambient_param = self.AmbientMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.ambient_dist = distributions.UniformDistribution(params = self.ambient_param)  
        self.ambient_latent_code = self.ambient_dist.sample()
        self.ambient_latent_code_log_pdf = self.ambient_dist.log_pdf(self.ambient_latent_code)
        self.transformed_ambient_latent_code, self.transformed_ambient_latent_code_log_pdf = self.flow_object.transform(self.ambient_latent_code, self.ambient_latent_code_log_pdf)

        self.KL_transformed_prior_per = self.transformed_posterior_latent_code_log_pdf-self.pre_prior_dist.log_pdf(self.transformed_posterior_latent_code)
        self.KL_transformed_hollow_prior_per = self.transformed_hollow_posterior_latent_code_log_pdf-self.ambient_prior_dist.log_pdf(self.transformed_hollow_posterior_latent_code)
        self.KL_transformed_ambient_prior_per = self.transformed_ambient_latent_code_log_pdf-self.ambient_prior_dist.log_pdf(self.transformed_ambient_latent_code)
        self.KL_transformed_prior = tf.reduce_mean(self.KL_transformed_prior_per)
        self.KL_transformed_hollow_prior = tf.reduce_mean(self.KL_transformed_hollow_prior_per)
        self.KL_transformed_ambient_prior = tf.reduce_mean(self.KL_transformed_ambient_prior_per)

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2)
        self.interpolated_obs = self.Generator.forward(self.interpolated_posterior_latent_code) 

        self.reconst_param = self.Generator.forward(self.posterior_latent_code[:, np.newaxis, :]) 
        self.reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        ### Primal Penalty
        
        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample
        
        #############################################################################

        # OBJECTIVES
        # Divergence
        # if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN':
        #     self.div_cost = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_posterior)+10e-7))+tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7)))
        # if self.config['divergence_mode'] == 'WGAN-GP':
        #     uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        #     uniform_w = uniform_dist.sample()
        #     self.trivial_line = uniform_w[:,np.newaxis,:]*self.pre_posterior_latent_code_expanded+(1-uniform_w[:,np.newaxis,:])*self.prior_latent_code_expanded
        #     self.div_trivial_line = self.Diverger.forward(self.trivial_line)
        #     self.trivial_line_grad = tf.gradients(tf.reduce_sum(self.div_trivial_line), [self.trivial_line])[0]
        #     self.trivial_line_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_line_grad**2, axis=-1, keep_dims=False)[:,:,np.newaxis])
        #     self.trivial_line_grad_norm_1_penalties = ((self.trivial_line_grad_norm-1)**2)
        #     self.div_reg_cost = tf.reduce_mean(self.trivial_line_grad_norm_1_penalties)
        #     # self.div_cost = -(tf.reduce_mean(self.div_posterior)-tf.reduce_mean(self.div_prior))+10*self.div_reg_cost
        # self.div_cost = (self.KL_transformed_prior+self.KL_transformed_ambient_prior+self.KL_transformed_hollow_prior)
        self.div_cost = (self.KL_transformed_prior+self.KL_transformed_hollow_prior)
       
        # ### Encoder
        # b_use_timer, timescale, starttime = False, 10, 5
        self.OT_primal = self.sample_distance_function(self.input_sample, self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        # if b_use_timer:
        #     self.mean_POT_primal = self.mean_OT_primal+helper.hardstep((self.epoch-float(starttime))/float(timescale))*self.config['enc_reg_strength']*self.enc_reg_cost
        # else:
        #     self.mean_POT_primal = self.mean_OT_primal+self.config['enc_reg_strength']*self.enc_reg_cost
        # self.enc_cost = self.mean_POT_primal
        self.enc_cost = self.mean_OT_primal

        # ### Critic
        # # self.cri_cost = helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # if self.config['divergence_mode'] == 'NS-GAN': 
        #     self.cri_cost = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'GAN': 
        #     self.cri_cost = tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'WGAN-GP': 
        #     self.cri_cost = -tf.reduce_mean(self.div_prior)+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)

        ### Generator
        self.gen_cost = self.mean_OT_primal
Exemplo n.º 11
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.obs_sample_param = self.Generator.forward(
            self.prior_latent_code[:, np.newaxis, :])
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if not os.path.exists(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/'):
            os.makedirs(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/')
        if os.path.exists(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[400,
                      self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz',
                np_constant_prior_sample)

        self.constant_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_obs_sample_param = self.Generator.forward(
            self.constant_prior_latent_code[:, np.newaxis, :])
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(
            b_mode=True)

        if self.config['n_latent'] == 2:
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate(
                (xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]),
                axis=1)

            self.constant_prior_grid_latent_code = tf.constant(
                np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_obs_grid_sample_param = self.Generator.forward(
                self.constant_prior_grid_latent_code[:, np.newaxis, :])
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(
                sample_properties=batch['observed']['properties'],
                params=self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(
                b_mode=True)

        #############################################################################
        # ENCODER
        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(self.batch_size_tf, 1)),
                tf.ones(shape=(self.batch_size_tf, 1))
            ],
                             axis=1))

        if self.config['encoder_mode'] == 'Deterministic':
            pdb.set_trace()
        if self.config['encoder_mode'] == 'Gaussian' or self.config[
                'encoder_mode'] == 'UnivApprox' or self.config[
                    'encoder_mode'] == 'UnivApproxNoSpatial' or self.config[
                        'encoder_mode'] == 'UnivApproxSine':
            self.epsilon_param = self.EpsilonMap.forward(
                (tf.zeros(shape=(self.batch_size_tf, 1)), ))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(
                params=self.epsilon_param)
            # self.epsilon_dist = distributions.BernoulliDistribution(params = self.epsilon_param)
            self.epsilon = self.epsilon_dist.sample()

        self.posterior_latent_code_expanded, self.posterior_latent_code_det_expanded = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,
                                                                         0, :]
        self.posterior_latent_code_det = self.posterior_latent_code_det_expanded[:,
                                                                                 0, :]

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        self.interpolated_obs = self.Generator.forward(
            self.interpolated_posterior_latent_code)

        self.reconst_param = self.Generator.forward(
            self.posterior_latent_code[:, np.newaxis, :])
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample

        #############################################################################
        # OBJECTIVES

        ### Encoder
        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        self.enc_cost = self.mean_OT_primal

        ### Generator
        self.gen_cost = self.mean_OT_primal
Exemplo n.º 12
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]
        
        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR 

        self.prior_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.prior_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()        
        self.obs_sample_param = self.Generator.forward(self.prior_latent_code[:, np.newaxis, :])
        self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)
        
        if not os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/'): os.makedirs(str(Path.home())+'/ExperimentalResults/FixedSamples/')
        if os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
            np_constant_prior_sample = np.load(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz')
        else:
            np_constant_prior_sample = np.random.normal(loc=0., scale=1., size=[400, self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    
        
        self.constant_prior_latent_code = tf.constant(np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_obs_sample_param = self.Generator.forward(self.constant_prior_latent_code[:, np.newaxis, :])
        self.constant_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(b_mode=True)
                
        if self.config['n_latent'] == 2: 
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate((xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]), axis=1)
        
            self.constant_prior_grid_latent_code = tf.constant(np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_obs_grid_sample_param = self.Generator.forward(self.constant_prior_grid_latent_code[:, np.newaxis, :])
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(b_mode=True)
                
        #############################################################################
        # ENCODER 
        self.uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        self.uniform_sample = self.uniform_dist.sample()

        if self.config['encoder_mode'] == 'Deterministic': 
            pdb.set_trace()
        if self.config['encoder_mode'] == 'Gaussian' or self.config['encoder_mode'] == 'UnivApprox' or self.config['encoder_mode'] == 'UnivApproxNoSpatial' or self.config['encoder_mode'] == 'UnivApproxSine': 
            self.epsilon_param = self.EpsilonMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(params = self.epsilon_param)        
            # self.epsilon_dist = distributions.BernoulliDistribution(params = self.epsilon_param)        
            self.epsilon = self.epsilon_dist.sample()
            

        self.posterior_latent_code_expanded, self.posterior_latent_code_det_expanded = self.Encoder.forward(self.input_sample, noise=self.epsilon)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,0,:]
        self.posterior_latent_code_det = self.posterior_latent_code_det_expanded[:,0,:]

        # self.flow_param_list = self.FlowMap.forward()
        # self.flow_object = transforms.SerialFlow([\
        #                                           transforms.NonLinearIARFlow(input_dim=2*self.config['n_latent'], parameters=self.flow_param_list[0]), 
        #                                           transforms.CustomSpecificOrderDimensionFlow(input_dim=2*self.config['n_latent']), 
        #                                           transforms.NonLinearIARFlow(input_dim=2*self.config['n_latent'], parameters=self.flow_param_list[1]),
        #                                           # transforms.CustomSpecificOrderDimensionFlow(input_dim=2*self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=2*self.config['n_latent'], parameters=self.flow_param_list[2]),
        #                                           # transforms.CustomSpecificOrderDimensionFlow(input_dim=2*self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=2*self.config['n_latent'], parameters=self.flow_param_list[3]),
        #                                           ])
        
        # self.transformed_posterior_latent_code, _ = self.flow_object.transform(tf.concat([self.posterior_latent_code_det, self.epsilon], axis=1), tf.zeros(shape=(self.batch_size_tf, 1)))
        # self.posterior_latent_code = self.transformed_posterior_latent_code[:, self.config['n_latent']:]

        self.info_param = self.InfoMap.forward(self.posterior_latent_code)
        self.info_dist = distributions.DiagonalGaussianDistribution(params = self.info_param)
        # self.info_dist = distributions.BernoulliDistribution(params = self.info_param)
        self.epsilon_info_log_pdf = self.info_dist.log_pdf(self.epsilon)

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2)
        self.interpolated_obs = self.Generator.forward(self.interpolated_posterior_latent_code) 

        self.reconst_param = self.Generator.forward(self.posterior_latent_code[:, np.newaxis, :]) 
        self.reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        self.MMD = helper.compute_MMD(self.posterior_latent_code, self.prior_dist.sample())
        self.info_cost = -tf.reduce_mean(self.epsilon_info_log_pdf)

        if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN' or self.config['divergence_mode'] == 'WGAN-GP':
            self.div_posterior = self.Diverger.forward(self.posterior_latent_code[:, np.newaxis, :])        
            self.div_prior = self.Diverger.forward(self.prior_latent_code[:, np.newaxis, :])
            
            # self.mean_z_divergence = tf.reduce_mean(tf.log(1e-7+self.div_prior))+tf.reduce_mean(tf.log(1e-7+1-self.div_posterior))
            # if self.config['divergence_mode'] == 'NS-GAN': 
            #     self.enc_reg_cost = -tf.reduce_mean(tf.log(1e-7+self.div_posterior))
            # elif self.config['divergence_mode'] == 'GAN': 
            #     self.enc_reg_cost = tf.reduce_mean(tf.log(1e-7+1-self.div_posterior))

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample
        
        #############################################################################
        # OBJECTIVES

        # ## Divergence
        # if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN':
        #     if b_use_timer:
        #         self.div_cost = helper.hardstep((self.epoch-float(starttime))/float(timescale)+0.0001)*(self.config['enc_reg_strength']*(-self.mean_z_divergence))
        #     else:
        #         self.div_cost = self.config['enc_reg_strength']*(-self.mean_z_divergence)

        # ### Encoder
        # b_use_timer, timescale, starttime = True, 10, 10
        # self.OT_primal = self.sample_distance_function(self.input_sample, self.reconst_sample)
        # self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        # if b_use_timer:
        #     self.enc_overall_cost = helper.hardstep((self.epoch-float(starttime))/float(timescale)+0.0001)*(self.config['enc_reg_strength']*self.enc_reg_cost+self.config['enc_inv_MMD_strength']*self.cri_reg_cost)
        # else:
        #     self.enc_overall_cost = (self.config['enc_reg_strength']*self.enc_reg_cost+self.config['enc_inv_MMD_strength']*self.cri_reg_cost)
        # self.enc_cost = self.mean_OT_primal + self.enc_overall_cost
        
        # ### Critic
        # self.cri_cost = self.config['enc_inv_MMD_strength']*self.cri_reg_cost

        # ### Generator
        # self.gen_cost = self.mean_OT_primal

        # b_use_timer, timescale, starttime = True, 5, 5 # MNIST
        b_use_timer, timescale, starttime = True, 100, 5
        # # ## Divergence
        # if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN':
        #     self.div_cost = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_posterior)+1e-7))+tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+1e-7)))
        # if self.config['divergence_mode'] == 'WGAN-GP':
        #     uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        #     uniform_w = uniform_dist.sample()
        #     self.trivial_line = uniform_w[:,np.newaxis,:]*self.posterior_latent_code[:,np.newaxis,:]+(1-uniform_w[:,np.newaxis,:])*self.prior_latent_code[:,np.newaxis,:]
        #     self.div_trivial_line = self.Diverger.forward(self.trivial_line)
        #     self.trivial_line_grad = tf.gradients(tf.reduce_sum(self.div_trivial_line), [self.trivial_line])[0]
        #     self.trivial_line_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_line_grad**2, axis=-1, keep_dims=False)[:,:,np.newaxis])
        #     self.trivial_line_grad_norm_1_penalties = ((self.trivial_line_grad_norm-1)**2)
        #     self.div_reg_cost = tf.reduce_mean(self.trivial_line_grad_norm_1_penalties)
        #     self.div_cost = -(tf.reduce_mean(self.div_posterior)-tf.reduce_mean(self.div_prior))+10*self.div_reg_cost

        # if self.config['divergence_mode'] == 'NS-GAN': 
        #     self.enc_reg_cost = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_prior)+1e-7))+1*self.MMD
        # elif self.config['divergence_mode'] == 'GAN': 
        #     self.enc_reg_cost = tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+1e-7))+1*self.MMD
        # elif self.config['divergence_mode'] == 'WGAN-GP': 
        #     self.enc_reg_cost = -tf.reduce_mean(self.div_prior)+1*self.MMD
        
        self.enc_reg_cost = self.MMD
        self.cri_reg_cost = self.info_cost

        ### Critic
        # self.cri_cost = self.config['enc_inv_MMD_strength']*self.cri_reg_cost
        self.cri_cost = self.cri_reg_cost


        ### Encoder
        self.OT_primal = self.sample_distance_function(self.input_sample, self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        if b_use_timer:
            self.enc_overall_cost = helper.hardstep((self.epoch-float(starttime))/float(timescale)+0.0001)*(self.config['enc_reg_strength']*self.enc_reg_cost + self.config['enc_inv_MMD_strength']*self.info_cost)
        else:
            self.enc_overall_cost = self.config['enc_reg_strength']*self.enc_reg_cost + self.config['enc_inv_MMD_strength']*self.info_cost
        self.enc_cost = self.mean_OT_primal + self.enc_overall_cost 

        # if b_use_timer:
        #     self.enc_overall_cost = helper.hardstep((self.epoch-float(starttime))/float(timescale)+0.0001)*(self.config['enc_reg_strength']*self.enc_reg_cost)+self.config['enc_inv_MMD_strength']*self.cri_reg_cost
        # else:
        #     self.enc_overall_cost = (self.config['enc_reg_strength']*self.enc_reg_cost+self.config['enc_inv_MMD_strength']*self.cri_reg_cost)
        # self.enc_cost = self.mean_OT_primal + self.enc_overall_cost
        
        ### Critic
        # self.cri_cost = self.config['enc_inv_MMD_strength']*self.cri_reg_cost

        ### Generator
        self.gen_cost = self.mean_OT_primal
Exemplo n.º 13
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.prior_latent_code_expanded = tf.reshape(
            self.prior_latent_code,
            [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        # self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code)
        # self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.obs_sample_param = self.Generator.forward(
            self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if not os.path.exists('./fixed_samples/'):
            os.makedirs('./fixed_samples/')
        if os.path.exists(
                './fixed_samples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                './np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[400,
                      self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(
                './fixed_samples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz',
                np_constant_prior_sample)

        self.constant_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_prior_latent_code_expanded = tf.reshape(
            self.constant_prior_latent_code, [
                -1, 1,
                *self.constant_prior_latent_code.get_shape().as_list()[1:]
            ])

        self.constant_obs_sample_param = self.Generator.forward(
            self.constant_prior_latent_code_expanded)
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(
            b_mode=True)

        #############################################################################
        # ENCODER

        if self.config['encoder_mode'] == 'Deterministic':
            self.epsilon = None
        if self.config['encoder_mode'] == 'Gaussian' or self.config[
                'encoder_mode'] == 'UnivApprox' or self.config[
                    'encoder_mode'] == 'UnivApproxNoSpatial' or self.config[
                        'encoder_mode'] == 'UnivApproxSine':
            self.epsilon_param = self.PriorMap.forward(
                (tf.zeros(shape=(self.batch_size_tf, 1)), ))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(
                params=self.epsilon_param)
            self.epsilon = self.epsilon_dist.sample()

        self.posterior_latent_code_expanded = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,
                                                                         0, :]
        self.interpolated_posterior_latent_code = self.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        self.interpolated_obs = self.Generator.forward(
            self.interpolated_posterior_latent_code)

        self.reconst_param = self.Generator.forward(
            self.posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        ### Primal Penalty
        if self.config['divergence_mode'] == 'MMD':
            self.MMD = self.compute_MMD(self.posterior_latent_code,
                                        self.prior_dist.sample())
            self.enc_reg_cost = self.MMD
        if self.config['divergence_mode'] == 'INV-MMD':
            batch_rand_vectors = tf.random_normal(shape=[
                self.config['enc_inv_MMD_n_trans'], self.config['n_latent']
            ])
            batch_rand_dirs = batch_rand_vectors / helper.safe_tf_sqrt(
                tf.reduce_sum((batch_rand_vectors**2), axis=1, keep_dims=True))
            self.Inv_MMD = self.stable_div(self.compute_MMD,
                                           self.posterior_latent_code,
                                           batch_rand_dirs)
            self.MMD = self.compute_MMD(self.posterior_latent_code,
                                        self.prior_dist.sample())
            self.enc_reg_cost = self.MMD + self.config[
                'enc_inv_MMD_strength'] * self.Inv_MMD
        elif self.config['divergence_mode'] == 'GAN' or self.config[
                'divergence_mode'] == 'NS-GAN':
            self.div_posterior = self.Diverger.forward(
                self.posterior_latent_code_expanded)
            self.div_prior = self.Diverger.forward(
                self.prior_latent_code_expanded)
            self.mean_z_divergence = tf.reduce_mean(
                tf.log(10e-7 + self.div_prior)) + tf.reduce_mean(
                    tf.log(10e-7 + 1 - self.div_posterior))
            if self.config['divergence_mode'] == 'NS-GAN':
                self.enc_reg_cost = -tf.reduce_mean(
                    tf.log(10e-7 + self.div_posterior))
            elif self.config['divergence_mode'] == 'GAN':
                self.enc_reg_cost = tf.reduce_mean(
                    tf.log(10e-7 + 1 - self.div_posterior))

        #############################################################################
        # REGULARIZER
        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(self.batch_size_tf, 1)),
                tf.ones(shape=(self.batch_size_tf, 1))
            ],
                             axis=1))

        ### Visualized regularization sample
        self.reg_target_dist = self.reconst_dist
        self.reg_dist = None

        ### Uniform Sample From Geodesic of Trivial Coupling Pairs
        if self.compute_all_regularizers or \
           'Trivial Gradient Norm' in self.config['critic_reg_mode'] or \
           'Trivial Lipschitz' in self.config['critic_reg_mode']:
            self.uniform_sample_b = self.uniform_dist.sample()
            self.trivial_line_sample_param = {'image': None, 'flat': None}
            try:
                self.uniform_sample_b_expanded = self.uniform_sample_b[:, np.
                                                                       newaxis, :,
                                                                       np.
                                                                       newaxis,
                                                                       np.
                                                                       newaxis]
                self.trivial_line_sample_param['image'] = self.uniform_sample_b_expanded*self.obs_sample['image']+\
                                                          (1-self.uniform_sample_b_expanded)*self.input_sample['image']
            except:
                self.uniform_sample_b_expanded = self.uniform_sample_b[:, np.
                                                                       newaxis, :]
                self.trivial_line_sample_param['flat'] = self.uniform_sample_b_expanded*self.obs_sample['flat']+\
                                                         (1-self.uniform_sample_b_expanded)*self.input_sample['flat']
            self.trivial_line_dist = distributions.ProductDistribution(
                sample_properties=batch['observed']['properties'],
                params=self.trivial_line_sample_param)
            self.trivial_line_sample = self.trivial_line_dist.sample(
                b_mode=True)
            if self.reg_dist is None: self.reg_dist = self.trivial_line_dist

        #############################################################################
        # CRITIC

        self.critic_real = self.Critic.forward(self.input_sample)
        self.critic_gen = self.Critic.forward(self.obs_sample)
        self.critic_rec = self.Critic.forward(self.reconst_sample)

        self.mean_critic_real = tf.reduce_mean(self.critic_real)
        self.mean_critic_gen = tf.reduce_mean(self.critic_gen)
        self.mean_critic_rec = tf.reduce_mean(self.critic_rec)

        self.mean_critic_reg = None

        if self.compute_all_regularizers or \
           'Trivial Gradient Norm' in self.config['critic_reg_mode'] or \
           'Trivial Lipschitz' in self.config['critic_reg_mode']:
            self.critic_trivial_line = self.Critic.forward(
                self.trivial_line_sample)
            if self.mean_critic_reg is None:
                self.mean_critic_reg = tf.reduce_mean(self.critic_trivial_line)
            try:
                self.trivial_line_grad = tf.gradients(
                    self.critic_trivial_line,
                    [self.trivial_line_sample['image']])[0]
                self.trivial_line_grad_norm = tf.sqrt(
                    tf.reduce_sum(tf.square(self.trivial_line_grad),
                                  axis=[-1, -2, -3],
                                  keep_dims=False))[:, np.newaxis, :]
            except:
                self.trivial_line_grad = tf.gradients(
                    self.critic_trivial_line,
                    [self.trivial_line_sample['flat']])[0]
                self.trivial_line_grad_norm = helper.safe_tf_sqrt(
                    tf.reduce_sum(self.trivial_line_grad**2,
                                  axis=[
                                      -1,
                                  ],
                                  keep_dims=True))

        self.cri_reg_cost = 0

        if self.compute_all_regularizers or 'Trivial Gradient Norm' in self.config[
                'critic_reg_mode']:
            self.trivial_line_grad_norm_1_penalties = ((
                self.trivial_line_grad_norm - 1.)**2)
            self.mean_trivial_line_grad_norm_1_penalties = tf.reduce_mean(
                self.trivial_line_grad_norm_1_penalties)
            if 'Trivial Gradient Norm' in self.config['critic_reg_mode']:
                print('Adding Trivial Gradient Norm Penalty')
                self.cri_reg_cost += self.mean_trivial_line_grad_norm_1_penalties

        if len(self.config['critic_reg_mode']) > 0:
            self.cri_reg_cost /= len(self.config['critic_reg_mode'])

        #############################################################################

        # OBJECTIVES
        ### Divergence
        if self.config['divergence_mode'] == 'GAN' or self.config[
                'divergence_mode'] == 'NS-GAN':
            self.div_cost = -self.config[
                'enc_reg_strength'] * self.mean_z_divergence

        ### Encoder
        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        # self.mean_OT_primal = helper.tf_print(self.mean_OT_primal, [self.mean_OT_primal])

        self.mean_POT_primal = self.mean_OT_primal  #+self.config['enc_reg_strength']*self.enc_reg_cost
        self.enc_cost = self.mean_POT_primal

        ### Generator
        if self.config['dual_dist_mode'] == 'Coupling':
            self.gen_cost = -self.mean_critic_rec
        elif self.config['dual_dist_mode'] == 'Prior':
            self.gen_cost = -self.mean_critic_gen
        elif self.config['dual_dist_mode'] == 'CouplingAndPrior':
            self.gen_cost = -0.5 * (self.mean_critic_rec +
                                    self.mean_critic_gen)

        ### Critic
        if self.config['dual_dist_mode'] == 'Coupling':
            self.mean_OT_dual = self.mean_critic_real - self.mean_critic_rec
        elif self.config['dual_dist_mode'] == 'Prior':
            self.mean_OT_dual = self.mean_critic_real - self.mean_critic_gen
        elif self.config['dual_dist_mode'] == 'CouplingAndPrior':
            self.mean_OT_dual = self.mean_critic_real - 0.5 * (
                self.mean_critic_rec + self.mean_critic_gen)

        self.cri_cost = -self.mean_OT_dual + self.config[
            'cri_reg_strength'] * self.cri_reg_cost
Exemplo n.º 14
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        empirical_observed_properties = copy.deepcopy(
            batch['observed']['properties'])
        for e in empirical_observed_properties['flat']:
            e['dist'] = 'dirac'
        for e in empirical_observed_properties['image']:
            e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=empirical_observed_properties,
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()

        self.obs_sample_param = self.Generator.forward(
            self.prior_latent_code[:, np.newaxis, :])
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample()

        self.obs_log_pdf = self.obs_sample_dist.log_pdf(self.input_sample)

        if not os.path.exists(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/'):
            os.makedirs(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/')
        if os.path.exists(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[400,
                      self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz',
                np_constant_prior_sample)

        self.constant_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_obs_sample_param = self.Generator.forward(
            self.constant_prior_latent_code[:, np.newaxis, :])
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample()

        if self.config['n_latent'] == 2:
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate(
                (xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]),
                axis=1)

            self.constant_prior_grid_latent_code = tf.constant(
                np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_obs_grid_sample_param = self.Generator.forward(
                self.constant_prior_grid_latent_code[:, np.newaxis, :])
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(
                sample_properties=batch['observed']['properties'],
                params=self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(
            )

        #############################################################################
        # ENCODER

        if self.config['encoder_mode'] == 'Gaussian':
            self.epsilon_param = self.EpsilonMap.forward(
                (tf.zeros(shape=(self.batch_size_tf, 1)), ))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(
                params=self.epsilon_param)
            self.epsilon = self.epsilon_dist.sample()
        else:
            self.epsilon = None

        self.posterior_latent_code_expanded = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,
                                                                         0, :]

        self.reconst_param = self.Generator.forward(
            self.posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)
        self.reconst_log_pdf = self.reconst_dist.log_pdf(self.input_sample)

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        self.interpolated_obs = self.Generator.forward(
            self.interpolated_posterior_latent_code)

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample

        #############################################################################
        # OBJECTIVES

        ## Encoder
        self.mean_neg_log_pdf = -tf.reduce_mean(self.reconst_log_pdf)

        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        # overall_cost = self.mean_neg_log_pdf
        timescale, start_time, min_tradeoff = 5, 10, 0.000001
        tradeoff = (1 - 2 * min_tradeoff) * helper.hardstep(
            (self.epoch - float(start_time)) / float(timescale)) + min_tradeoff
        overall_cost = tradeoff * self.mean_neg_log_pdf + (
            1 - tradeoff) * self.mean_OT_primal

        self.enc_cost = overall_cost

        ### Generator
        self.gen_cost = overall_cost
Exemplo n.º 15
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################

        self.pre_prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.pre_prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.pre_prior_param)
        # self.pre_prior_latent_template = tf.concat([tf.zeros(shape=[1, self.config['n_latent']-48]), tf.ones(shape=[1, 48])], axis=1)
        # self.pre_prior_latent_template = tf.concat([tf.ones(shape=[1, 48]), tf.zeros(shape=[1, self.config['n_latent']-48])], axis=1)
        self.pre_prior_latent_code = self.pre_prior_dist.sample()
        self.pre_prior_latent_code = self.pre_prior_latent_code  #*self.pre_prior_latent_template
        self.pre_prior_latent_code_expanded = self.pre_prior_latent_code[:, np.
                                                                         newaxis, :]

        # GENERATOR
        self.flow_param = self.FlowMap.forward()
        self.flow_object = transforms.HouseholdRotationFlow(
            self.flow_param, self.config['n_latent'])
        self.prior_latent_code, _ = self.flow_object.inverse_transform(
            self.pre_prior_latent_code,
            tf.zeros(shape=(self.batch_size_tf, 1)))
        self.prior_latent_code_expanded = self.prior_latent_code[:,
                                                                 np.newaxis, :]

        self.obs_sample_param = self.Generator.forward(
            self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if not os.path.exists(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/'):
            os.makedirs(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/')
        if os.path.exists(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.pre_prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.pre_prior_latent_code.get_shape().as_list()[-1]) +
                '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[
                    400,
                    self.pre_prior_latent_code.get_shape().as_list()[-1]
                ])
            np.save(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.pre_prior_latent_code.get_shape().as_list()[-1]) +
                '.npz', np_constant_prior_sample)

        self.constant_pre_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_pre_prior_latent_code_expanded = self.constant_pre_prior_latent_code[:,
                                                                                           np
                                                                                           .
                                                                                           newaxis, :]

        self.constant_prior_latent_code, _ = self.flow_object.inverse_transform(
            self.constant_pre_prior_latent_code,
            tf.zeros(shape=(self.batch_size_tf, 1)))
        self.constant_prior_latent_code_expanded = self.constant_prior_latent_code[:,
                                                                                   np
                                                                                   .
                                                                                   newaxis, :]

        self.constant_obs_sample_param = self.Generator.forward(
            self.constant_prior_latent_code_expanded)
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(
            b_mode=True)

        if self.config['n_latent'] == 2:
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate(
                (xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]),
                axis=1)

            self.constant_prior_grid_latent_code = tf.constant(
                np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_prior_grid_latent_code_expanded = self.constant_prior_grid_latent_code[:,
                                                                                                 np
                                                                                                 .
                                                                                                 newaxis, :]

            self.constant_obs_grid_sample_param = self.Generator.forward(
                self.constant_prior_grid_latent_code_expanded)
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(
                sample_properties=batch['observed']['properties'],
                params=self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(
                b_mode=True)

        #############################################################################
        # ENCODER
        self.epsilon_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.epsilon_dist = distributions.DiagonalGaussianDistribution(
            params=self.epsilon_param)

        if self.config['encoder_mode'] == 'Deterministic':
            self.epsilon = None
        if self.config['encoder_mode'] == 'Gaussian' or self.config[
                'encoder_mode'] == 'UnivApprox' or self.config[
                    'encoder_mode'] == 'UnivApproxNoSpatial' or self.config[
                        'encoder_mode'] == 'UnivApproxSine':
            self.epsilon = self.epsilon_dist.sample()

        self.posterior_latent_code_expanded, _ = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,
                                                                         0, :]
        self.transformed_posterior_latent_code, _ = self.flow_object.transform(
            self.posterior_latent_code,
            tf.zeros(shape=(self.batch_size_tf, 1)))

        # self.shouldbezero = self.transformed_posterior_latent_code[:, :self.config['n_latent']-48]
        # self.shouldbenormal = self.transformed_posterior_latent_code[:, self.config['n_latent']-48:]
        # self.isnormal = self.pre_prior_latent_code[:, self.config['n_latent']-48:]
        # self.shouldbezero = self.transformed_posterior_latent_code[:, 48:]
        # self.shouldbenormal = self.transformed_posterior_latent_code[:, :48]
        # self.isnormal = self.pre_prior_latent_code[:, :48]

        self.transformed_posterior_latent_code_abs = helper.relu_abs(
            self.transformed_posterior_latent_code)
        self.transformed_posterior_latent_code_sq = self.transformed_posterior_latent_code**2
        self.transformed_posterior_latent_code_abs_means = tf.reduce_mean(
            self.transformed_posterior_latent_code_abs, axis=0)
        self.transformed_posterior_latent_code_sq_means = tf.reduce_mean(
            self.transformed_posterior_latent_code_sq, axis=0)
        self.transformed_posterior_latent_code_abs_weighted = (
            tf.range(1, self.config['n_latent'] + 1, 1,
                     dtype=tf.float32)[np.newaxis, :] / self.config['n_latent']
        ) * self.transformed_posterior_latent_code_abs
        self.transformed_posterior_latent_code_sq_weighted = (
            tf.range(1, self.config['n_latent'] + 1, 1,
                     dtype=tf.float32)[np.newaxis, :] / self.config['n_latent']
        ) * self.transformed_posterior_latent_code_sq
        self.transformed_posterior_latent_code_abs_weighted_means = tf.reduce_mean(
            self.transformed_posterior_latent_code_abs_weighted, axis=0)
        self.transformed_posterior_latent_code_sq_weighted_means = tf.reduce_mean(
            self.transformed_posterior_latent_code_sq_weighted, axis=0)

        self.transformed_posterior_latent_code_norm = tf.reduce_sum(
            self.transformed_posterior_latent_code_sq, axis=1, keep_dims=True)
        self.transformed_posterior_latent_code_norm_weighted = tf.reduce_sum(
            self.transformed_posterior_latent_code_sq_weighted,
            axis=1,
            keep_dims=True)
        self.shouldbezero_cost = tf.reduce_mean(
            self.transformed_posterior_latent_code_norm_weighted)

        # self.shouldbezero_cost = tf.reduce_mean(self.transformed_posterior_latent_code_norm)
        # self.shouldbenormal_cost = self.compute_MMD(self.isnormal, self.shouldbenormal)

        self.interpolated_posterior_latent_code = self.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        self.interpolated_obs = self.Generator.forward(
            self.interpolated_posterior_latent_code)

        self.reconst_param = self.Generator.forward(
            self.posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        ### Primal Penalty
        # self.enc_reg_cost = self.compute_MMD(self.pre_prior_latent_code, self.pre_posterior_latent_code)
        # self.div_posterior = self.Diverger.forward(self.pre_posterior_latent_code_expanded)
        # self.div_prior = self.Diverger.forward(self.prior_latent_code_expanded)

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample

        #############################################################################

        # # OBJECTIVES
        # # Divergence
        # if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN':
        #     self.div_cost = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_posterior)+10e-7))+tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7)))
        # if self.config['divergence_mode'] == 'WGAN-GP':
        #     uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        #     uniform_w = uniform_dist.sample()
        #     self.trivial_line = uniform_w[:,np.newaxis,:]*self.pre_posterior_latent_code_expanded+(1-uniform_w[:,np.newaxis,:])*self.prior_latent_code_expanded
        #     self.div_trivial_line = self.Diverger.forward(self.trivial_line)
        #     self.trivial_line_grad = tf.gradients(tf.reduce_sum(self.div_trivial_line), [self.trivial_line])[0]
        #     self.trivial_line_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_line_grad**2, axis=-1, keep_dims=False)[:,:,np.newaxis])
        #     self.trivial_line_grad_norm_1_penalties = ((self.trivial_line_grad_norm-1)**2)
        #     self.div_reg_cost = tf.reduce_mean(self.trivial_line_grad_norm_1_penalties)
        #     self.div_cost = -(tf.reduce_mean(self.div_posterior)-tf.reduce_mean(self.div_prior))+10*self.div_reg_cost

        # ### Encoder
        # b_use_timer, timescale, starttime = False, 10, 5
        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        # if b_use_timer:
        #     self.mean_POT_primal = self.mean_OT_primal+helper.hardstep((self.epoch-float(starttime))/float(timescale))*self.config['enc_reg_strength']*self.enc_reg_cost
        # else:
        #     self.mean_POT_primal = self.mean_OT_primal+self.config['enc_reg_strength']*self.enc_reg_cost
        # self.enc_cost = self.mean_POT_primal
        self.enc_cost = self.mean_OT_primal

        # ### Critic
        # # self.cri_cost = self.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # if self.config['divergence_mode'] == 'NS-GAN':
        #     self.cri_cost = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*self.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'GAN':
        #     self.cri_cost = tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*self.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'WGAN-GP':
        #     self.cri_cost = -tf.reduce_mean(self.div_prior)+self.config['enc_reg_strength']*self.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # self.cri_reg_cost = self.shouldbezero_cost
        # self.cri_cost = self.shouldbenormal_cost+self.cri_reg_cost

        # self.cri_cost = self.shouldbezero_cost

        ### Generator
        self.gen_cost = self.mean_OT_primal
Exemplo n.º 16
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            batch_size_tf = tf.shape(self.input_sample['image'])[0]

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.prior_latent_code_logpdf = self.prior_dist.log_pdf(
            self.prior_latent_code)
        self.prior_logpdf = tf.reduce_mean(self.prior_latent_code_logpdf)

        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(batch_size_tf, 1)),
                tf.ones(shape=(batch_size_tf, 1))
            ],
                             axis=1))
        self.convex_mix = self.uniform_dist.sample()

        self.prior_latent_code_expanded = tf.reshape(
            self.prior_latent_code,
            [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        self.obs_sample_param = self.Decoder.forward(
            self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        #############################################################################
        b_stochastic_encoder = True
        self.epsilon_param = self.EpsilonMap.forward(
            (tf.zeros(shape=(batch_size_tf, 1)), ))
        self.epsilon_dist = distributions.DiagonalGaussianDistribution(
            params=self.epsilon_param)
        self.epsilon = self.epsilon_dist.sample()
        if b_stochastic_encoder:
            self.posterior_latent_code_expanded = self.EncodingPlan.forward(
                self.input_sample, epsilon_sample=self.epsilon)
        else:
            self.posterior_latent_code_expanded = self.EncodingPlan.forward(
                self.input_sample)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,
                                                                         0, :]
        self.posterior_latent_code_logpdf = self.prior_dist.log_pdf(
            self.posterior_latent_code)
        self.posterior_logpdf = tf.reduce_mean(
            self.posterior_latent_code_logpdf)
        self.reg_target_param = self.Decoder.forward(
            self.posterior_latent_code_expanded)
        self.reg_target_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reg_target_param)
        self.reg_target_sample = self.reg_target_dist.sample(b_mode=True)

        # #################################################################################

        self.reg_sample_param = {'image': None, 'flat': None}
        try:
            self.reg_sample_param['image'] = self.convex_mix[:, np.newaxis, :, np.newaxis, np.newaxis]*self.reg_target_sample['image']+\
                                            (1-self.convex_mix[:, np.newaxis, :, np.newaxis, np.newaxis])*self.input_sample['image']
        except:
            self.reg_sample_param['flat'] = self.convex_mix[:, np.newaxis, :]*self.reg_target_sample['flat']+\
                                            (1-self.convex_mix[:, np.newaxis, :])*self.input_sample['flat']
        self.reg_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reg_sample_param)
        self.reg_sample = self.reg_dist.sample(b_mode=True)

        self.critic_real = self.Discriminator.forward(self.input_sample)
        self.critic_gen = self.Discriminator.forward(self.obs_sample)
        self.critic_reg = self.Discriminator.forward(self.reg_sample)

        self.real_reg_distances_sq = self.metric_distance_sq(
            self.input_sample, self.reg_sample)
        self.real_reg_slopes_sq = ((self.critic_real - self.critic_reg)**
                                   2) / (self.real_reg_distances_sq + 1e-7)
        self.slope_penalties = ((self.real_reg_slopes_sq - 1)**2)

        self.mean_critic_real = tf.reduce_mean(self.critic_real)
        self.mean_critic_gen = tf.reduce_mean(self.critic_gen)
        self.mean_critic_reg = tf.reduce_mean(self.critic_reg)

        self.mean_real_reg_slopes_sq = tf.reduce_mean(self.real_reg_slopes_sq)
        self.mean_slope_penalty = tf.reduce_mean(self.slope_penalties)

        self.regularizer_cost = 10 * self.mean_slope_penalty
        self.discriminator_cost = -self.mean_critic_real + self.mean_critic_gen + self.regularizer_cost
        self.generator_cost = -self.mean_critic_gen
        self.transporter_cost = -self.mean_real_reg_slopes_sq
Exemplo n.º 17
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            batch_size_tf = tf.shape(self.input_sample['image'])[0]

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.neg_ent_prior = self.prior_dist.log_pdf(self.prior_latent_code)
        self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(batch_size_tf, 1)),
                tf.ones(shape=(batch_size_tf, 1))
            ],
                             axis=1))
        self.convex_mix = self.uniform_dist.sample()

        self.prior_latent_code_expanded = tf.reshape(
            self.prior_latent_code,
            [-1, 1, *self.prior_latent_code.get_shape().as_list()[1:]])
        self.obs_sample_param = self.Decoder.forward(
            self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        #############################################################################
        b_use_reconst_as_reg = True
        self.posterior_param_expanded = self.Encoder.forward(self.input_sample)
        self.posterior_param = self.posterior_param_expanded[:, 0, :]
        self.posterior_dist = distributions.DiagonalGaussianDistribution(
            params=self.posterior_param)
        self.posterior_latent_code = self.posterior_dist.sample()
        self.posterior_latent_code_expanded = self.posterior_latent_code[:, np.
                                                                         newaxis, :]

        self.reconst_param = self.Decoder.forward(
            self.posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        if b_use_reconst_as_reg:
            self.reg_target_dist = self.reconst_dist
            self.reg_target_sample = self.reconst_sample
        else:
            self.reg_target_dist = self.obs_sample_dist
            self.reg_target_sample = self.obs_sample

        self.neg_cross_ent_posterior = self.prior_dist.log_pdf(
            self.posterior_latent_code)
        self.mean_neg_cross_ent_posterior = tf.reduce_mean(
            self.neg_cross_ent_posterior)
        self.neg_ent_posterior = self.posterior_dist.log_pdf(
            self.posterior_latent_code)
        self.mean_neg_ent_posterior = tf.reduce_mean(self.neg_ent_posterior)

        # self.kl_posterior_prior = distributions.KLDivDiagGaussianVsDiagGaussian().forward(self.posterior_dist, self.prior_dist)
        # self.mean_kl_posterior_prior = tf.reduce_mean(self.kl_posterior_prior)

        self.mean_kl_posterior_prior = -self.mean_neg_cross_ent_posterior + self.mean_neg_ent_posterior
        #############################################################################

        self.reg_sample_param = {'image': None, 'flat': None}
        try:
            self.reg_sample_param['image'] = self.convex_mix[:, np.newaxis, :, np.newaxis, np.newaxis]*self.reg_target_sample['image']+\
                                            (1-self.convex_mix[:, np.newaxis, :, np.newaxis, np.newaxis])*self.input_sample['image']
        except:
            self.reg_sample_param['flat'] = self.convex_mix[:, np.newaxis, :]*self.reg_target_sample['flat']+\
                                            (1-self.convex_mix[:, np.newaxis, :])*self.input_sample['flat']
        self.reg_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reg_sample_param)
        self.reg_sample = self.reg_dist.sample(b_mode=True)

        self.critic_real = self.Discriminator.forward(self.input_sample)
        self.critic_gen = self.Discriminator.forward(self.obs_sample)
        self.critic_reg = self.Discriminator.forward(self.reg_sample)

        lambda_t = 1
        self.real_reconst_distances_sq = self.metric_distance_sq(
            self.input_sample, self.reconst_sample)
        self.autoencode_costs = self.real_reconst_distances_sq / 2 + lambda_t * self.mean_kl_posterior_prior

        try:
            self.convex_grad = tf.gradients(self.critic_reg,
                                            [self.reg_sample['image']])[0]
            self.convex_grad_norm = helper.safe_tf_sqrt(
                tf.reduce_sum(self.convex_grad**2,
                              axis=[-1, -2, -3],
                              keep_dims=False)[:, :, np.newaxis])
        except:
            self.convex_grad = tf.gradients(self.critic_reg,
                                            [self.reg_sample['flat']])[0]
            self.convex_grad_norm = helper.safe_tf_sqrt(
                tf.reduce_sum(self.convex_grad**2, axis=[-1], keep_dims=True))
        self.gradient_penalties = ((self.convex_grad_norm - 1)**2)

        self.mean_critic_real = tf.reduce_mean(self.critic_real)
        self.mean_critic_gen = tf.reduce_mean(self.critic_gen)
        self.mean_critic_reg = tf.reduce_mean(self.critic_reg)
        self.mean_autoencode_cost = tf.reduce_mean(self.autoencode_costs)
        self.mean_gradient_penalty = tf.reduce_mean(self.gradient_penalties)

        self.regularizer_cost = 10 * self.mean_gradient_penalty
        self.discriminator_cost = -self.mean_critic_real + self.mean_critic_gen + self.regularizer_cost
        self.generator_cost = -self.mean_critic_gen
        self.transporter_cost = self.mean_autoencode_cost
Exemplo n.º 18
0
def test_cgmm(num_points,
              components,
              g_dimensions,
              g_rank,
              c_dimensions,
              tolerance,
              training_steps,
              cluster=None):
    print("Generating data...")
    c_data, g_data, c_counts, c_means, g_means, g_covariances, \
        true_weights, responsibilities = utils.generate_cgmm_data(
            num_points, components, c_dimensions, g_dimensions, seed=20)

    print("Computing avg. covariance...")
    avg_g_data_variance = np.var(g_data,
                                 axis=0).sum() / components / g_dimensions

    print("Initializing components...")
    mixture_components = [
        distributions.ProductDistribution([
            distributions.GaussianDistribution(
                dims=g_dimensions,
                mean=g_data[comp],
                # covariance=covariances.IsotropicCovariance(
                #     g_dimensions,
                #     scalar=avg_g_data_variance,
                #     prior={"alpha": 1.0, "beta": 1.0}
                # ),
                # covariance=covariances.DiagonalCovariance(
                #     g_dimensions,
                #     vector=np.full((g_dimensions,), avg_g_data_variance),
                #     prior={"alpha": 1.0, "beta": 1.0}
                # ),
                covariance=covariances.SparseCovariance(
                    g_dimensions,
                    g_rank,
                    baseline=avg_g_data_variance,
                    prior={
                        "alpha": 1.0,
                        "beta": 1.0
                    }),
                # covariance=covariances.FullCovariance(
                #     g_dimensions,
                #     matrix=np.diag(np.full((g_dimensions,), avg_g_data_variance)),
                #     prior={"alpha": 1.0, "beta": 1.0}
                # ),
            ),
            distributions.CategoricalDistribution(c_counts)
        ]) for comp in range(components)
    ]

    print("Initializing model...")
    gmm = models.MixtureModel([g_data, c_data],
                              mixture_components,
                              cluster=cluster)

    print("Training model...\n")
    result = gmm.train(tolerance=tolerance,
                       max_steps=training_steps,
                       feedback=feedback_sub)

    final_g_means = np.stack([result[2][i][0][0] for i in range(components)])
    final_g_covariances = np.stack(
        [result[2][i][0][1] for i in range(components)])

    utils.plot_fitted_data(g_data[:, :2], final_g_means[:, :2],
                           final_g_covariances[:, :2, :2], g_means[:, :2],
                           g_covariances[:, :2, :2])

    return result