def kernel_params_log_prob(param_0bar, param_1bar): param_0 = logit(param_0bar, nan_replace=self.params.kernel_params.prior[0].b) param_1 = logit(param_1bar, nan_replace=self.params.kernel_params.prior[1].b) new_prob = tf.reduce_sum(self.params.latents.prior( all_states[self.state_indices['latents']], param_0bar, param_1bar)) new_prob += self.params.kernel_params.prior[0].log_prob(param_0) new_prob += self.params.kernel_params.prior[1].log_prob(param_1) return tf.reduce_sum(new_prob)
def normal_sampler_fn(seed): p1, p2 = all_states[self.state_indices['kernel_params']] m, K = self.kernel_selector()(logit(p1), logit(p2)) m = tf.zeros((self.num_replicates, self.num_tfs, self.N_p), dtype='float64') K = tf.stack([K for _ in range(3)], axis=0) jitter = tf.linalg.diag(1e-8 * tf.ones(self.N_p, dtype='float64')) z = tfd.MultivariateNormalTriL( loc=m, scale_tril=tf.linalg.cholesky(K + jitter)).sample(seed=seed) # tf.print(z) return z
def results(self, burnin=0): Δ = σ2_f = k_fbar = None σ2_m = self.samples[self.state_indices['σ2_m']][burnin:] if self.options.preprocessing_variance: σ2_m = logit(σ2_m) else: σ2_f = self.samples[self.state_indices['σ2_f']][burnin:] nuts_index = 0 kbar = self.samples[self.state_indices['kinetics']][nuts_index].numpy()[burnin:] fbar = self.samples[self.state_indices['latents']] if self.options.translation: nuts_index += 1 k_fbar = self.samples[self.state_indices['kinetics']][nuts_index].numpy()[burnin:] if k_fbar.ndim < 3: k_fbar = np.expand_dims(k_fbar, 2) if not self.options.joint_latent: kernel_params = self.samples[self.state_indices['kernel_params']][burnin:] else: kernel_params = [fbar[1][burnin:], fbar[2][burnin:]] fbar = fbar[0][burnin:] wbar = tf.stack([logistic(1*tf.ones((self.num_genes, self.num_tfs), dtype='float64')) for _ in range(fbar.shape[0])], axis=0) w_0bar = tf.stack([0.5*tf.ones(self.num_genes, dtype='float64') for _ in range(fbar.shape[0])], axis=0) if self.options.weights: nuts_index += 1 wbar = self.samples[self.state_indices['kinetics']][nuts_index][burnin:] w_0bar = self.samples[self.state_indices['kinetics']][nuts_index+1][burnin:] if self.options.delays: Δ = self.samples[self.state_indices['Δ']][burnin:] return SampleResults(self.options, fbar, kbar, k_fbar, Δ, kernel_params, wbar, w_0bar, σ2_m, σ2_f)
def kbar_log_prob_fn(*args): #kbar, k_fbar, wbar, w_0bar index = 0 kbar = args[index] new_prob = 0 k_m =logit(kbar) if self.options.kinetic_exponential: k_m = tf.exp(k_m) # tf.print(k_m) lik_args = {'kbar': kbar} new_prob += tf.reduce_sum(self.params.kinetics.prior[index].log_prob(k_m)) # tf.print('kbar', new_prob) if options.translation: index += 1 k_fbar = args[index] lik_args['k_fbar'] = k_fbar kfprob = tf.reduce_sum(self.params.kinetics.prior[index].log_prob(logit(k_fbar))) new_prob += kfprob if options.weights: index += 1 wbar = args[index] w_0bar = args[index+1] new_prob += tf.reduce_sum(self.params.weights.prior[0].log_prob((wbar))) new_prob += tf.reduce_sum(self.params.weights.prior[1].log_prob((w_0bar))) lik_args['wbar'] = wbar lik_args['w_0bar'] = w_0bar new_prob += tf.reduce_sum(self.likelihood.genes( all_states=all_states, state_indices=self.state_indices, **lik_args )) return tf.reduce_sum(new_prob)
def calculate_protein(self, fbar, k_fbar, Δ): # Calculate p_i vector τ = self.data.τ f_i = inverse_positivity(fbar) δ_i = tf.reshape(logit(k_fbar), (-1, 1)) if self.options.delays: # Add delay Δ = tf.cast(Δ, 'int32') for r in range(self.num_replicates): f_ir = rotate(f_i[r], -Δ) mask = ~tf.sequence_mask(Δ, f_i.shape[2]) f_ir = tf.where(mask, f_ir, 0) mask = np.zeros((self.num_replicates, 1, 1), dtype='float64') mask[r] = 1 f_i = (1 - mask) * f_i + mask * f_ir # Approximate integral (trapezoid rule) resolution = τ[1] - τ[0] sum_term = tfm.multiply(tfm.exp(δ_i * τ), f_i) cumsum = 0.5 * resolution * tfm.cumsum( sum_term[:, :, :-1] + sum_term[:, :, 1:], axis=2) integrals = tf.concat([ tf.zeros((self.num_replicates, self.num_tfs, 1), dtype='float64'), cumsum ], axis=2) exp_δt = tfm.exp(-δ_i * τ) p_i = exp_δt * integrals return p_i
def σ2_m_log_prob_fn(σ2_mstar): # tf.print('starr:', logit(σ2_mstar)) new_prob = self.likelihood.genes( all_states=all_states, state_indices=self.state_indices, σ2_m=σ2_mstar ) + self.params.σ2_m.prior.log_prob(logit(σ2_mstar)) # tf.print('prob', tf.reduce_sum(new_prob)) return tf.reduce_sum(new_prob)
def predict_m(self, kbar, k_fbar, wbar, fbar, w_0bar, Δ): # Take relevant parameters out of log-space if self.options.kinetic_exponential: kin = (tf.reshape(tf.exp(logit(kbar[:, i])), (-1, 1)) for i in range(kbar.shape[1])) else: kin = (tf.reshape(logit(kbar[:, i]), (-1, 1)) for i in range(kbar.shape[1])) if self.options.initial_conditions: a_j, b_j, d_j, s_j = kin else: b_j, d_j, s_j = kin w = (wbar) w_0 = tf.reshape((w_0bar), (-1, 1)) τ = self.data.τ N_p = self.data.τ.shape[0] p_i = inverse_positivity(fbar) if self.options.translation: p_i = self.calculate_protein(fbar, k_fbar, Δ) # Calculate m_pred resolution = τ[1] - τ[0] interactions = tf.matmul(w, tfm.log(p_i + 1e-100)) + w_0 G = tfm.sigmoid(interactions) # TF Activation Function (sigmoid) sum_term = G * tfm.exp(d_j * τ) integrals = tf.concat( [ tf.zeros((self.num_replicates, self.num_genes, 1), dtype='float64'), # Trapezoid rule 0.5 * resolution * tfm.cumsum(sum_term[:, :, :-1] + sum_term[:, :, 1:], axis=2) ], axis=2) exp_dt = tfm.exp(-d_j * τ) integrals = tfm.multiply(exp_dt, integrals) m_pred = b_j / d_j + s_j * integrals if self.options.initial_conditions: m_pred += tfm.multiply((a_j - b_j / d_j), exp_dt) return m_pred
def _genes(self, fbar, kbar, k_fbar, wbar, w_0bar, σ2_m, Δ): m_pred = self.predict_m(kbar, k_fbar, wbar, fbar, w_0bar, Δ) sq_diff = tfm.square(self.data.m_obs - tf.transpose( tf.gather(tf.transpose(m_pred), self.data.common_indices))) variance = tf.reshape(σ2_m, (-1, 1)) if self.preprocessing_variance: variance = logit( variance) + self.data.σ2_m_pre # add PUMA variance log_lik = -0.5 * tfm.log(2 * PI * variance) - 0.5 * sq_diff / variance log_lik = tf.reduce_sum(log_lik) return log_lik
def k_f(self): if self.k_fbar is None: return None return logit(self.k_fbar).numpy()
def k(self): ret = logit(self.kbar).numpy() if self.options.kinetic_exponential: return np.exp(ret) return ret