def load_trace(directory: str, model=None) -> MultiTrace: """Loads a multitrace that has been written to file. A the model used for the trace must be passed in, or the command must be run in a model context. Parameters ---------- directory : str Path to a pymc3 serialized trace model : pm.Model (optional) Model used to create the trace. Can also be inferred from context Returns ------- pm.Multitrace that was saved in the directory """ straces = [] for subdir in glob.glob(os.path.join(directory, '*')): if os.path.isdir(subdir): straces.append(SerializeNDArray(subdir).load(model)) if not straces: raise TraceDirectoryError("%s is not a PyMC3 saved chain directory." % directory) return base.MultiTrace(straces)
def load_multitrace(dirname, model=None): """ Load TextChain database. Parameters ---------- dirname : str Name of directory with files (one per chain) model : Model If None, the model is taken from the `with` context. Returns ------- A :class:`pymc3.backend.base.MultiTrace` instance """ logger.info('Loading multitrace from %s' % dirname) files = glob(os.path.join(dirname, 'chain-*.csv')) straces = [] for f in files: chain = int(os.path.splitext(f)[0].rsplit('-', 1)[1]) strace = TextChain(dirname, model=model) strace.chain = chain strace.filename = f straces.append(strace) return base.MultiTrace(straces)
def load_trace(directory: str, model=None) -> MultiTrace: """Loads a multitrace that has been written to file. A the model used for the trace must be passed in, or the command must be run in a model context. Parameters ---------- directory: str Path to a pymc3 serialized trace model: pm.Model (optional) Model used to create the trace. Can also be inferred from context Returns ------- pm.Multitrace that was saved in the directory """ warnings.warn( "The `load_trace` function will soon be removed." "Instead, use `arviz.from_netcdf` to load traces.", DeprecationWarning, ) straces = [] for subdir in glob.glob(os.path.join(directory, "*")): if os.path.isdir(subdir): straces.append(SerializeNDArray(subdir).load(model)) if not straces: raise TraceDirectoryError("%s is not a PyMC3 saved chain directory." % directory) return base.MultiTrace(straces)
def setUpClass(cls): cls.test_point, cls.model, _ = models.beta_bernoulli(cls.shape) with cls.model: strace0 = cls.backend(cls.name) strace1 = cls.backend(cls.name) cls.draws = 5 strace0.setup(cls.draws, chain=0) strace1.setup(cls.draws, chain=1) varnames = list(cls.test_point.keys()) shapes = {varname: value.shape for varname, value in cls.test_point.items()} dtypes = {varname: value.dtype for varname, value in cls.test_point.items()} cls.expected = {0: {}, 1: {}} for varname in varnames: mcmc_shape = (cls.draws,) + shapes[varname] values = np.arange(cls.draws * np.prod(shapes[varname]), dtype=dtypes[varname]) cls.expected[0][varname] = values.reshape(mcmc_shape) cls.expected[1][varname] = values.reshape(mcmc_shape) * 100 for idx in range(cls.draws): point0 = {varname: cls.expected[0][varname][idx, ...] for varname in varnames} point1 = {varname: cls.expected[1][varname][idx, ...] for varname in varnames} strace0.record(point=point0) strace1.record(point=point1) strace0.close() strace1.close() cls.mtrace = base.MultiTrace([strace0, strace1])
def load(con_str, model=None): """Load ODBC database. Parameters ---------- con_str : str ODBC Connection string including database model : Model If None, the model is taken from the `with` context. Returns ------- A MultiTrace instance """ db = _ODBCDB(con_str) db.connect() name = _get_db_name(con_str) varnames = _get_table_list(db.cursor) if len(varnames) == 0: raise ValueError(('Can not get variable list for database' '`{}`'.format(name))) chains = _get_chain_list(db.cursor, varnames[0]) print(chains) straces = [] for chain in chains: strace = ODBC(con_str, model=model) strace.chain = chain strace._var_cols = {varname: ttab.create_flat_names('v', shape) for varname, shape in strace.var_shapes.items()} strace._is_setup = True strace.db = db # Share the db with all traces. straces.append(strace) return base.MultiTrace(straces)
def test_merge_traces_diff_lengths(self): with self.model: strace0 = self.backend(self.name) strace0.setup(self.draws, 1) for i in range(self.draws): strace0.record(self.test_point) strace0.close() mtrace0 = base.MultiTrace([self.strace0]) with self.model: strace1 = self.backend(self.name) strace1.setup(2 * self.draws, 1) for i in range(2 * self.draws): strace1.record(self.test_point) strace1.close() mtrace1 = base.MultiTrace([strace1]) with pytest.raises(ValueError): base.merge_traces([mtrace0, mtrace1])
def load_multitrace(dirname, varnames=None, chains=None): """ Load TextChain database. Parameters ---------- dirname : str Name of directory with files (one per chain) varnames : list of strings with variable names chains : list optional Returns ------- A :class:`pymc3.backend.base.MultiTrace` instance """ if not istransd(varnames)[0]: logger.info('Loading multitrace from %s' % dirname) if chains is None: files = glob(os.path.join(dirname, 'chain-*.csv')) chains = [ int( os.path.splitext(os.path.basename(f))[0].replace( 'chain-', '')) for f in files ] final_chain = -1 if final_chain in chains: idx = chains.index(final_chain) files.pop(idx) chains.pop(idx) else: files = [ os.path.join(dirname, 'chain-%i.csv' % chain) for chain in chains ] for f in files: if not os.path.exists(f): raise IOError('File %s does not exist! Please run:' ' "beat summarize <project_dir>"!' % f) straces = [] for chain, f in zip(chains, files): strace = TextChain(dirname) strace.chain = chain strace.filename = f straces.append(strace) return base.MultiTrace(straces) else: logger.info('Loading trans-d trace from %s' % dirname) raise NotImplementedError('Loading trans-d trace is not implemented!')
def load_multitrace(dirname, varnames=None, chains=None): """ Load TextChain database. Parameters ---------- dirname : str Name of directory with files (one per chain) model : Model If None, the model is taken from the `with` context. chains : list optional Returns ------- A :class:`pymc3.backend.base.MultiTrace` instance """ if not istransd(varnames)[0]: logger.info('Loading multitrace from %s' % dirname) if chains is None: files = glob(os.path.join(dirname, 'chain-*.csv')) chains = list( set([ int(os.path.splitext(f)[0].rsplit('-', 1)[1]) for f in files ])) else: files = [ os.path.join(dirname, 'chain-%i.csv' % chain) for chain in chains ] for f in files: if not os.path.exists(f): raise IOError('File %s does not exist! Please run:' ' "beat summarize <project_dir>"!' % f) straces = [] for chain, f in zip(chains, files): strace = TextChain(dirname) strace.chain = chain strace.filename = f straces.append(strace) return base.MultiTrace(straces) else: logger.info('Loading trans-d trace from %s' % dirname) raise NotImplementedError('Loading trans-d trace is not implemented!')
def setup_class(cls): cls.test_point, cls.model, _ = models.beta_bernoulli(cls.shape) if hasattr(cls, "write_partial_chain") and cls.write_partial_chain is True: cls.chain_vars = cls.model.unobserved_RVs[1:] else: cls.chain_vars = cls.model.unobserved_RVs with cls.model: strace0 = cls.backend(cls.name, vars=cls.chain_vars) strace1 = cls.backend(cls.name, vars=cls.chain_vars) if not hasattr(cls, "sampler_vars"): cls.sampler_vars = None cls.draws = 5 if cls.sampler_vars is not None: strace0.setup(cls.draws, chain=0, sampler_vars=cls.sampler_vars) strace1.setup(cls.draws, chain=1, sampler_vars=cls.sampler_vars) else: strace0.setup(cls.draws, chain=0) strace1.setup(cls.draws, chain=1) varnames = list(cls.test_point.keys()) shapes = {varname: value.shape for varname, value in cls.test_point.items()} dtypes = {varname: value.dtype for varname, value in cls.test_point.items()} cls.expected = {0: {}, 1: {}} for varname in varnames: mcmc_shape = (cls.draws,) + shapes[varname] values = np.arange(cls.draws * np.prod(shapes[varname]), dtype=dtypes[varname]) cls.expected[0][varname] = values.reshape(mcmc_shape) cls.expected[1][varname] = values.reshape(mcmc_shape) * 100 if cls.sampler_vars is not None: cls.expected_stats = {0: [], 1: []} for vars in cls.sampler_vars: stats = {} cls.expected_stats[0].append(stats) cls.expected_stats[1].append(stats) for key, dtype in vars.items(): if dtype == np.bool: stats[key] = np.zeros(cls.draws, dtype=dtype) else: stats[key] = np.arange(cls.draws, dtype=dtype) for idx in range(cls.draws): point0 = {varname: cls.expected[0][varname][idx, ...] for varname in varnames} point1 = {varname: cls.expected[1][varname][idx, ...] for varname in varnames} if cls.sampler_vars is not None: stats1 = [ {key: val[idx] for key, val in stats.items()} for stats in cls.expected_stats[0] ] stats2 = [ {key: val[idx] for key, val in stats.items()} for stats in cls.expected_stats[1] ] strace0.record(point=point0, sampler_stats=stats1) strace1.record(point=point1, sampler_stats=stats2) else: strace0.record(point=point0) strace1.record(point=point1) strace0.close() strace1.close() cls.mtrace = base.MultiTrace([strace0, strace1]) cls.stat_dtypes = {} cls.stats_counts = collections.Counter() for stats in cls.sampler_vars or []: cls.stat_dtypes.update(stats) cls.stats_counts.update(stats.keys())
def test_merge_traces_nonunique(self): mtrace0 = base.MultiTrace([self.strace0]) mtrace1 = base.MultiTrace([self.strace1]) with pytest.raises(ValueError): base.merge_traces([mtrace0, mtrace1])
def test_multitrace_nonunique(self): with pytest.raises(ValueError): base.MultiTrace([self.strace0, self.strace1])
def test_merge_traces_nonunique(self): mtrace0 = base.MultiTrace([self.strace0]) mtrace1 = base.MultiTrace([self.strace1]) self.assertRaises(ValueError, base.merge_traces, [mtrace0, mtrace1])