Esempio n. 1
0
def load(name, chains=None, model=None):
    """Load text database.

    Parameters
    ----------
    name : str
        Path to root directory for text database
    chains : list
        Chains to load. If None, all chains are loaded.
    model : Model
        If None, the model is taken from the `with` context.

    Returns
    -------
    ndarray.Trace instance
    """
    chain_dirs = _get_chain_dirs(name)
    if chains is None:
        chains = list(chain_dirs.keys())

    traces = []
    for chain in chains:
        chain_dir = chain_dirs[chain]
        shape_file = os.path.join(chain_dir, 'shapes.json')
        with open(shape_file, 'r') as sfh:
            shapes = json.load(sfh)
        samples = {}
        for varname, shape in shapes.items():
            var_file = os.path.join(chain_dir, varname + '.txt')
            samples[varname] = np.loadtxt(var_file).reshape(shape)
        trace = NDArray(model=model)
        trace.samples = samples
        trace.chain = chain
        traces.append(trace)
    return base.MultiTrace(traces)
Esempio n. 2
0
def load(name, model=None):
    """Load SQLite database.

    Parameters
    ----------
    name : str
        Path to SQLite database file
    model : Model
        If None, the model is taken from the `with` context.

    Returns
    -------
    A MultiTrace instance
    """
    db = _SQLiteDB(name)
    db.connect()
    varnames = _get_table_list(db.cursor)
    chains = _get_chain_list(db.cursor, varnames[0])

    traces = []
    for chain in chains:
        trace = SQLite(name, model=model)
        trace.varnames = varnames
        trace.chain = chain
        trace._is_setup = True
        trace.db = db  # Share the db with all traces.
        traces.append(trace)
    return base.MultiTrace(traces)
Esempio n. 3
0
def load_trace(directory: str, model=None) -> MultiTrace:
    """Loads a multitrace that has been written to file.

    A the model used for the trace must be passed in, or the command
    must be run in a model context.

    Parameters
    ----------
    directory: str
        Path to a pymc serialized trace
    model: pm.Model (optional)
        Model used to create the trace.  Can also be inferred from context

    Returns
    -------
    pm.Multitrace that was saved in the directory
    """
    warnings.warn(
        "The `load_trace` function will soon be removed."
        "Instead, use `arviz.from_netcdf` to load traces.",
        DeprecationWarning,
    )
    straces = []
    for subdir in glob.glob(os.path.join(directory, "*")):
        if os.path.isdir(subdir):
            straces.append(SerializeNDArray(subdir).load(model))
    if not straces:
        raise TraceDirectoryError("%s is not a PyMC saved chain directory." %
                                  directory)
    return base.MultiTrace(straces)
Esempio n. 4
0
    def test_merge_traces_diff_lengths(self):
        with self.model:
            strace0 = self.backend(self.name)
            strace0.setup(self.draws, 1)
            for i in range(self.draws):
                strace0.record(self.test_point)
            strace0.close()
        mtrace0 = base.MultiTrace([self.strace0])

        with self.model:
            strace1 = self.backend(self.name)
            strace1.setup(2 * self.draws, 1)
            for i in range(2 * self.draws):
                strace1.record(self.test_point)
            strace1.close()
        mtrace1 = base.MultiTrace([strace1])

        with pytest.raises(ValueError):
            base.merge_traces([mtrace0, mtrace1])
Esempio n. 5
0
def load(name,
         chains=None,
         model=None,
         host='localhost',
         port='50070',
         user_name=None):
    '''
	Load text database

	Parameters
	----------
	name : str
		Path to root directory in HDFS for text database without a leading '/'
	chains : list
		Chains to load. If None, all chains are loaded
	model : Model
		If None, the model is taken from the 'with' context
	host : str
		The IP address or hostname of the HDFS namenode. By default,
		it is 'localhost'
	port : str
		The port number for WebHDFS on the namenode. By default, it
		is '50070'
	user_name : str
		WebHDFS user_name used for authentication. By default, it is
		None

	Returns
	-------
	ndarray.Trace instance
	'''
    hdfs = PyWebHdfsClient(host=host, port=port, user_name=user_name)
    chain_dirs = _get_chain_dirs(name, hdfs)
    if chains is None:
        chains = list(chain_dirs.keys())
    traces = []
    for chain in chains:
        chain_dir = chain_dirs[chain]
        dir_path = os.path.join(name, chain_dir)
        shape_file = os.path.join(dir_path, 'shapes.json')
        shapes = json.load(StringIO.StringIO(hdfs.read_file(shape_file)))
        samples = {}
        for varname, shape in shapes.items():
            var_file = os.path.join(dir_path, varname + '.txt')
            samples[varname] = np.loadtxt(
                StringIO.StringIO(str(
                    hdfs.read_file(var_file)))).reshape(shape)
        trace = NDArray(model=model)
        trace.samples = samples
        trace.chain = chain
        traces.append(trace)
    return base.MultiTrace(traces)
