Beispiel #1
0
def load_trace(directory: str, model=None) -> MultiTrace:
    """Loads a multitrace that has been written to file.

    A the model used for the trace must be passed in, or the command
    must be run in a model context.

    Parameters
    ----------
    directory : str
        Path to a pymc3 serialized trace
    model : pm.Model (optional)
        Model used to create the trace.  Can also be inferred from context

    Returns
    -------
    pm.Multitrace that was saved in the directory
    """
    straces = []
    for subdir in glob.glob(os.path.join(directory, '*')):
        if os.path.isdir(subdir):
            straces.append(SerializeNDArray(subdir).load(model))
    if not straces:
        raise TraceDirectoryError("%s is not a PyMC3 saved chain directory." %
                                  directory)
    return base.MultiTrace(straces)
Beispiel #2
0
def load_multitrace(dirname, model=None):
    """
    Load TextChain database.

    Parameters
    ----------
    dirname : str
        Name of directory with files (one per chain)
    model : Model
        If None, the model is taken from the `with` context.

    Returns
    -------
    A :class:`pymc3.backend.base.MultiTrace` instance
    """

    logger.info('Loading multitrace from %s' % dirname)
    files = glob(os.path.join(dirname, 'chain-*.csv'))
    straces = []
    for f in files:
        chain = int(os.path.splitext(f)[0].rsplit('-', 1)[1])
        strace = TextChain(dirname, model=model)
        strace.chain = chain
        strace.filename = f
        straces.append(strace)
    return base.MultiTrace(straces)
Beispiel #3
0
def load_trace(directory: str, model=None) -> MultiTrace:
    """Loads a multitrace that has been written to file.

    A the model used for the trace must be passed in, or the command
    must be run in a model context.

    Parameters
    ----------
    directory: str
        Path to a pymc3 serialized trace
    model: pm.Model (optional)
        Model used to create the trace.  Can also be inferred from context

    Returns
    -------
    pm.Multitrace that was saved in the directory
    """
    warnings.warn(
        "The `load_trace` function will soon be removed."
        "Instead, use `arviz.from_netcdf` to load traces.",
        DeprecationWarning,
    )
    straces = []
    for subdir in glob.glob(os.path.join(directory, "*")):
        if os.path.isdir(subdir):
            straces.append(SerializeNDArray(subdir).load(model))
    if not straces:
        raise TraceDirectoryError("%s is not a PyMC3 saved chain directory." %
                                  directory)
    return base.MultiTrace(straces)
Beispiel #4
0
    def setUpClass(cls):
        cls.test_point, cls.model, _ = models.beta_bernoulli(cls.shape)
        with cls.model:
            strace0 = cls.backend(cls.name)
            strace1 = cls.backend(cls.name)

        cls.draws = 5
        strace0.setup(cls.draws, chain=0)
        strace1.setup(cls.draws, chain=1)

        varnames = list(cls.test_point.keys())
        shapes = {varname: value.shape
                  for varname, value in cls.test_point.items()}
        dtypes = {varname: value.dtype
                  for varname, value in cls.test_point.items()}

        cls.expected = {0: {}, 1: {}}
        for varname in varnames:
            mcmc_shape = (cls.draws,) + shapes[varname]
            values = np.arange(cls.draws * np.prod(shapes[varname]),
                               dtype=dtypes[varname])
            cls.expected[0][varname] = values.reshape(mcmc_shape)
            cls.expected[1][varname] = values.reshape(mcmc_shape) * 100

        for idx in range(cls.draws):
            point0 = {varname: cls.expected[0][varname][idx, ...]
                      for varname in varnames}
            point1 = {varname: cls.expected[1][varname][idx, ...]
                      for varname in varnames}
            strace0.record(point=point0)
            strace1.record(point=point1)
        strace0.close()
        strace1.close()
        cls.mtrace = base.MultiTrace([strace0, strace1])
