def get_logits(self, ltla=None, time=None, lineage=None): logits = self.posterior.dist( Sites.B1, ltla, None, lineage) * np.arange(0, self.num_time)[make_array(time)].reshape( 1, -1, 1) + self.posterior.dist(Sites.C1, ltla, None, lineage) logits = self._expand_dims(logits) return logits
def _indices(self, shape, *args): """Creates indices for easier access to variables.""" indices = [] for i, arg in enumerate(args): if arg is None: indices.append(np.arange(shape[i])) else: indices.append(make_array(arg)) return np.ix_(*indices)
def get_probabilities(self, ltla=None, time=None, lineage=None): logits = self.get_logits(ltla, time) p = np.exp(logits - logsumexp(logits, -1, keepdims=True)) if lineage is not None: idx = make_array(lineage) else: idx = slice(None) return p[..., idx]
def get_growth_rate_lineage(self, ltla, time=None, lineage=None): p = self.get_probabilities(ltla, time) b1 = self._expand_dims(self.posterior.dist(Sites.B1, ltla), dim=self.TIME_DIM) gr = self.get_growth_rate(ltla, time) gr_lin = gr - np.einsum("mijk,milk->mijl", p, b1) + b1 if lineage is not None: idx = make_array(lineage) else: idx = slice(None) return gr_lin[..., idx]
def get_log_R_lineage(self, ltla=None, time=None, lineage=None): p = self.get_probabilities(ltla, time) # TODO: set this up # b = self.posterior.dist(Sites.B0, lineage) # b1 = np.concatenate([b, np.zeros((b.shape[0], 1))], -1) b1 = self._expand_dims(self.posterior.dist(Sites.B1, ltla), dim=self.TIME_DIM) log_R = self.get_log_R(ltla, time) log_R_lineage = (log_R - (np.einsum("mijk,milk->mijl", p, b1) * self.tau)) + ( b1 * self.tau) if lineage is not None: idx = make_array(lineage) else: idx = slice(None) return log_R_lineage[..., idx]