Esempio n. 1
0
    def setUp(self):

        self.a = {chr(i + 97): list(range(10)) for i in range(5)}

        self.t = Trace(**self.a)
        self.mt = Trace(self.a, self.a, self.a)
        self.real_mt = Trace.from_csv(FULL_PATH + r'/data/south_mvcm_5000',
                                      multi=True)
        self.real_singles = [
            Trace.from_csv(FULL_PATH +
                           r'/data/south_mvcm_5000_{}.csv'.format(i))
            for i in range(4)
        ]
Esempio n. 2
0
 def test_validate_names(self):
     b = self.a.copy()
     try:
         bad_names = Trace(self.a, b, self.a, self.a)
     except KeyError:
         pass
Esempio n. 3
0
 def test_ordering(self):
     for ch, alone in zip(self.real_mt.chains, self.real_singles):
         Trace(ch)._assert_allclose(alone)
Esempio n. 4
0
    def __init__(
            self,
            #data parameters
            Y,
            X,
            coordinates,
            n_samples=1000,
            n_jobs=1,
            priors=None,
            configs=None,
            starting_values=None,
            extra_traced_params=None,
            dmetric='euclidean',
            correlation_function=nexp,
            verbose=False,
            center=True,
            rescale_dists=True):
        if center:
            X = verify_center(X)
        X = verify_covariates(X)

        N, p = X.shape
        Xs = X

        X = explode(X)
        self.state = Hashmap(X=X, Y=Y, coordinates=coordinates)
        self.traced_params = ['Betas', 'Mus', 'T', 'Phi', 'Tau2']
        if extra_traced_params is not None:
            self.traced_params.extend(extra_traced_params)
        self.trace = Trace(**{param: [] for param in self.traced_params})
        st = self.state
        self.state.correlation_function = correlation_function
        self.verbose = verbose

        st.Y = Y
        st.X = X
        st.Xs = Xs
        st.N = N
        st.p = p

        st._dmetric = dmetric
        if isinstance(st._dmetric, str):
            st.pwds = d.squareform(d.pdist(st.coordinates, metric=st._dmetric))
        elif callable(st._dmetric):
            st.pwds = st._dmetric(st.coordinates)

        st.max_dist = st.pwds.max()
        if rescale_dists:
            st.pwds = st.pwds / st.max_dist
            st._old_max = st.max_dist
            st.max_dist = 1.

        if configs is None:
            configs = dict()
        if priors is None:
            priors = dict()
        if starting_values is None:
            starting_values = dict()

        self._setup_priors(**priors)
        self._setup_starting_values(**starting_values)
        self._setup_configs(**configs)

        self._verbose = verbose
        self.cycles = 0

        if n_samples > 0:
            try:
                self.sample(n_samples, n_jobs=n_jobs)
            except (np.linalg.LinAlgError, ValueError) as e:
                Warn('Encountered the following LinAlgError. '
                     'Model will return for debugging. \n {}'.format(e))