Beispiel #5
0
def load(con_str, model=None):
    """Load ODBC database.

    Parameters
    ----------
    con_str : str
        ODBC Connection string including database
    model : Model
        If None, the model is taken from the `with` context.

    Returns
    -------
    A MultiTrace instance
    """
    db = _ODBCDB(con_str)
    db.connect()
    name = _get_db_name(con_str)
    varnames = _get_table_list(db.cursor)
    if len(varnames) == 0:
        raise ValueError(('Can not get variable list for database'
                          '`{}`'.format(name)))
    chains = _get_chain_list(db.cursor, varnames[0])

    print(chains)
    straces = []
    for chain in chains:
        strace = ODBC(con_str, model=model)
        strace.chain = chain
        strace._var_cols = {varname: ttab.create_flat_names('v', shape)
                            for varname, shape in strace.var_shapes.items()}
        strace._is_setup = True
        strace.db = db  # Share the db with all traces.
        straces.append(strace)
    return base.MultiTrace(straces)
Beispiel #6
0
    def test_merge_traces_diff_lengths(self):
        with self.model:
            strace0 = self.backend(self.name)
            strace0.setup(self.draws, 1)
            for i in range(self.draws):
                strace0.record(self.test_point)
            strace0.close()
        mtrace0 = base.MultiTrace([self.strace0])

        with self.model:
            strace1 = self.backend(self.name)
            strace1.setup(2 * self.draws, 1)
            for i in range(2 * self.draws):
                strace1.record(self.test_point)
            strace1.close()
        mtrace1 = base.MultiTrace([strace1])

        with pytest.raises(ValueError):
            base.merge_traces([mtrace0, mtrace1])
Beispiel #7
0
def load_multitrace(dirname, varnames=None, chains=None):
    """
    Load TextChain database.

    Parameters
    ----------
    dirname : str
        Name of directory with files (one per chain)
    varnames : list
        of strings with variable names
    chains : list optional

    Returns
    -------
    A :class:`pymc3.backend.base.MultiTrace` instance
    """

    if not istransd(varnames)[0]:
        logger.info('Loading multitrace from %s' % dirname)
        if chains is None:
            files = glob(os.path.join(dirname, 'chain-*.csv'))
            chains = [
                int(
                    os.path.splitext(os.path.basename(f))[0].replace(
                        'chain-', '')) for f in files
            ]

            final_chain = -1
            if final_chain in chains:
                idx = chains.index(final_chain)
                files.pop(idx)
                chains.pop(idx)
        else:
            files = [
                os.path.join(dirname, 'chain-%i.csv' % chain)
                for chain in chains
            ]
            for f in files:
                if not os.path.exists(f):
                    raise IOError('File %s does not exist! Please run:'
                                  ' "beat summarize <project_dir>"!' % f)

        straces = []
        for chain, f in zip(chains, files):
            strace = TextChain(dirname)
            strace.chain = chain
            strace.filename = f
            straces.append(strace)
        return base.MultiTrace(straces)
    else:
        logger.info('Loading trans-d trace from %s' % dirname)
        raise NotImplementedError('Loading trans-d trace is not implemented!')
Beispiel #8
0
def load_multitrace(dirname, varnames=None, chains=None):
    """
    Load TextChain database.

    Parameters
    ----------
    dirname : str
        Name of directory with files (one per chain)
    model : Model
        If None, the model is taken from the `with` context.
    chains : list optional

    Returns
    -------
    A :class:`pymc3.backend.base.MultiTrace` instance
    """

    if not istransd(varnames)[0]:
        logger.info('Loading multitrace from %s' % dirname)
        if chains is None:
            files = glob(os.path.join(dirname, 'chain-*.csv'))
            chains = list(
                set([
                    int(os.path.splitext(f)[0].rsplit('-', 1)[1])
                    for f in files
                ]))
        else:
            files = [
                os.path.join(dirname, 'chain-%i.csv' % chain)
                for chain in chains
            ]
            for f in files:
                if not os.path.exists(f):
                    raise IOError('File %s does not exist! Please run:'
                                  ' "beat summarize <project_dir>"!' % f)

        straces = []
        for chain, f in zip(chains, files):
            strace = TextChain(dirname)
            strace.chain = chain
            strace.filename = f
            straces.append(strace)
        return base.MultiTrace(straces)
    else:
        logger.info('Loading trans-d trace from %s' % dirname)
        raise NotImplementedError('Loading trans-d trace is not implemented!')