Esempio n. 6
0
    def setUpClass(cls):
        cls.test_point, cls.model, _ = models.non_normal(cls.shape)
        with cls.model:
            trace0 = cls.backend(cls.name)
            trace1 = cls.backend(cls.name)

        cls.draws = 5
        trace0.setup(cls.draws, chain=0)
        trace1.setup(cls.draws, chain=1)

        varnames = list(cls.test_point.keys())
        shapes = {
            varname: value.shape
            for varname, value in cls.test_point.items()
        }

        cls.expected = {0: {}, 1: {}}
        for varname in varnames:
            mcmc_shape = (cls.draws, ) + shapes[varname]
            values = np.arange(cls.draws * np.prod(shapes[varname]))
            cls.expected[0][varname] = values.reshape(mcmc_shape)
            cls.expected[1][varname] = values.reshape(mcmc_shape) * 100

        for idx in range(cls.draws):
            point0 = {
                varname: cls.expected[0][varname][idx, ...]
                for varname in varnames
            }
            point1 = {
                varname: cls.expected[1][varname][idx, ...]
                for varname in varnames
            }
            trace0.record(point=point0)
            trace1.record(point=point1)
        trace0.close()
        trace1.close()
        cls.mtrace = base.MultiTrace([trace0, trace1])
Esempio n. 7
0
    def setup_class(cls):
        cls.test_point, cls.model, _ = models.beta_bernoulli(cls.shape)

        if hasattr(cls,
                   "write_partial_chain") and cls.write_partial_chain is True:
            cls.chain_vars = [
                v.tag.value_var for v in cls.model.unobserved_RVs[1:]
            ]
        else:
            cls.chain_vars = [
                v.tag.value_var for v in cls.model.unobserved_RVs
            ]

        with cls.model:
            strace0 = cls.backend(cls.name, vars=cls.chain_vars)
            strace1 = cls.backend(cls.name, vars=cls.chain_vars)

        if not hasattr(cls, "sampler_vars"):
            cls.sampler_vars = None

        cls.draws = 5
        if cls.sampler_vars is not None:
            strace0.setup(cls.draws, chain=0, sampler_vars=cls.sampler_vars)
            strace1.setup(cls.draws, chain=1, sampler_vars=cls.sampler_vars)
        else:
            strace0.setup(cls.draws, chain=0)
            strace1.setup(cls.draws, chain=1)

        varnames = list(cls.test_point.keys())
        shapes = {
            varname: value.shape
            for varname, value in cls.test_point.items()
        }
        dtypes = {
            varname: value.dtype
            for varname, value in cls.test_point.items()
        }

        cls.expected = {0: {}, 1: {}}
        for varname in varnames:
            mcmc_shape = (cls.draws, ) + shapes[varname]
            values = np.arange(cls.draws * np.prod(shapes[varname]),
                               dtype=dtypes[varname])
            cls.expected[0][varname] = values.reshape(mcmc_shape)
            cls.expected[1][varname] = values.reshape(mcmc_shape) * 100

        if cls.sampler_vars is not None:
            cls.expected_stats = {0: [], 1: []}
            for vars in cls.sampler_vars:
                stats = {}
                cls.expected_stats[0].append(stats)
                cls.expected_stats[1].append(stats)
                for key, dtype in vars.items():
                    if dtype == bool:
                        stats[key] = np.zeros(cls.draws, dtype=dtype)
                    else:
                        stats[key] = np.arange(cls.draws, dtype=dtype)

        for idx in range(cls.draws):
            point0 = {
                varname: cls.expected[0][varname][idx, ...]
                for varname in varnames
            }
            point1 = {
                varname: cls.expected[1][varname][idx, ...]
                for varname in varnames
            }
            if cls.sampler_vars is not None:
                stats1 = [{key: val[idx]
                           for key, val in stats.items()}
                          for stats in cls.expected_stats[0]]
                stats2 = [{key: val[idx]
                           for key, val in stats.items()}
                          for stats in cls.expected_stats[1]]
                strace0.record(point=point0, sampler_stats=stats1)
                strace1.record(point=point1, sampler_stats=stats2)
            else:
                strace0.record(point=point0)
                strace1.record(point=point1)
        strace0.close()
        strace1.close()
        cls.mtrace = base.MultiTrace([strace0, strace1])

        cls.stat_dtypes = {}
        cls.stats_counts = collections.Counter()
        for stats in cls.sampler_vars or []:
            cls.stat_dtypes.update(stats)
            cls.stats_counts.update(stats.keys())
Esempio n. 8
0
    def test_merge_traces_nonunique(self):
        mtrace0 = base.MultiTrace([self.trace0])
        mtrace1 = base.MultiTrace([self.trace1])

        self.assertRaises(ValueError, base.merge_traces, [mtrace0, mtrace1])
Esempio n. 9
0
    def test_merge_traces_nonunique(self):
        mtrace0 = base.MultiTrace([self.strace0])
        mtrace1 = base.MultiTrace([self.strace1])

        with pytest.raises(ValueError):
            base.merge_traces([mtrace0, mtrace1])
Esempio n. 10
0
 def test_multitrace_nonunique(self):
     with pytest.raises(ValueError):
         base.MultiTrace([self.strace0, self.strace1])