Exemple #1
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR
        self.flow_param_list = self.FlowMap.forward(batch)
        self.wolf_param_list = self.WolfMap.forward(batch)
        n_output = np.prod(
            batch['observed']['properties']['image'][0]['size'][2:])

        # Euclidean_flow_class = transforms.PiecewisePlanarScalingFlow
        # Euclidean_flow_class = transforms.RealNVPFlow
        Euclidean_flow_class = transforms.NonLinearIARFlow
        self.pre_flow_object = transforms.SerialFlow([\
                                                        Euclidean_flow_class(input_dim=self.config['n_latent'], parameters=self.wolf_param_list[0]),
                                                        transforms.SpecificRotationFlow(input_dim=self.config['n_latent']),
                                                        Euclidean_flow_class(input_dim=self.config['n_latent'], parameters=self.wolf_param_list[1]),
                                                        transforms.SpecificRotationFlow(input_dim=self.config['n_latent']),
                                                        Euclidean_flow_class(input_dim=self.config['n_latent'], parameters=self.wolf_param_list[2]),
                                                     ])

        self.flow_object = transforms.SerialFlow([\
                                                    Euclidean_flow_class(input_dim=self.config['n_latent'], parameters=self.flow_param_list[0]),
                                                    transforms.SpecificRotationFlow(input_dim=self.config['n_latent']),
                                                    Euclidean_flow_class(input_dim=self.config['n_latent'], parameters=self.flow_param_list[1]),
                                                    transforms.RiemannianFlow(input_dim=self.config['n_latent'], output_dim=n_output, n_input_NOM=self.config['rnf_prop']['n_input_NOM'], n_output_NOM=self.config['rnf_prop']['n_output_NOM'], parameters=self.flow_param_list[-2]),
                                                    transforms.CompoundRotationFlow(input_dim=n_output, parameters=self.flow_param_list[-1]),
                                                 ])

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()

        self.pre_prior_latent_code, _ = self.pre_flow_object.inverse_transform(
            self.prior_latent_code, tf.zeros(shape=(self.batch_size_tf, 1)))
        self.transformed_prior_latent_code, _ = self.flow_object.transform(
            self.pre_prior_latent_code,
            tf.zeros(shape=(self.batch_size_tf, 1)))

        self.obs_sample_param = {
            'flat':
            None,
            'image':
            tf.reshape(self.transformed_prior_latent_code, [
                -1, 1, *batch['observed']['properties']['image'][0]['size'][2:]
            ])
        }
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if not os.path.exists(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/'):
            os.makedirs(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/')
        if os.path.exists(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[400,
                      self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz',
                np_constant_prior_sample)
        self.constant_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)

        self.constant_pre_prior_latent_code, _ = self.pre_flow_object.inverse_transform(
            self.constant_prior_latent_code,
            tf.zeros(shape=(self.batch_size_tf, 1)))
        self.constant_transformed_prior_latent_code, _ = self.flow_object.transform(
            self.constant_pre_prior_latent_code,
            tf.zeros(shape=(self.batch_size_tf, 1)))

        self.constant_obs_sample_param = {
            'flat':
            None,
            'image':
            tf.reshape(self.constant_transformed_prior_latent_code, [
                -1, 1, *batch['observed']['properties']['image'][0]['size'][2:]
            ])
        }
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(
            b_mode=True)

        # if self.config['n_latent'] == 2:
        #     grid_scale = 3
        #     x = np.linspace(-grid_scale, grid_scale, 20)
        #     y = np.linspace(grid_scale, -grid_scale, 20)
        #     xv, yv = np.meshgrid(x, y)
        #     np_constant_prior_grid_sample = np.concatenate((xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]), axis=1)
        #     self.constant_grid_prior_latent_code = tf.constant(np.asarray(np_constant_prior_grid_sample), dtype=np.float32)

        #     self.constant_grid_pre_prior_latent_code, _ = self.pre_flow_object.inverse_transform(self.constant_grid_prior_latent_code, tf.zeros(shape=(self.batch_size_tf, 1)))
        #     self.constant_grid_transformed_prior_latent_code, _ = self.flow_object.transform(self.constant_grid_pre_prior_latent_code, tf.zeros(shape=(self.batch_size_tf, 1)))

        #     self.constant_grid_obs_sample_param = {'flat': None, 'image': tf.reshape(self.constant_grid_transformed_prior_latent_code, [-1, 1, *batch['observed']['properties']['image'][0]['size'][2:]])}
        #     self.constant_grid_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_grid_obs_sample_param)
        #     self.constant_grid_obs_sample = self.constant_grid_obs_sample_dist.sample(b_mode=True)

        #############################################################################
        # ENCODER
        if self.config['encoder_mode'] == 'Deterministic':
            self.epsilon = None
        if self.config['encoder_mode'] == 'Gaussian' or self.config[
                'encoder_mode'] == 'GaussianLeastVariance' or self.config[
                    'encoder_mode'] == 'UnivApprox' or 'UnivApproxNoSpatial' in self.config[
                        'encoder_mode'] or self.config[
                            'encoder_mode'] == 'UnivApproxSine':
            self.epsilon_param = self.EpsilonMap.forward(
                (tf.zeros(shape=(self.batch_size_tf, 1)), ))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(
                params=self.epsilon_param)
            self.epsilon = self.epsilon_dist.sample()

        self.pre_posterior_latent_code_expanded, self.pre_posterior_latent_code_det_expanded = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.pre_posterior_latent_code = self.pre_posterior_latent_code_expanded[:,
                                                                                 0, :]

        self.posterior_latent_code, self.posterior_delta_log_pdf = self.pre_flow_object.transform(
            self.pre_posterior_latent_code,
            tf.zeros(shape=(self.batch_size_tf, 1)))
        self.posterior_log_pdf = self.prior_dist.log_pdf(
            self.posterior_latent_code)
        self.pre_posterior_log_pdf = self.posterior_log_pdf - self.posterior_delta_log_pdf

        # self.pre_posterior_latent_code, self.pre_posterior_log_pdf = self.pre_flow_object.inverse_transform(self.posterior_latent_code, self.posterior_log_pdf)
        self.transformed_pre_posterior_latent_code, self.transformed_pre_posterior_log_pdf = self.flow_object.transform(
            self.pre_posterior_latent_code, self.pre_posterior_log_pdf)

        self.reconst_param = {
            'flat':
            None,
            'image':
            tf.reshape(self.transformed_pre_posterior_latent_code, [
                -1, 1, *batch['observed']['properties']['image'][0]['size'][2:]
            ])
        }
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        # self.interpolated_pre_posterior_latent_code, _ = self.pre_flow_object.inverse_transform(tf.reshape(self.interpolated_posterior_latent_code, [-1, self.interpolated_posterior_latent_code.get_shape().as_list()[-1]]), tf.zeros(shape=(self.batch_size_tf, 1)))
        # self.interpolated_transformed_posterior_latent_code, _ = self.flow_object.transform(self.interpolated_pre_posterior_latent_code, tf.zeros(shape=(self.batch_size_tf, 1)))
        # self.interpolated_obs = {'flat': None, 'image': tf.reshape(self.interpolated_transformed_posterior_latent_code, [-1, 10, *batch['observed']['properties']['image'][0]['size'][2:]])}
        self.interpolated_obs = {
            'flat':
            None,
            'image':
            tf.tile(
                self.input_sample['image'][:self.batch_size_tf //
                                           2, :, :, :, :], [1, 10, 1, 1, 1])
        }

        self.enc_reg_cost = -tf.reduce_mean(
            self.transformed_pre_posterior_log_pdf)
        self.cri_reg_cost = -tf.reduce_mean(self.pre_posterior_log_pdf)

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample

        #############################################################################
        # OBJECTIVES

        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        if '0' in self.config['timers']:
            lambda_t = helper.hardstep(
                (self.epoch - float(self.config['timers']['0']['start'])) /
                float(self.config['timers']['0']['timescale']) + 1e-5)
            overall_cost = self.mean_OT_primal + lambda_t * self.config[
                'enc_reg_strength'] * self.enc_reg_cost
        else:
            overall_cost = self.mean_OT_primal + self.config[
                'enc_reg_strength'] * self.enc_reg_cost

        # self.cri_cost = self.cri_reg_cost
        # self.cri_cost = self.config['enc_reg_strength']*self.enc_reg_cost
        self.cri_cost = self.enc_reg_cost
        self.enc_cost = overall_cost
        self.gen_cost = overall_cost
Exemple #2
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR

        self.pre_prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.pre_prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.pre_prior_param)

        self.pre_prior_latent_code = self.pre_prior_dist.sample()
        self.pre_prior_latent_code_expanded = self.pre_prior_latent_code[:, np.
                                                                         newaxis, :]
        self.neg_ent_prior = self.pre_prior_dist.log_pdf(
            self.pre_prior_latent_code)
        self.mean_neg_ent_prior = tf.reduce_mean(self.neg_ent_prior)

        self.prior_latent_code = self.PriorExpandMap.forward(
            self.pre_prior_latent_code_expanded)
        self.prior_latent_code_expanded = self.prior_latent_code[:,
                                                                 np.newaxis, :]

        self.obs_sample_param = self.Generator.forward(
            self.prior_latent_code_expanded)
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if not os.path.exists(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/'):
            os.makedirs(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/')
        if os.path.exists(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.pre_prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.pre_prior_latent_code.get_shape().as_list()[-1]) +
                '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[
                    400,
                    self.pre_prior_latent_code.get_shape().as_list()[-1]
                ])
            np.save(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.pre_prior_latent_code.get_shape().as_list()[-1]) +
                '.npz', np_constant_prior_sample)

        self.constant_pre_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_pre_prior_latent_code_expanded = self.constant_pre_prior_latent_code[:,
                                                                                           np
                                                                                           .
                                                                                           newaxis, :]

        self.constant_prior_latent_code = self.PriorExpandMap.forward(
            self.constant_pre_prior_latent_code_expanded)
        self.constant_prior_latent_code_expanded = self.constant_prior_latent_code[:,
                                                                                   np
                                                                                   .
                                                                                   newaxis, :]

        self.constant_obs_sample_param = self.Generator.forward(
            self.constant_prior_latent_code_expanded)
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(
            b_mode=True)

        if self.config['n_latent'] == 2:
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate(
                (xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]),
                axis=1)

            self.constant_prior_grid_latent_code = tf.constant(
                np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_prior_grid_latent_code_expanded = self.constant_prior_grid_latent_code[:,
                                                                                                 np
                                                                                                 .
                                                                                                 newaxis, :]

            self.constant_obs_grid_sample_param = self.Generator.forward(
                self.constant_prior_grid_latent_code_expanded)
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(
                sample_properties=batch['observed']['properties'],
                params=self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(
                b_mode=True)

        #############################################################################
        # ENCODER
        self.epsilon_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.epsilon_dist = distributions.DiagonalGaussianDistribution(
            params=self.epsilon_param)

        if self.config['encoder_mode'] == 'Deterministic':
            self.epsilon = None
        if self.config['encoder_mode'] == 'Gaussian' or self.config[
                'encoder_mode'] == 'UnivApprox' or self.config[
                    'encoder_mode'] == 'UnivApproxNoSpatial' or self.config[
                        'encoder_mode'] == 'UnivApproxSine':
            self.epsilon = self.epsilon_dist.sample()

        self.pre_posterior_latent_code_expanded, _ = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.pre_posterior_latent_code = self.pre_posterior_latent_code_expanded[:,
                                                                                 0, :]

        self.nball_param = tf.concat([
            self.pre_posterior_latent_code,
            0.1 * tf.ones(shape=(self.batch_size_tf, 1))
        ],
                                     axis=1)
        self.nball_dist = distributions.UniformBallDistribution(
            params=self.tiny_perturb_param)
        self.posterior_latent_code = self.nball_dist.sample()
        self.posterior_latent_code_log_pdf = self.nball_dist.log_pdf(
            self.posterior_latent_code)
        # self.posterior_latent_code = self.pre_posterior_latent_code
        self.posterior_latent_code_expanded = self.posterior_latent_code[:, np.
                                                                         newaxis, :]

        # self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2)
        self.interpolated_pre_posterior_latent_code = helper.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        self.interpolated_posterior_latent_code_collapsed = self.PriorExpandMap.forward(
            tf.reshape(self.interpolated_pre_posterior_latent_code, [
                -1, 1,
                self.interpolated_pre_posterior_latent_code.get_shape().
                as_list()[-1]
            ]))
        self.interpolated_posterior_latent_code = tf.reshape(
            self.interpolated_posterior_latent_code_collapsed, [
                -1, *self.interpolated_pre_posterior_latent_code.get_shape().
                as_list()[-2:]
            ])
        self.interpolated_obs = self.Generator.forward(
            self.interpolated_posterior_latent_code)

        self.reconst_param = self.Generator.forward(
            self.pre_posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        ### Primal Penalty
        self.enc_reg_cost = helper.compute_MMD(self.pre_prior_latent_code,
                                               self.pre_posterior_latent_code)
        self.div_posterior = self.Diverger.forward(
            self.pre_posterior_latent_code_expanded)
        self.div_prior = self.Diverger.forward(self.prior_latent_code_expanded)

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample

        #############################################################################

        # OBJECTIVES
        # Divergence
        if self.config['divergence_mode'] == 'GAN' or self.config[
                'divergence_mode'] == 'NS-GAN':
            self.div_cost = -(tf.reduce_mean(
                tf.log(tf.nn.sigmoid(self.div_posterior) +
                       10e-7)) + tf.reduce_mean(
                           tf.log(1 - tf.nn.sigmoid(self.div_prior) + 10e-7)))
        if self.config['divergence_mode'] == 'WGAN-GP':
            uniform_dist = distributions.UniformDistribution(
                params=tf.concat([
                    tf.zeros(shape=(self.batch_size_tf, 1)),
                    tf.ones(shape=(self.batch_size_tf, 1))
                ],
                                 axis=1))
            uniform_w = uniform_dist.sample()
            self.trivial_line = uniform_w[:, np.
                                          newaxis, :] * self.pre_posterior_latent_code_expanded + (
                                              1 - uniform_w[:, np.newaxis, :]
                                          ) * self.prior_latent_code_expanded
            self.div_trivial_line = self.Diverger.forward(self.trivial_line)
            self.trivial_line_grad = tf.gradients(
                tf.reduce_sum(self.div_trivial_line), [self.trivial_line])[0]
            self.trivial_line_grad_norm = helper.safe_tf_sqrt(
                tf.reduce_sum(self.trivial_line_grad**2,
                              axis=-1,
                              keep_dims=False)[:, :, np.newaxis])
            self.trivial_line_grad_norm_1_penalties = ((
                self.trivial_line_grad_norm - 1)**2)
            self.div_reg_cost = tf.reduce_mean(
                self.trivial_line_grad_norm_1_penalties)
            # self.div_cost = -(tf.reduce_mean(self.div_posterior)-tf.reduce_mean(self.div_prior))+10*self.div_reg_cost

        # ### Encoder
        b_use_timer, timescale, starttime = False, 10, 5
        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        # if b_use_timer:
        #     self.mean_POT_primal = self.mean_OT_primal+helper.hardstep((self.epoch-float(starttime))/float(timescale))*self.config['enc_reg_strength']*self.enc_reg_cost
        # else:
        #     self.mean_POT_primal = self.mean_OT_primal+self.config['enc_reg_strength']*self.enc_reg_cost
        self.mean_POT_primal = self.mean_OT_primal
        self.enc_cost = self.mean_POT_primal

        # ### Critic
        # # self.cri_cost = helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # if self.config['divergence_mode'] == 'NS-GAN':
        #     self.cri_cost = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'GAN':
        #     self.cri_cost = tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'WGAN-GP':
        #     self.cri_cost = -tf.reduce_mean(self.div_prior)+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)

        ### Generator
        self.gen_cost = self.mean_OT_primal
Exemple #3
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]
        
        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR 

        self.pre_prior_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))      
        # self.pre_prior_dist = distributions.UniformDistribution(params = self.pre_prior_param)        
        self.pre_prior_dist = distributions.DiagonalBetaDistribution(params = self.pre_prior_param)        
        self.pre_prior_latent_code = self.pre_prior_dist.sample()
        self.pre_prior_latent_code_log_pdf = self.pre_prior_dist.log_pdf(self.pre_prior_latent_code)

        self.ambient_prior_param = self.PriorMapBetaInverted.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))      
        self.ambient_prior_dist = distributions.DiagonalBetaDistribution(params = self.ambient_prior_param)        

        self.flow_param_list = self.FlowMap.forward()
        # self.flow_object = transforms.SerialFlow([ \
        #                                           transforms.InverseOpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
        #                                           transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[0]), 
        #                                           transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[1]),
        #                                           transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[2]),
        #                                           transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[3]),
        #                                           # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[4]),
        #                                           # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[5]),
        #                                           # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[6]),
        #                                           transforms.OpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
        #                                           ])
        
        # self.flow_object = transforms.SerialFlow([ \
        #                                           transforms.InverseOpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[0]), 
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[7]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[1]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[8]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[2]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[9]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[3]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[10]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[4]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[11]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[5]),
        #                                           transforms.HouseholdRotationFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[12]), 
        #                                           transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[6]),
        #                                           transforms.OpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
        #                                           ])
        
        self.flow_object = transforms.SerialFlow([ \
                                                  transforms.InverseOpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
                                                  transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[0]), 
                                                  transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[1]),
                                                  transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[2]),
                                                  transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[3]),
                                                  # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  # transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[4]),
                                                  # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  # transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[5]),
                                                  # transforms.SpecificOrderDimensionFlow(input_dim=self.config['n_latent']), 
                                                  # transforms.RealNVPFlow(input_dim=self.config['n_latent'], parameters=self.flow_param_list[6]),
                                                  transforms.OpenIntervalDimensionFlow(input_dim=self.config['n_latent']),
                                                  ])
        self.gen_flow_object = self.flow_object
        self.prior_latent_code, self.prior_latent_code_log_pdf = self.flow_object.inverse_transform(self.pre_prior_latent_code, self.pre_prior_latent_code_log_pdf)

        self.obs_sample_param = self.Generator.forward(self.prior_latent_code[:, np.newaxis, :])
        self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)        

        if not os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/'): os.makedirs(str(Path.home())+'/ExperimentalResults/FixedSamples/')
        # if os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform11_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
        if os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform01_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
            # np_constant_prior_sample = np.load(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform11_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz')
            np_constant_prior_sample = np.load(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform01_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz')
        else:
            # np_constant_prior_sample = np.random.uniform(low=-1., high=1., size=[400, self.pre_prior_latent_code.get_shape().as_list()[-1]])
            np_constant_prior_sample = np.random.uniform(low=0., high=1., size=[400, self.pre_prior_latent_code.get_shape().as_list()[-1]])
            # np.save(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform11_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    
            np.save(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_uniform01_prior_sample_'+str(self.pre_prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    

        self.constant_pre_prior_latent_code = tf.constant(np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_prior_latent_code, _ = self.flow_object.inverse_transform(self.constant_pre_prior_latent_code, tf.zeros(shape=(self.batch_size_tf, 1)))

        self.constant_obs_sample_param = self.Generator.forward(self.constant_prior_latent_code[:, np.newaxis, :])
        self.constant_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(b_mode=True)
    
        if self.config['n_latent'] == 2: 
            # x = np.linspace(-1, 1, 20)
            # y = np.linspace(1, -1, 20)
            x = np.linspace(0, 1, 20)
            y = np.linspace(1, 0, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate((xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]), axis=1)        
            self.constant_prior_grid_latent_code = tf.constant(np.asarray(np_constant_prior_grid_sample), dtype=np.float32)

            self.constant_obs_grid_sample_param = self.Generator.forward(self.constant_prior_grid_latent_code[:, np.newaxis, :])
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(b_mode=True)
                
        #############################################################################
        # ENCODER 
        self.epsilon_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.epsilon_dist = distributions.DiagonalGaussianDistribution(params=self.epsilon_param) 
        
        if self.config['encoder_mode'] == 'Deterministic': 
            self.epsilon = None
        if self.config['encoder_mode'] == 'Gaussian' or self.config['encoder_mode'] == 'UnivApprox' or self.config['encoder_mode'] == 'UnivApproxNoSpatial' or self.config['encoder_mode'] == 'UnivApproxSine':         
            self.epsilon = self.epsilon_dist.sample()

        # self.pre_posterior_latent_code = 0.95*tf.nn.tanh(self.Encoder.forward(self.input_sample, noise=self.epsilon))[:,0,:]
        # self.pre_posterior_latent_code = 0.9*tf.nn.sigmoid(self.Encoder.forward(self.input_sample, noise=self.epsilon))[:,0,:]+0.05
        self.pre_posterior_latent_code = 0.8*tf.nn.sigmoid(self.Encoder.forward(self.input_sample, noise=self.epsilon))[:,0,:]+0.1
        self.nball_param = tf.concat([self.pre_posterior_latent_code, 0.05*tf.ones(shape=(self.batch_size_tf, 1))], axis=1)
        self.nball_dist = distributions.UniformBallDistribution(params=self.nball_param) 
        self.posterior_latent_code = self.nball_dist.sample()
        self.posterior_latent_code_log_pdf = -np.log(50000)+self.nball_dist.log_pdf(self.posterior_latent_code)
        self.transformed_posterior_latent_code, self.transformed_posterior_latent_code_log_pdf = self.flow_object.transform(self.posterior_latent_code, self.posterior_latent_code_log_pdf)

        self.hollow_nball_param = tf.concat([self.pre_posterior_latent_code, 0.05*tf.ones(shape=(self.batch_size_tf, 1)), 0.1*tf.ones(shape=(self.batch_size_tf, 1))], axis=1)
        self.hollow_nball_dist = distributions.UniformHollowBallDistribution(params=self.hollow_nball_param) 
        self.hollow_posterior_latent_code = self.hollow_nball_dist.sample()
        self.hollow_posterior_latent_code_log_pdf = -np.log(50000)+self.hollow_nball_dist.log_pdf(self.hollow_posterior_latent_code)
        self.transformed_hollow_posterior_latent_code, self.transformed_hollow_posterior_latent_code_log_pdf = self.flow_object.transform(self.hollow_posterior_latent_code, self.hollow_posterior_latent_code_log_pdf)

        self.ambient_param = self.AmbientMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.ambient_dist = distributions.UniformDistribution(params = self.ambient_param)  
        self.ambient_latent_code = self.ambient_dist.sample()
        self.ambient_latent_code_log_pdf = self.ambient_dist.log_pdf(self.ambient_latent_code)
        self.transformed_ambient_latent_code, self.transformed_ambient_latent_code_log_pdf = self.flow_object.transform(self.ambient_latent_code, self.ambient_latent_code_log_pdf)

        self.KL_transformed_prior_per = self.transformed_posterior_latent_code_log_pdf-self.pre_prior_dist.log_pdf(self.transformed_posterior_latent_code)
        self.KL_transformed_hollow_prior_per = self.transformed_hollow_posterior_latent_code_log_pdf-self.ambient_prior_dist.log_pdf(self.transformed_hollow_posterior_latent_code)
        self.KL_transformed_ambient_prior_per = self.transformed_ambient_latent_code_log_pdf-self.ambient_prior_dist.log_pdf(self.transformed_ambient_latent_code)
        self.KL_transformed_prior = tf.reduce_mean(self.KL_transformed_prior_per)
        self.KL_transformed_hollow_prior = tf.reduce_mean(self.KL_transformed_hollow_prior_per)
        self.KL_transformed_ambient_prior = tf.reduce_mean(self.KL_transformed_ambient_prior_per)

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2)
        self.interpolated_obs = self.Generator.forward(self.interpolated_posterior_latent_code) 

        self.reconst_param = self.Generator.forward(self.posterior_latent_code[:, np.newaxis, :]) 
        self.reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        ### Primal Penalty
        
        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample
        
        #############################################################################

        # OBJECTIVES
        # Divergence
        # if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN':
        #     self.div_cost = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_posterior)+10e-7))+tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7)))
        # if self.config['divergence_mode'] == 'WGAN-GP':
        #     uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        #     uniform_w = uniform_dist.sample()
        #     self.trivial_line = uniform_w[:,np.newaxis,:]*self.pre_posterior_latent_code_expanded+(1-uniform_w[:,np.newaxis,:])*self.prior_latent_code_expanded
        #     self.div_trivial_line = self.Diverger.forward(self.trivial_line)
        #     self.trivial_line_grad = tf.gradients(tf.reduce_sum(self.div_trivial_line), [self.trivial_line])[0]
        #     self.trivial_line_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_line_grad**2, axis=-1, keep_dims=False)[:,:,np.newaxis])
        #     self.trivial_line_grad_norm_1_penalties = ((self.trivial_line_grad_norm-1)**2)
        #     self.div_reg_cost = tf.reduce_mean(self.trivial_line_grad_norm_1_penalties)
        #     # self.div_cost = -(tf.reduce_mean(self.div_posterior)-tf.reduce_mean(self.div_prior))+10*self.div_reg_cost
        # self.div_cost = (self.KL_transformed_prior+self.KL_transformed_ambient_prior+self.KL_transformed_hollow_prior)
        self.div_cost = (self.KL_transformed_prior+self.KL_transformed_hollow_prior)
       
        # ### Encoder
        # b_use_timer, timescale, starttime = False, 10, 5
        self.OT_primal = self.sample_distance_function(self.input_sample, self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        # if b_use_timer:
        #     self.mean_POT_primal = self.mean_OT_primal+helper.hardstep((self.epoch-float(starttime))/float(timescale))*self.config['enc_reg_strength']*self.enc_reg_cost
        # else:
        #     self.mean_POT_primal = self.mean_OT_primal+self.config['enc_reg_strength']*self.enc_reg_cost
        # self.enc_cost = self.mean_POT_primal
        self.enc_cost = self.mean_OT_primal

        # ### Critic
        # # self.cri_cost = helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # if self.config['divergence_mode'] == 'NS-GAN': 
        #     self.cri_cost = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'GAN': 
        #     self.cri_cost = tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+10e-7))+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)
        # elif self.config['divergence_mode'] == 'WGAN-GP': 
        #     self.cri_cost = -tf.reduce_mean(self.div_prior)+self.config['enc_reg_strength']*helper.compute_MMD(self.pre_prior_latent_code, self.prior_latent_code)

        ### Generator
        self.gen_cost = self.mean_OT_primal
Exemple #4
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        if len(batch['observed']['properties']['flat']) > 0:
            for e in batch['observed']['properties']['flat']:
                e['dist'] = 'dirac'
        else:
            for e in batch['observed']['properties']['image']:
                e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()
        self.obs_sample_param = self.Generator.forward(
            self.prior_latent_code[:, np.newaxis, :])
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)

        if not os.path.exists(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/'):
            os.makedirs(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/')
        if os.path.exists(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[400,
                      self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz',
                np_constant_prior_sample)

        self.constant_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_obs_sample_param = self.Generator.forward(
            self.constant_prior_latent_code[:, np.newaxis, :])
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(
            b_mode=True)

        if self.config['n_latent'] == 2:
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate(
                (xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]),
                axis=1)

            self.constant_prior_grid_latent_code = tf.constant(
                np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_obs_grid_sample_param = self.Generator.forward(
                self.constant_prior_grid_latent_code[:, np.newaxis, :])
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(
                sample_properties=batch['observed']['properties'],
                params=self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(
                b_mode=True)

        #############################################################################
        # ENCODER
        self.uniform_dist = distributions.UniformDistribution(
            params=tf.concat([
                tf.zeros(shape=(self.batch_size_tf, 1)),
                tf.ones(shape=(self.batch_size_tf, 1))
            ],
                             axis=1))

        if self.config['encoder_mode'] == 'Deterministic':
            pdb.set_trace()
        if self.config['encoder_mode'] == 'Gaussian' or self.config[
                'encoder_mode'] == 'UnivApprox' or self.config[
                    'encoder_mode'] == 'UnivApproxNoSpatial' or self.config[
                        'encoder_mode'] == 'UnivApproxSine':
            self.epsilon_param = self.EpsilonMap.forward(
                (tf.zeros(shape=(self.batch_size_tf, 1)), ))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(
                params=self.epsilon_param)
            # self.epsilon_dist = distributions.BernoulliDistribution(params = self.epsilon_param)
            self.epsilon = self.epsilon_dist.sample()

        self.posterior_latent_code_expanded, self.posterior_latent_code_det_expanded = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,
                                                                         0, :]
        self.posterior_latent_code_det = self.posterior_latent_code_det_expanded[:,
                                                                                 0, :]

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        self.interpolated_obs = self.Generator.forward(
            self.interpolated_posterior_latent_code)

        self.reconst_param = self.Generator.forward(
            self.posterior_latent_code[:, np.newaxis, :])
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample

        #############################################################################
        # OBJECTIVES

        ### Encoder
        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        self.enc_cost = self.mean_OT_primal

        ### Generator
        self.gen_cost = self.mean_OT_primal
Exemple #5
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]
        
        if len(batch['observed']['properties']['flat'])>0:
            for e in batch['observed']['properties']['flat']: e['dist']='dirac'
        else:
            for e in batch['observed']['properties']['image']: e['dist']='dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try: self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except: self.n_time = batch['observed']['properties']['image'][0]['size'][1]
        try: self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except: self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR 

        self.prior_param = self.PriorMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
        self.prior_dist = distributions.DiagonalGaussianDistribution(params = self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()        
        self.obs_sample_param = self.Generator.forward(self.prior_latent_code[:, np.newaxis, :])
        self.obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample(b_mode=True)
        
        if not os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/'): os.makedirs(str(Path.home())+'/ExperimentalResults/FixedSamples/')
        if os.path.exists(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz'): 
            np_constant_prior_sample = np.load(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz')
        else:
            np_constant_prior_sample = np.random.normal(loc=0., scale=1., size=[400, self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(str(Path.home())+'/ExperimentalResults/FixedSamples/np_constant_prior_sample_'+str(self.prior_latent_code.get_shape().as_list()[-1])+'.npz', np_constant_prior_sample)    
        
        self.constant_prior_latent_code = tf.constant(np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_obs_sample_param = self.Generator.forward(self.constant_prior_latent_code[:, np.newaxis, :])
        self.constant_obs_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample(b_mode=True)
                
        if self.config['n_latent'] == 2: 
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate((xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]), axis=1)
        
            self.constant_prior_grid_latent_code = tf.constant(np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_obs_grid_sample_param = self.Generator.forward(self.constant_prior_grid_latent_code[:, np.newaxis, :])
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(b_mode=True)
                
        #############################################################################
        # ENCODER 
        self.uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        self.uniform_sample = self.uniform_dist.sample()

        if self.config['encoder_mode'] == 'Deterministic': 
            pdb.set_trace()
        if self.config['encoder_mode'] == 'Gaussian' or self.config['encoder_mode'] == 'UnivApprox' or self.config['encoder_mode'] == 'UnivApproxNoSpatial' or self.config['encoder_mode'] == 'UnivApproxSine': 
            self.epsilon_param = self.EpsilonMap.forward((tf.zeros(shape=(self.batch_size_tf, 1)),))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(params = self.epsilon_param)        
            # self.epsilon_dist = distributions.BernoulliDistribution(params = self.epsilon_param)        
            self.epsilon = self.epsilon_dist.sample()
            

        self.posterior_latent_code_expanded, self.posterior_latent_code_det_expanded = self.Encoder.forward(self.input_sample, noise=self.epsilon)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,0,:]
        self.posterior_latent_code_det = self.posterior_latent_code_det_expanded[:,0,:]

        # self.flow_param_list = self.FlowMap.forward()
        # self.flow_object = transforms.SerialFlow([\
        #                                           transforms.NonLinearIARFlow(input_dim=2*self.config['n_latent'], parameters=self.flow_param_list[0]), 
        #                                           transforms.CustomSpecificOrderDimensionFlow(input_dim=2*self.config['n_latent']), 
        #                                           transforms.NonLinearIARFlow(input_dim=2*self.config['n_latent'], parameters=self.flow_param_list[1]),
        #                                           # transforms.CustomSpecificOrderDimensionFlow(input_dim=2*self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=2*self.config['n_latent'], parameters=self.flow_param_list[2]),
        #                                           # transforms.CustomSpecificOrderDimensionFlow(input_dim=2*self.config['n_latent']), 
        #                                           # transforms.NonLinearIARFlow(input_dim=2*self.config['n_latent'], parameters=self.flow_param_list[3]),
        #                                           ])
        
        # self.transformed_posterior_latent_code, _ = self.flow_object.transform(tf.concat([self.posterior_latent_code_det, self.epsilon], axis=1), tf.zeros(shape=(self.batch_size_tf, 1)))
        # self.posterior_latent_code = self.transformed_posterior_latent_code[:, self.config['n_latent']:]

        self.info_param = self.InfoMap.forward(self.posterior_latent_code)
        self.info_dist = distributions.DiagonalGaussianDistribution(params = self.info_param)
        # self.info_dist = distributions.BernoulliDistribution(params = self.info_param)
        self.epsilon_info_log_pdf = self.info_dist.log_pdf(self.epsilon)

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(self.posterior_latent_code, size=self.batch_size_tf//2)
        self.interpolated_obs = self.Generator.forward(self.interpolated_posterior_latent_code) 

        self.reconst_param = self.Generator.forward(self.posterior_latent_code[:, np.newaxis, :]) 
        self.reconst_dist = distributions.ProductDistribution(sample_properties = batch['observed']['properties'], params = self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)

        self.MMD = helper.compute_MMD(self.posterior_latent_code, self.prior_dist.sample())
        self.info_cost = -tf.reduce_mean(self.epsilon_info_log_pdf)

        if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN' or self.config['divergence_mode'] == 'WGAN-GP':
            self.div_posterior = self.Diverger.forward(self.posterior_latent_code[:, np.newaxis, :])        
            self.div_prior = self.Diverger.forward(self.prior_latent_code[:, np.newaxis, :])
            
            # self.mean_z_divergence = tf.reduce_mean(tf.log(1e-7+self.div_prior))+tf.reduce_mean(tf.log(1e-7+1-self.div_posterior))
            # if self.config['divergence_mode'] == 'NS-GAN': 
            #     self.enc_reg_cost = -tf.reduce_mean(tf.log(1e-7+self.div_posterior))
            # elif self.config['divergence_mode'] == 'GAN': 
            #     self.enc_reg_cost = tf.reduce_mean(tf.log(1e-7+1-self.div_posterior))

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample
        
        #############################################################################
        # OBJECTIVES

        # ## Divergence
        # if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN':
        #     if b_use_timer:
        #         self.div_cost = helper.hardstep((self.epoch-float(starttime))/float(timescale)+0.0001)*(self.config['enc_reg_strength']*(-self.mean_z_divergence))
        #     else:
        #         self.div_cost = self.config['enc_reg_strength']*(-self.mean_z_divergence)

        # ### Encoder
        # b_use_timer, timescale, starttime = True, 10, 10
        # self.OT_primal = self.sample_distance_function(self.input_sample, self.reconst_sample)
        # self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        # if b_use_timer:
        #     self.enc_overall_cost = helper.hardstep((self.epoch-float(starttime))/float(timescale)+0.0001)*(self.config['enc_reg_strength']*self.enc_reg_cost+self.config['enc_inv_MMD_strength']*self.cri_reg_cost)
        # else:
        #     self.enc_overall_cost = (self.config['enc_reg_strength']*self.enc_reg_cost+self.config['enc_inv_MMD_strength']*self.cri_reg_cost)
        # self.enc_cost = self.mean_OT_primal + self.enc_overall_cost
        
        # ### Critic
        # self.cri_cost = self.config['enc_inv_MMD_strength']*self.cri_reg_cost

        # ### Generator
        # self.gen_cost = self.mean_OT_primal

        # b_use_timer, timescale, starttime = True, 5, 5 # MNIST
        b_use_timer, timescale, starttime = True, 100, 5
        # # ## Divergence
        # if self.config['divergence_mode'] == 'GAN' or self.config['divergence_mode'] == 'NS-GAN':
        #     self.div_cost = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_posterior)+1e-7))+tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+1e-7)))
        # if self.config['divergence_mode'] == 'WGAN-GP':
        #     uniform_dist = distributions.UniformDistribution(params = tf.concat([tf.zeros(shape=(self.batch_size_tf, 1)), tf.ones(shape=(self.batch_size_tf, 1))], axis=1))
        #     uniform_w = uniform_dist.sample()
        #     self.trivial_line = uniform_w[:,np.newaxis,:]*self.posterior_latent_code[:,np.newaxis,:]+(1-uniform_w[:,np.newaxis,:])*self.prior_latent_code[:,np.newaxis,:]
        #     self.div_trivial_line = self.Diverger.forward(self.trivial_line)
        #     self.trivial_line_grad = tf.gradients(tf.reduce_sum(self.div_trivial_line), [self.trivial_line])[0]
        #     self.trivial_line_grad_norm = helper.safe_tf_sqrt(tf.reduce_sum(self.trivial_line_grad**2, axis=-1, keep_dims=False)[:,:,np.newaxis])
        #     self.trivial_line_grad_norm_1_penalties = ((self.trivial_line_grad_norm-1)**2)
        #     self.div_reg_cost = tf.reduce_mean(self.trivial_line_grad_norm_1_penalties)
        #     self.div_cost = -(tf.reduce_mean(self.div_posterior)-tf.reduce_mean(self.div_prior))+10*self.div_reg_cost

        # if self.config['divergence_mode'] == 'NS-GAN': 
        #     self.enc_reg_cost = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.div_prior)+1e-7))+1*self.MMD
        # elif self.config['divergence_mode'] == 'GAN': 
        #     self.enc_reg_cost = tf.reduce_mean(tf.log(1-tf.nn.sigmoid(self.div_prior)+1e-7))+1*self.MMD
        # elif self.config['divergence_mode'] == 'WGAN-GP': 
        #     self.enc_reg_cost = -tf.reduce_mean(self.div_prior)+1*self.MMD
        
        self.enc_reg_cost = self.MMD
        self.cri_reg_cost = self.info_cost

        ### Critic
        # self.cri_cost = self.config['enc_inv_MMD_strength']*self.cri_reg_cost
        self.cri_cost = self.cri_reg_cost


        ### Encoder
        self.OT_primal = self.sample_distance_function(self.input_sample, self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)
        if b_use_timer:
            self.enc_overall_cost = helper.hardstep((self.epoch-float(starttime))/float(timescale)+0.0001)*(self.config['enc_reg_strength']*self.enc_reg_cost + self.config['enc_inv_MMD_strength']*self.info_cost)
        else:
            self.enc_overall_cost = self.config['enc_reg_strength']*self.enc_reg_cost + self.config['enc_inv_MMD_strength']*self.info_cost
        self.enc_cost = self.mean_OT_primal + self.enc_overall_cost 

        # if b_use_timer:
        #     self.enc_overall_cost = helper.hardstep((self.epoch-float(starttime))/float(timescale)+0.0001)*(self.config['enc_reg_strength']*self.enc_reg_cost)+self.config['enc_inv_MMD_strength']*self.cri_reg_cost
        # else:
        #     self.enc_overall_cost = (self.config['enc_reg_strength']*self.enc_reg_cost+self.config['enc_inv_MMD_strength']*self.cri_reg_cost)
        # self.enc_cost = self.mean_OT_primal + self.enc_overall_cost
        
        ### Critic
        # self.cri_cost = self.config['enc_inv_MMD_strength']*self.cri_reg_cost

        ### Generator
        self.gen_cost = self.mean_OT_primal
Exemple #6
0
    def inference(self, batch, additional_inputs_tf):
        self.epoch = additional_inputs_tf[0]
        self.b_identity = additional_inputs_tf[1]

        empirical_observed_properties = copy.deepcopy(
            batch['observed']['properties'])
        for e in empirical_observed_properties['flat']:
            e['dist'] = 'dirac'
        for e in empirical_observed_properties['image']:
            e['dist'] = 'dirac'

        self.input_sample = batch['observed']['data']
        self.input_dist = distributions.ProductDistribution(
            sample_properties=empirical_observed_properties,
            params=self.input_sample)

        if not self.bModules: self.generate_modules(batch)
        try:
            self.n_time = batch['observed']['properties']['flat'][0]['size'][1]
        except:
            self.n_time = batch['observed']['properties']['image'][0]['size'][
                1]
        try:
            self.batch_size_tf = tf.shape(self.input_sample['flat'])[0]
        except:
            self.batch_size_tf = tf.shape(self.input_sample['image'])[0]

        #############################################################################
        # GENERATOR

        self.prior_param = self.PriorMap.forward(
            (tf.zeros(shape=(self.batch_size_tf, 1)), ))
        self.prior_dist = distributions.DiagonalGaussianDistribution(
            params=self.prior_param)
        self.prior_latent_code = self.prior_dist.sample()

        self.obs_sample_param = self.Generator.forward(
            self.prior_latent_code[:, np.newaxis, :])
        self.obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.obs_sample_param)
        self.obs_sample = self.obs_sample_dist.sample()

        self.obs_log_pdf = self.obs_sample_dist.log_pdf(self.input_sample)

        if not os.path.exists(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/'):
            os.makedirs(
                str(Path.home()) + '/ExperimentalResults/FixedSamples/')
        if os.path.exists(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) +
                '.npz'):
            np_constant_prior_sample = np.load(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz')
        else:
            np_constant_prior_sample = np.random.normal(
                loc=0.,
                scale=1.,
                size=[400,
                      self.prior_latent_code.get_shape().as_list()[-1]])
            np.save(
                str(Path.home()) +
                '/ExperimentalResults/FixedSamples/np_constant_prior_sample_' +
                str(self.prior_latent_code.get_shape().as_list()[-1]) + '.npz',
                np_constant_prior_sample)

        self.constant_prior_latent_code = tf.constant(
            np.asarray(np_constant_prior_sample), dtype=np.float32)
        self.constant_obs_sample_param = self.Generator.forward(
            self.constant_prior_latent_code[:, np.newaxis, :])
        self.constant_obs_sample_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.constant_obs_sample_param)
        self.constant_obs_sample = self.constant_obs_sample_dist.sample()

        if self.config['n_latent'] == 2:
            grid_scale = 3
            x = np.linspace(-grid_scale, grid_scale, 20)
            y = np.linspace(grid_scale, -grid_scale, 20)
            xv, yv = np.meshgrid(x, y)
            np_constant_prior_grid_sample = np.concatenate(
                (xv.flatten()[:, np.newaxis], yv.flatten()[:, np.newaxis][:]),
                axis=1)

            self.constant_prior_grid_latent_code = tf.constant(
                np.asarray(np_constant_prior_grid_sample), dtype=np.float32)
            self.constant_obs_grid_sample_param = self.Generator.forward(
                self.constant_prior_grid_latent_code[:, np.newaxis, :])
            self.constant_obs_grid_sample_dist = distributions.ProductDistribution(
                sample_properties=batch['observed']['properties'],
                params=self.constant_obs_grid_sample_param)
            self.constant_obs_grid_sample = self.constant_obs_grid_sample_dist.sample(
            )

        #############################################################################
        # ENCODER

        if self.config['encoder_mode'] == 'Gaussian':
            self.epsilon_param = self.EpsilonMap.forward(
                (tf.zeros(shape=(self.batch_size_tf, 1)), ))
            self.epsilon_dist = distributions.DiagonalGaussianDistribution(
                params=self.epsilon_param)
            self.epsilon = self.epsilon_dist.sample()
        else:
            self.epsilon = None

        self.posterior_latent_code_expanded = self.Encoder.forward(
            self.input_sample, noise=self.epsilon)
        self.posterior_latent_code = self.posterior_latent_code_expanded[:,
                                                                         0, :]

        self.reconst_param = self.Generator.forward(
            self.posterior_latent_code_expanded)
        self.reconst_dist = distributions.ProductDistribution(
            sample_properties=batch['observed']['properties'],
            params=self.reconst_param)
        self.reconst_sample = self.reconst_dist.sample(b_mode=True)
        self.reconst_log_pdf = self.reconst_dist.log_pdf(self.input_sample)

        self.interpolated_posterior_latent_code = helper.interpolate_latent_codes(
            self.posterior_latent_code, size=self.batch_size_tf // 2)
        self.interpolated_obs = self.Generator.forward(
            self.interpolated_posterior_latent_code)

        #############################################################################
        # REGULARIZER

        self.reg_target_dist = self.reconst_dist
        self.reg_target_sample = self.reconst_sample
        self.reg_dist = self.reconst_dist
        self.reg_sample = self.reconst_sample

        #############################################################################
        # OBJECTIVES

        ## Encoder
        self.mean_neg_log_pdf = -tf.reduce_mean(self.reconst_log_pdf)

        self.OT_primal = self.sample_distance_function(self.input_sample,
                                                       self.reconst_sample)
        self.mean_OT_primal = tf.reduce_mean(self.OT_primal)

        # overall_cost = self.mean_neg_log_pdf
        timescale, start_time, min_tradeoff = 5, 10, 0.000001
        tradeoff = (1 - 2 * min_tradeoff) * helper.hardstep(
            (self.epoch - float(start_time)) / float(timescale)) + min_tradeoff
        overall_cost = tradeoff * self.mean_neg_log_pdf + (
            1 - tradeoff) * self.mean_OT_primal

        self.enc_cost = overall_cost

        ### Generator
        self.gen_cost = overall_cost