Beispiel #9
0
    def setup_class(cls):
        cls.test_point, cls.model, _ = models.beta_bernoulli(cls.shape)

        if hasattr(cls, "write_partial_chain") and cls.write_partial_chain is True:
            cls.chain_vars = cls.model.unobserved_RVs[1:]
        else:
            cls.chain_vars = cls.model.unobserved_RVs

        with cls.model:
            strace0 = cls.backend(cls.name, vars=cls.chain_vars)
            strace1 = cls.backend(cls.name, vars=cls.chain_vars)

        if not hasattr(cls, "sampler_vars"):
            cls.sampler_vars = None

        cls.draws = 5
        if cls.sampler_vars is not None:
            strace0.setup(cls.draws, chain=0, sampler_vars=cls.sampler_vars)
            strace1.setup(cls.draws, chain=1, sampler_vars=cls.sampler_vars)
        else:
            strace0.setup(cls.draws, chain=0)
            strace1.setup(cls.draws, chain=1)

        varnames = list(cls.test_point.keys())
        shapes = {varname: value.shape for varname, value in cls.test_point.items()}
        dtypes = {varname: value.dtype for varname, value in cls.test_point.items()}

        cls.expected = {0: {}, 1: {}}
        for varname in varnames:
            mcmc_shape = (cls.draws,) + shapes[varname]
            values = np.arange(cls.draws * np.prod(shapes[varname]), dtype=dtypes[varname])
            cls.expected[0][varname] = values.reshape(mcmc_shape)
            cls.expected[1][varname] = values.reshape(mcmc_shape) * 100

        if cls.sampler_vars is not None:
            cls.expected_stats = {0: [], 1: []}
            for vars in cls.sampler_vars:
                stats = {}
                cls.expected_stats[0].append(stats)
                cls.expected_stats[1].append(stats)
                for key, dtype in vars.items():
                    if dtype == np.bool:
                        stats[key] = np.zeros(cls.draws, dtype=dtype)
                    else:
                        stats[key] = np.arange(cls.draws, dtype=dtype)

        for idx in range(cls.draws):
            point0 = {varname: cls.expected[0][varname][idx, ...] for varname in varnames}
            point1 = {varname: cls.expected[1][varname][idx, ...] for varname in varnames}
            if cls.sampler_vars is not None:
                stats1 = [
                    {key: val[idx] for key, val in stats.items()} for stats in cls.expected_stats[0]
                ]
                stats2 = [
                    {key: val[idx] for key, val in stats.items()} for stats in cls.expected_stats[1]
                ]
                strace0.record(point=point0, sampler_stats=stats1)
                strace1.record(point=point1, sampler_stats=stats2)
            else:
                strace0.record(point=point0)
                strace1.record(point=point1)
        strace0.close()
        strace1.close()
        cls.mtrace = base.MultiTrace([strace0, strace1])

        cls.stat_dtypes = {}
        cls.stats_counts = collections.Counter()
        for stats in cls.sampler_vars or []:
            cls.stat_dtypes.update(stats)
            cls.stats_counts.update(stats.keys())
    def test_merge_traces_nonunique(self):
        mtrace0 = base.MultiTrace([self.strace0])
        mtrace1 = base.MultiTrace([self.strace1])

        with pytest.raises(ValueError):
            base.merge_traces([mtrace0, mtrace1])
 def test_multitrace_nonunique(self):
     with pytest.raises(ValueError):
         base.MultiTrace([self.strace0, self.strace1])
Beispiel #12
0
    def test_merge_traces_nonunique(self):
        mtrace0 = base.MultiTrace([self.strace0])
        mtrace1 = base.MultiTrace([self.strace1])

        self.assertRaises(ValueError,
                          base.merge_traces, [mtrace0, mtrace1])