Пример #1
0
 def _init_distribution(conditions, **kwargs):
     rate = conditions["rate"]
     return tfd.Exponential(rate=rate, **kwargs)
Пример #2
0
    def __init__(self, data: DataHolder, options: Options):
        self.data = data
        self.samples = None
        self.N_p = data.τ.shape[0]
        self.N_m = data.t.shape[0]      # Number of observations

        self.num_tfs = data.f_obs.shape[1] # Number of TFs
        self.num_genes = data.m_obs.shape[1]
        self.num_replicates = data.m_obs.shape[0]

        self.likelihood = TranscriptionLikelihood(data, options)
        self.options = options
        self.kernel_selector = GPKernelSelector(data, options)

        self.state_indices = {}
        step_sizes = self.options.initial_step_sizes
        logistic_step_size = step_sizes['nuts'] if 'nuts' in step_sizes else 0.00001


        # Latent function & GP hyperparameters
        kernel_initial = self.kernel_selector.initial_params()

        f_step_size = step_sizes['latents'] if 'latents' in step_sizes else 20
        latents_kernel = LatentKernel(data, options, self.likelihood, 
                                      self.kernel_selector,
                                      self.state_indices,
                                      f_step_size*tf.ones(self.N_p, dtype='float64'))
        latents_initial = 0.3*tf.ones((self.num_replicates, self.num_tfs, self.N_p), dtype='float64')
        if self.options.joint_latent:
            latents_initial = [latents_initial, *kernel_initial]
        latents = KernelParameter('latents', self.fbar_prior, latents_initial,
                                kernel=latents_kernel, requires_all_states=False)

        # White noise for genes
        if not options.preprocessing_variance:
            def m_sq_diff_fn(all_states):
                fbar, k_fbar, kbar, wbar, w_0bar, σ2_m, Δ = self.likelihood.get_parameters_from_state(all_states, self.state_indices)
                m_pred = self.likelihood.predict_m(kbar, k_fbar, wbar, fbar, w_0bar, Δ)
                sq_diff = tfm.square(self.data.m_obs - tf.transpose(tf.gather(tf.transpose(m_pred),self.data.common_indices)))
                return tf.reduce_sum(sq_diff, axis=0)

            σ2_m_kernel = GibbsKernel(data, options, self.likelihood, tfd.InverseGamma(f64(0.01), f64(0.01)), 
                                      self.state_indices, m_sq_diff_fn)
            σ2_m = KernelParameter('σ2_m', None, 1e-3*tf.ones((self.num_genes, 1), dtype='float64'), kernel=σ2_m_kernel)
        else:
            def σ2_m_log_prob(all_states):
                def σ2_m_log_prob_fn(σ2_mstar):
                    # tf.print('starr:', logit(σ2_mstar))
                    new_prob = self.likelihood.genes(
                        all_states=all_states, 
                        state_indices=self.state_indices,
                        σ2_m=σ2_mstar 
                    ) + self.params.σ2_m.prior.log_prob(logit(σ2_mstar))
                    # tf.print('prob', tf.reduce_sum(new_prob))
                    return tf.reduce_sum(new_prob)                
                return σ2_m_log_prob_fn
            σ2_m = KernelParameter('σ2_m', LogisticNormal(f64(1e-5), f64(1e-2)), # f64(max(np.var(data.f_obs, axis=1)))
                            logistic(f64(5e-3))*tf.ones(self.num_genes, dtype='float64'), 
                            hmc_log_prob=σ2_m_log_prob, requires_all_states=True, step_size=logistic_step_size)
        kernel_params = None
        if not self.options.joint_latent:
            # GP kernel
            def kernel_params_log_prob(all_states):
                def kernel_params_log_prob(param_0bar, param_1bar):
                    param_0 = logit(param_0bar, nan_replace=self.params.kernel_params.prior[0].b)
                    param_1 = logit(param_1bar, nan_replace=self.params.kernel_params.prior[1].b)
                    new_prob = tf.reduce_sum(self.params.latents.prior(
                                all_states[self.state_indices['latents']], param_0bar, param_1bar))
                    new_prob += self.params.kernel_params.prior[0].log_prob(param_0)
                    new_prob += self.params.kernel_params.prior[1].log_prob(param_1)
                    return tf.reduce_sum(new_prob)
                return kernel_params_log_prob

            kernel_initial = self.kernel_selector.initial_params()
            kernel_ranges = self.kernel_selector.ranges()
            kernel_params = KernelParameter('kernel_params', 
                        [LogisticNormal(*kernel_ranges[0]), LogisticNormal(*kernel_ranges[1])],
                        [logistic(k) for k in kernel_initial], 
                        step_size=0.1*logistic_step_size, hmc_log_prob=kernel_params_log_prob, requires_all_states=True)

        
        # Kinetic parameters & Interaction weights
        w_prior = LogisticNormal(f64(-2), f64(2))
        w_initial = logistic(1*tf.ones((self.num_genes, self.num_tfs), dtype='float64'))
        w_0_prior = LogisticNormal(f64(-0.8), f64(0.8))
        w_0_initial = logistic(0*tf.ones(self.num_genes, dtype='float64'))
        def weights_log_prob(all_states):
            def weights_log_prob_fn(wbar, w_0bar):
                # tf.print((wbar))
                new_prob = tf.reduce_sum(self.params.weights.prior[0].log_prob((wbar))) 
                new_prob += tf.reduce_sum(self.params.weights.prior[1].log_prob((w_0bar)))
                
                new_prob += tf.reduce_sum(self.likelihood.genes(
                    all_states=all_states,
                    state_indices=self.state_indices,
                    wbar=wbar,
                    w_0bar=w_0bar
                ))
                # tf.print(new_prob)
                return new_prob
            return weights_log_prob_fn
        weights_kernel = RWMWrapperKernel(weights_log_prob, 
                new_state_fn=tfp.mcmc.random_walk_normal_fn(scale=0.08))
        weights = KernelParameter(
            'weights', [w_prior, w_0_prior], [w_initial, w_0_initial],
            hmc_log_prob=weights_log_prob, step_size=10*logistic_step_size, requires_all_states=True)
            #TODO kernel=weights_kernel

        num_kin = 4 if self.options.initial_conditions else 3
        kbar_initial = 0.8*tf.ones((self.num_genes, num_kin), dtype='float64')

        def kbar_log_prob(all_states):
            def kbar_log_prob_fn(*args): #kbar, k_fbar, wbar, w_0bar
                index = 0
                kbar = args[index]
                new_prob = 0
                k_m =logit(kbar)
                if self.options.kinetic_exponential:
                    k_m = tf.exp(k_m)
                # tf.print(k_m)
                lik_args = {'kbar': kbar}
                new_prob += tf.reduce_sum(self.params.kinetics.prior[index].log_prob(k_m))
                # tf.print('kbar', new_prob)
                if options.translation:
                    index += 1
                    k_fbar = args[index]
                    lik_args['k_fbar'] = k_fbar
                    kfprob = tf.reduce_sum(self.params.kinetics.prior[index].log_prob(logit(k_fbar)))
                    new_prob += kfprob
                if options.weights:
                    index += 1
                    wbar = args[index]
                    w_0bar = args[index+1]
                    new_prob += tf.reduce_sum(self.params.weights.prior[0].log_prob((wbar))) 
                    new_prob += tf.reduce_sum(self.params.weights.prior[1].log_prob((w_0bar)))
                    lik_args['wbar'] = wbar
                    lik_args['w_0bar'] = w_0bar
                new_prob += tf.reduce_sum(self.likelihood.genes(
                    all_states=all_states,
                    state_indices=self.state_indices,
                    **lik_args
                ))
                return tf.reduce_sum(new_prob)
            return kbar_log_prob_fn


        k_fbar_initial = 0.8*tf.ones((self.num_tfs,), dtype='float64')

        kinetics_initial = [kbar_initial]
        kinetics_priors = [LogisticNormal(0.01, 30)]
        if options.translation:
            kinetics_initial += [k_fbar_initial]
            kinetics_priors += [LogisticNormal(0.1, 7)]
        if options.weights:
            kinetics_initial += [w_initial, w_0_initial]
        kinetics = KernelParameter(
            'kinetics', 
            kinetics_priors, 
            kinetics_initial,
            hmc_log_prob=kbar_log_prob, step_size=logistic_step_size, requires_all_states=True)


        delta_kernel = DelayKernel(self.likelihood, 0, 10, self.state_indices, tfd.Exponential(f64(0.3)))
        Δ = KernelParameter('Δ', tfd.InverseGamma(f64(0.01), f64(0.01)), 0.6*tf.ones(self.num_tfs, dtype='float64'),
                        kernel=delta_kernel, requires_all_states=False)
        
        σ2_f = None
        if not options.preprocessing_variance:
            def f_sq_diff_fn(all_states):
                f_pred = inverse_positivity(all_states[self.state_indices['latents']][0])
                sq_diff = tfm.square(self.data.f_obs - tf.transpose(tf.gather(tf.transpose(f_pred),self.data.common_indices)))
                return tf.reduce_sum(sq_diff, axis=0)
            kernel = GibbsKernel(data, options, self.likelihood, tfd.InverseGamma(f64(0.01), f64(0.01)), 
                                 self.state_indices, f_sq_diff_fn)
            σ2_f = KernelParameter('σ2_f', None, 1e-4*tf.ones((self.num_tfs,1), dtype='float64'), kernel=kernel)
        
        self.params = Params(latents, weights, kinetics, Δ, kernel_params, σ2_m, σ2_f)
        
        self.active_params = [
            self.params.kinetics,
            self.params.latents,
            self.params.σ2_m,
        ]
        # if options.weights:
        #     self.active_params += [self.params.weights]
        if not options.joint_latent:
            self.active_params += [self.params.kernel_params]
        if not options.preprocessing_variance:
            self.active_params += [self.params.σ2_f]
        if options.delays:
            self.active_params += [self.params.Δ]

        self.state_indices.update({
            param.name: i for i, param in enumerate(self.active_params)
        })
Пример #3
0
 def _base_dist(self, lam: TensorLike, *args, **kwargs):
     return tfd.Exponential(rate=lam)