def setUp(self): self.a = {chr(i + 97): list(range(10)) for i in range(5)} self.t = Trace(**self.a) self.mt = Trace(self.a, self.a, self.a) self.real_mt = Trace.from_csv(FULL_PATH + r'/data/south_mvcm_5000', multi=True) self.real_singles = [ Trace.from_csv(FULL_PATH + r'/data/south_mvcm_5000_{}.csv'.format(i)) for i in range(4) ]
def test_validate_names(self): b = self.a.copy() try: bad_names = Trace(self.a, b, self.a, self.a) except KeyError: pass
def test_ordering(self): for ch, alone in zip(self.real_mt.chains, self.real_singles): Trace(ch)._assert_allclose(alone)
def __init__( self, #data parameters Y, X, coordinates, n_samples=1000, n_jobs=1, priors=None, configs=None, starting_values=None, extra_traced_params=None, dmetric='euclidean', correlation_function=nexp, verbose=False, center=True, rescale_dists=True): if center: X = verify_center(X) X = verify_covariates(X) N, p = X.shape Xs = X X = explode(X) self.state = Hashmap(X=X, Y=Y, coordinates=coordinates) self.traced_params = ['Betas', 'Mus', 'T', 'Phi', 'Tau2'] if extra_traced_params is not None: self.traced_params.extend(extra_traced_params) self.trace = Trace(**{param: [] for param in self.traced_params}) st = self.state self.state.correlation_function = correlation_function self.verbose = verbose st.Y = Y st.X = X st.Xs = Xs st.N = N st.p = p st._dmetric = dmetric if isinstance(st._dmetric, str): st.pwds = d.squareform(d.pdist(st.coordinates, metric=st._dmetric)) elif callable(st._dmetric): st.pwds = st._dmetric(st.coordinates) st.max_dist = st.pwds.max() if rescale_dists: st.pwds = st.pwds / st.max_dist st._old_max = st.max_dist st.max_dist = 1. if configs is None: configs = dict() if priors is None: priors = dict() if starting_values is None: starting_values = dict() self._setup_priors(**priors) self._setup_starting_values(**starting_values) self._setup_configs(**configs) self._verbose = verbose self.cycles = 0 if n_samples > 0: try: self.sample(n_samples, n_jobs=n_jobs) except (np.linalg.LinAlgError, ValueError) as e: Warn('Encountered the following LinAlgError. ' 'Model will return for debugging. \n {}'.format(e))