def test_posterior_getter(raw_posterior): p = Posterior(raw_posterior) a = p["a"] assert isinstance(a, np.ndarray) with pytest.raises(KeyError): p["c"]
def lineage_model(test_posterior): m = Lineage( tau=5.0, cases=test_posterior["cases"], lineages=test_posterior["lineages"], lineage_dates=test_posterior["dates"], population=test_posterior["population"], basis=test_posterior["basis"], posterior=Posterior(test_posterior), ) return m
def independent_clock_reset_model(test_posterior): m = MultiLineage( cases=test_posterior["cases"], lineages=test_posterior["lineages"], lineage_dates=test_posterior["dates"], population=test_posterior["population"], basis=test_posterior["basis"], posterior=Posterior(test_posterior), model_kwargs=dict(num_samples=100), independent_clock=True, ) return m
def __init__( self, tau, cases, lineages, lineage_dates, population, basis=None, auto_correlation=None, linearize=False, ancestor_matrix=None, posterior=None, ): self.tau = tau self.cases = cases self.lineages = lineages self.lineage_dates = lineage_dates self.population = population self.posterior = Posterior( posterior) if posterior is not None else posterior self.auto_correlation = auto_correlation self.linearize = linearize self.ancestor_matrix = ancestor_matrix if basis is None: self.knots = NowCastKnots(cases.shape[-1]) self.B = self.knots.basis # _, self.B = create_spline_basis( # np.arange(cases.shape[1]), # num_knots=int(np.ceil(cases.shape[1] / 10)), # add_intercept=False, # ) else: self.B = basis self.num_ltla = self.cases.shape[0] self.num_time = self.cases.shape[-1] self.num_lin = self.lineages.shape[-1] - 1 self.num_basis = self.B.shape[-1] self.num_ltla_lin = self.nan_idx.shape[0]
def test_posterior_mean(raw_posterior): """Posterior.mean returns the slices irrespective of whether np.arrays, lists or integers are provided.""" p = Posterior(raw_posterior) assert p.mean("a").shape == (5, 5, 5) assert p.mean("a", 1).shape == (1, 5, 5) assert p.mean("a", None, 1).shape == (5, 1, 5) assert p.mean("a", [1, 2]).shape == (2, 5, 5)
def clock_reset_model(test_posterior): test_posterior["t"] = test_posterior["t"][:, 1, :].reshape(100, 1, -1) print(test_posterior["t"].shape) m = MultiLineage( cases=test_posterior["cases"], lineages=test_posterior["lineages"], lineage_dates=test_posterior["dates"], population=test_posterior["population"], basis=test_posterior["basis"], posterior=Posterior(test_posterior), model_kwargs=dict(num_samples=100), independent_clock=False, ) return m
class Lineage(object): EPS = -1e6 SCALE = 100.0 LTLA_DIM = 1 TIME_DIM = 2 LIN_DIM = 3 """ Implements Lineage model helper. """ def __init__( self, tau, cases, lineages, lineage_dates, population, basis=None, auto_correlation=None, linearize=False, ancestor_matrix=None, posterior=None, ): self.tau = tau self.cases = cases self.lineages = lineages self.lineage_dates = lineage_dates self.population = population self.posterior = Posterior( posterior) if posterior is not None else posterior self.auto_correlation = auto_correlation self.linearize = linearize self.ancestor_matrix = ancestor_matrix if basis is None: self.knots = NowCastKnots(cases.shape[-1]) self.B = self.knots.basis # _, self.B = create_spline_basis( # np.arange(cases.shape[1]), # num_knots=int(np.ceil(cases.shape[1] / 10)), # add_intercept=False, # ) else: self.B = basis self.num_ltla = self.cases.shape[0] self.num_time = self.cases.shape[-1] self.num_lin = self.lineages.shape[-1] - 1 self.num_basis = self.B.shape[-1] self.num_ltla_lin = self.nan_idx.shape[0] # Private methods def _expand_dims(self, array, num_dim=4, dim=1): """Soft dim expansion.""" if array.ndim < num_dim: array = np.expand_dims(array, dim) return array def _expand_array(self, array: jnp.ndarray, index, shape: tuple) -> jnp.ndarray: """Creates an a zero array with shape `shape` and fills it with `array` at index.""" expanded_array = jnp.zeros(shape) expanded_array = index_update(expanded_array, index, array) return expanded_array def _pad_array(self, array: jnp.ndarray, func=jnp.zeros): """Adds an additional column to an three dimensional array.""" return jnp.concatenate([ array, func((array.shape[0], *[1 for _ in range(array.ndim - 1)])) ], -1) def _indices(self, shape, *args): """Creates indices for easier access to variables.""" indices = [] for i, arg in enumerate(args): if arg is None: indices.append(np.arange(shape[i])) else: indices.append(make_array(arg)) return np.ix_(*indices) def _is_nan_row(self, array: np.ndarray) -> np.ndarray: """ Helper function to extract the indices of rows (1st dimension) in an array that contains nan values :param array: a numpy array :returns: an array of indices """ return np.where(np.isnan( array.sum(axis=tuple(range(1, array.ndim)))))[0] @property def arma(self): if not hasattr(self, "_arma"): if self.auto_correlation is None: self._arma = jnp.eye(self.num_basis) else: Σ0 = jnp.eye(self.num_basis) for i in range(1, self.num_basis): Σ0 = index_update(Σ0, index[i, i - 1], jnp.array(self.auto_correlation)) if self.linearize: Π0 = jnp.linalg.inv(Σ0) for i in range(self.num_basis - 3, self.num_basis): Π0 = index_update(Π0, index[i, i - 2:i], jnp.array([1, -2])) Π0 = index_update( Π0, index[self.num_basis - 3, self.num_basis - 5:self.num_basis - 3], 0.5 * jnp.array([1, -2]), ) self._arma = jnp.linalg.inv(Π0) else: self._arma = Σ0 return self._arma @property def nan_idx(self): if not hasattr(self, "_nan_idx"): exclude = list( set(self._is_nan_row(self.lineages)) | set(self._is_nan_row(self.cases))) self._nan_idx = np.array( [i for i in range(self.cases.shape[0]) if i not in exclude]) return self._nan_idx @property def missing_lineages(self): if not hasattr(self, "_missing_lineages"): self._missing_lineages = (self.lineages[..., :-1].sum(1) != 0)[self.nan_idx].astype(int) return self._missing_lineages def aggregate(self, region, func, correct_zero_obs=None, *args, **kwargs): agg = [] if correct_zero_obs is not None: lin_obs = (self.lineages.sum(-1) != 0).astype(int) for i in np.sort(np.unique(region)): region_idx = np.where(region == i)[0] region_not_nan = np.isin(region_idx, self.nan_idx) region_idx = region_idx[region_not_nan] aggregate = func(int(region_idx[0]), *args, **kwargs) for r in region_idx[1:]: arr = func(int(r), *args, **kwargs) if correct_zero_obs is not None: j = np.argmax(lin_obs[r]) if j > 0: k = self.lineage_dates[j] arr = np.concatenate( [ correct_zero_obs * np.ones_like(arr)[:, :, :k], arr[:, :, k:], ], 2, ) aggregate += arr agg.append(aggregate) return np.concatenate(agg, 1) def get_logits(self, ltla=None, time=None, lineage=None): logits = self.posterior.dist( Sites.B1, ltla, None, lineage) * np.arange(0, self.num_time)[make_array(time)].reshape( 1, -1, 1) + self.posterior.dist(Sites.C1, ltla, None, lineage) logits = self._expand_dims(logits) return logits def get_probabilities(self, ltla=None, time=None, lineage=None): logits = self.get_logits(ltla, time) p = np.exp(logits - logsumexp(logits, -1, keepdims=True)) if lineage is not None: idx = make_array(lineage) else: idx = slice(None) return p[..., idx] def get_growth_rate(self, ltla=None, time=None): """ Computes the fitted growth rate. :param ltla: indices of the the LTLAs :param time: time indices :returns: the posterior distribution of the incidence with shape (num_samples, num_ltla, num_time, 1) """ beta = self._expand_dims(self.posterior.dist(Sites.BETA1, ltla), self.LIN_DIM) basis = self._expand_dims( self.B[self._indices(self.B.shape, 1, time)].T.squeeze(), self.TIME_DIM) gr = np.einsum("ijk,kl->ijl", beta, basis) gr = self._expand_dims(gr, dim=self.LIN_DIM) return gr def get_growth_rate_lineage(self, ltla, time=None, lineage=None): p = self.get_probabilities(ltla, time) b1 = self._expand_dims(self.posterior.dist(Sites.B1, ltla), dim=self.TIME_DIM) gr = self.get_growth_rate(ltla, time) gr_lin = gr - np.einsum("mijk,milk->mijl", p, b1) + b1 if lineage is not None: idx = make_array(lineage) else: idx = slice(None) return gr_lin[..., idx] def get_lambda(self, ltla=None, time=None): """ Returns the posterior distribution of the incidence. :param ltla: indices of the the LTLAs :param time: time indices :returns: the posterior distribution of the incidence with shape (num_samples, num_ltla, num_time, 1) """ beta = self.posterior.dist(Sites.BETA1, ltla) beta = self._expand_dims(beta, self.LIN_DIM) basis = self._expand_dims( self.B[self._indices(self.B.shape, 0, time)].T.squeeze(), self.TIME_DIM) lamb = self.population[self._indices( self.population.shape, ltla)].reshape(1, -1, 1) * np.exp( np.einsum("ijk,kl->ijl", beta, basis)) lamb = self._expand_dims(lamb, dim=self.LIN_DIM) return lamb def get_lambda_lineage(self, ltla=None, time=None, lineage=None): return self.get_lambda(ltla, time) * self.get_probabilities( ltla, time, lineage) def get_transmissibility(self, rebase=None): b = self.posterior.dist(Sites.B0) b = np.concatenate([b, np.zeros((b.shape[0], 1))], -1) if self.ancestor_matrix is not None: b = np.einsum("ij,lj->li", self.ancestor_matrix, b) if rebase is not None: b = b - b[..., rebase].reshape(-1, 1) return np.exp(b * self.tau) def get_log_R(self, ltla=None, time=None): log_R = (self.posterior.dist(Sites.BETA1, ltla) @ self.B[self._indices( self.B.shape, 1, time)].T.squeeze()) * self.tau if log_R.ndim == 2: if isinstance(time, int): log_R = self._expand_dims(log_R, num_dim=3, dim=self.TIME_DIM) if isinstance(ltla, int): log_R = self._expand_dims(log_R, num_dim=3, dim=self.LIN_DIM) return self._expand_dims(log_R, dim=self.LIN_DIM) def get_log_R_lineage(self, ltla=None, time=None, lineage=None): p = self.get_probabilities(ltla, time) # TODO: set this up # b = self.posterior.dist(Sites.B0, lineage) # b1 = np.concatenate([b, np.zeros((b.shape[0], 1))], -1) b1 = self._expand_dims(self.posterior.dist(Sites.B1, ltla), dim=self.TIME_DIM) log_R = self.get_log_R(ltla, time) log_R_lineage = (log_R - (np.einsum("mijk,milk->mijl", p, b1) * self.tau)) + ( b1 * self.tau) if lineage is not None: idx = make_array(lineage) else: idx = slice(None) return log_R_lineage[..., idx] def get_other_log_R(self, exclude, ltla=None, time=None): p = self.get_probabilities(ltla, time, exclude) b = self.posterior.dist(Sites.B0) b = np.concatenate([b, np.zeros((b.shape[0], 1))], -1) log_R = self.get_log_R(ltla, time) log_R0 = log_R - (b[..., exclude].reshape(-1, 1, 1, 1) * p) * self.tau return self._expand_dims(log_R0, dim=self.LIN_DIM) def aggregate_lambda(self, region, time=None): return self.aggregate(region, self.get_lambda, time) def aggregate_growth_rate(self, region, time=None): return self.aggregate(region, self.get_growth_rate, time) def aggregate_lambda_lineage(self, region, time=None, lineage=None): """ Aggregates lambda lineage by an indicator array. :param region: an indicator array, e.g. np.array([0, 0, 0, 1, 1]) :param time: index array containing indices of time :param lineage: index array containing indices of lineage of interest :return: a numpy array containing aggregated incidence due to each lineage """ return self.aggregate(region, self.get_lambda_lineage, time=time, lineage=lineage) def aggregate_probabilities(self, region, time=None, lineage=None): lambda_lin = self.aggregate_lambda_lineage(region, time=time, lineage=lineage) return lambda_lin / lambda_lin.sum(-1, keepdims=True) def aggregate_log_R(self, region, time=None): lambda_regions = self.aggregate_lambda(region, time=time) def weighted_log_R(ltla, time): return self.get_log_R(ltla, time) * self.get_lambda(ltla, time) agg = self.aggregate(region, weighted_log_R, time=time) return agg / lambda_regions def aggregate_log_R_lineage(self, region, time=None): lambda_regions = self.aggregate_lambda_lineage(region, time=time) def weighted_log_R(ltla, time): return self.get_log_R_lineage( ltla, time) * self.get_lambda_lineage(ltla, time) agg = self.aggregate(region, weighted_log_R, correct_zero_obs=self.EPS, time=time) return agg / lambda_regions def aggregate_growth_rate_lineage(self, region, time=None): lambda_regions = self.aggregate_lambda_lineage(region, time=time) def weighted_growth_rate(ltla, time): return self.get_growth_rate_lineage( ltla, time) * self.get_lambda_lineage(ltla, time) agg = self.aggregate(region, weighted_growth_rate, correct_zero_obs=self.EPS, time=time) return agg / lambda_regions
def test_posterior_dist(raw_posterior): """Posterior.dist returns the slices irrespective of whether np.arrays, lists or integers are provided.""" p = Posterior(raw_posterior) assert p.dist("a").shape == (5, 5, 5, 5) assert p.dist("a", None).shape == (5, 5, 5, 5) assert np.all(p.dist("a", 1, 1, 1).squeeze() == np.array([31, 156, 281, 406, 531])) assert p.dist("a", 1).shape == (5, 1, 5, 5) assert p.dist("b", 1, None, 2).shape == (5, 1, 5, 1) assert p.dist("b", 1, 1, None).shape == (5, 1, 1, 5) assert p.dist("b", [1, 2], None, 2).shape == (5, 2, 5, 1) assert p.dist("b", [1, 2], [1, 2], [1, 2, 3]).shape == (5, 2, 2, 3) assert p.dist( "b", np.array([1, 2]), np.array([1, 2]), np.array([1, 2, 3]) ).shape == (5, 2, 2, 3) assert p.dist("a", 2, np.array([1, 2]), np.array([1, 2, 3])).shape == (5, 1, 2, 3) assert p.dist("b", np.array([1, 2]), 2, np.array([1, 2, 3])).shape == (5, 2, 1, 3) assert p.dist("b", np.array([1, 2]), 2, [1, 2, 3]).shape == (5, 2, 1, 3)