def integrate(cls, integrand, axes): # Check if we need to use the regular integrate if (not isinstance(integrand, fbsData)): return GraphHMM.integrate(integrand, axes) # Need adjusted axes because the relative axes in integrand change as we reduce # over each axis assert isinstance(axes, Iterable) if (len(axes) == 0): return integrand integrand, fbs_axis = (integrand.data, integrand.fbs_axis) assert max(axes) < integrand.ndim axes = np.array(axes) axes[axes < 0] = integrand.ndim + axes[axes < 0] adjusted_axes = np.array(sorted(axes)) - np.arange(len(axes)) for ax in adjusted_axes: integrand = logsumexp(integrand, axis=ax) if (fbs_axis > -1): fbs_axis -= len(adjusted_axes) # assert fbs_axis > -1, adjusted_axes return fbsData(integrand, fbs_axis)
def categoricalLossGrad(self, prediction_logits): # Normalize the logits norm_factor = logsumexp(prediction_logits) prediction_logits_norm = prediction_logits - norm_factor loss = -np.sum(self.current_label_one_hot * prediction_logits_norm) return -np.exp(prediction_logits_norm - norm_factor)
def inheritancePatternPrediction(self, model_parameters, graph=None, mc_samples=10): prediction_logits = self.inheritancePatternLogits( model_parameters, graph=graph, mc_samples=mc_samples) prediction_probs = prediction_logits - logsumexp(prediction_logits) return prediction_probs
def initialProb(self, node): pi = np.copy(self.pi0) if (int(node) in self.possible_latent_states): states = self.possible_latent_states[int(node)] impossible_states = np.setdiff1d(np.arange(pi.shape[-1]), states) for state in impossible_states: pi[state] = np.NINF pi[states] -= logsumexp(pi) return pi
def initialProb(self, node, is_partial_graph_index=False): pi = np.copy(self.pi0) node_full = self.partialGraphIndexToFullGraphIndex( node) if is_partial_graph_index == True else node if (int(node_full) in self.possible_latent_states): states = self.possible_latent_states[int(node_full)] impossible_states = np.setdiff1d(np.arange(pi.shape[-1]), states) for state in impossible_states: pi[impossible_states] = np.NINF pi[states] -= logsumexp(pi) return fbsData(pi, -1)
def recognize(self, y, cond, recognizer_params=None, inheritance_pattern=None): assert recognizer_params is not None assert inheritance_pattern is not None if (inheritance_pattern == 'AD'): mendel_vec = self.recognizeAD(y, cond) elif (inheritance_pattern == 'AR'): mendel_vec = self.recognizeAR(y, cond) else: mendel_vec = self.recognizeXL(y, cond) # "Soft log" the vector mendel_vec[mendel_vec == 1] = 0 mendel_vec[mendel_vec == 0] = -3 sex, age, affected, n_above, n_below, keyword_vec = cond # Age one hot age_one_hot = np.array([1, 0]) if age > 20 else np.array([0, 1]) # n-above and n-below feature: True if either > 0 n_above_one_hot = np.array([1, 0]) if n_above > 0 else np.array([0, 1]) n_above_below_hot = np.array([1, 0]) if n_below > 0 else np.array( [0, 1]) # Take a multilinear combination and linear combination of the data to produce an output W1, b1 = recognizer_params[0] l1 = np.einsum('abcde,a,b,c,d->e', W1, mendel_vec, age_one_hot, n_above_one_hot, n_above_below_hot) + b1 l1 = np.tanh(l1) W2, b2 = recognizer_params[1] l2 = np.einsum( 'ab,b->a', W2, np.hstack([ mendel_vec, age_one_hot, n_above_one_hot, n_above_below_hot, keyword_vec ])) + b2 l2 = np.tanh(l2) W3, b3 = recognizer_params[2] l3 = np.einsum('abc,a,b->c', W3, l1, l2) + b3 l3 = l3 - logsumexp(l3) # Basically use mendellian genetics but with a little difference based on node features return l3 + mendel_vec
def integrate(cls, integrand, axes): # Need adjusted axes because the relative axes in integrand change as we reduce # over each axis assert isinstance(axes, Iterable) if (len(axes) == 0): return integrand assert max(axes) < integrand.ndim axes = np.array(axes) axes[axes < 0] = integrand.ndim + axes[axes < 0] adjusted_axes = np.array(sorted(axes)) - np.arange(len(axes)) for ax in adjusted_axes: integrand = logsumexp(integrand, axis=ax) return integrand
def log_likelihood(self, x, y, cond, generative_params): sex, age, affected, n_above, n_below = cond # This is ok because x is an array of logits last_layer = np.exp(x) last_layer = np.hstack((last_layer, age, n_above, n_below)) for W, b in generative_params[:-1]: last_layer = np.tanh(np.einsum('ij,j->i', W, last_layer) + b) W, b = generative_params[-1] last_layer = np.einsum('ij,j->i', W, last_layer) + b logits = last_layer - logsumexp(last_layer, axis=0) y_one_hot = np.zeros((y.shape[0], self.d_out)) y_one_hot[np.arange(y.shape[0]), y] = 1.0 return np.einsum('ti,i->', y_one_hot, logits)
def inheritancePatternLogits(self, model_parameters, graph=None, mc_samples=10): if (graph is not None): self.updateCurrentGraphAndLabel(graph=graph) ad_emission_params, ar_emission_params, xl_emission_params = model_parameters all_logits = [] for i in range(mc_samples): ad_loss = self.ad_model.opt.marginalLoss(ad_emission_params, 0) ar_loss = self.ar_model.opt.marginalLoss(ar_emission_params, 0) xl_loss = self.xl_model.opt.marginalLoss(xl_emission_params, 0) prediction_logits = np.array([ad_loss, ar_loss, xl_loss]) all_logits.append(prediction_logits) return logsumexp(np.array(all_logits), axis=0)
def reparametrizedSample(cls, params=None, nat_params=None, size=1, temp=0.1, return_log=False): # Use the Gumbel Softmax reparametrization trick # https://arxiv.org/pdf/1611.01144.pdf assert (params is None) ^ (nat_params is None) (n, ) = nat_params if nat_params is not None else cls.standardToNat( *params) g = np.random.gumbel(size=n.shape[0]) p = (n + g) / temp if (return_log == False): unnorm = np.exp(p) return unnorm / np.sum(unnorm) return p - logsumexp(p)
def transitionProb(self, child): parents, parent_order = self.getParents(child, get_order=True) ndim = len(parents) + 1 pi = np.copy(self.pis[ndim]) # If we know the latent state for child, then ensure that we # transition there. Also make sure we're only using the possible # parent latent states!!!! modified = False for parent, order in zip(parents, parent_order): if (int(parent) in self.possible_latent_states): parent_states = self.possible_latent_states[int(parent)] impossible_parent_axes = np.setdiff1d( np.arange(pi.shape[order]), parent_states) index = [slice(0, s) for s in pi.shape] index[order] = impossible_parent_axes pi[tuple(index)] = np.NINF modified = True if (int(child) in self.possible_latent_states): child_states = self.possible_latent_states[int(child)] impossible_child_axes = np.setdiff1d(np.arange(pi.shape[-1]), child_states) pi[..., impossible_child_axes] = np.NINF modified = True if (modified == True): with np.errstate(invalid='ignore'): pi[..., :] -= logsumexp(pi, axis=-1)[..., None] # In case entire rows summed to -inf pi[np.isnan(pi)] = np.NINF # Reshape pi's axes to match parent order assert len(parents) + 1 == pi.ndim assert parent_order.shape[0] == parents.shape[0] pi = np.moveaxis(pi, np.arange(ndim), np.hstack((parent_order, ndim - 1))) return pi