示例#1
0
    def connect_model(self, model):
        """Link the Database to the Model instance.

        In case a new database is created from scratch, ``connect_model``
        creates Trace objects for all tallyable pymc objects defined in
        `model`.

        If the database is being loaded from an existing file, ``connect_model``
        restore the objects trace to their stored value.

        :Parameters:
        model : pymc.Model instance
          An instance holding the pymc objects defining a statistical
          model (stochastics, deterministics, data, ...)
        """

        # Changed this to allow non-Model models. -AP
        if isinstance(model, pymc.Model):
            self.model = model
        else:
            raise AttributeError('Not a Model instance.')

        # Restore the state of the Model from an existing Database.
        # The `load` method will have already created the Trace objects.
        if hasattr(self, '_state_'):
            names = set()
            for morenames in self.trace_names:
                names.update(morenames)
            for name, fun in six.iteritems(model._funs_to_tally):
                if name in self._traces:
                    self._traces[name]._getfunc = fun
                    names.remove(name)
            if len(names) > 0:
                raise RuntimeError(
                    "Some objects from the database have not been assigned a getfunc: %s" %
                    ', '.join(names))

        # Create a fresh new state. This is now taken care of in initialize.
        else:
            for name, fun in six.iteritems(model._funs_to_tally):
                if np.array(fun()).dtype is np.dtype('object'):
                    self._traces[
                        name] = TraceObject(
                            name,
                            getfunc=fun,
                            db=self)
                else:
                    self._traces[name] = Trace(name, getfunc=fun, db=self)
示例#2
0
    def __init__(self, iterable):
        dict.__init__(self, iterable)
        ContainerBase.__init__(self, iterable)
        self._value = copy(iterable)
        file_items(self, iterable)

        self.val_keys = []
        self.val_obj = []
        self.nonval_keys = []
        self.nonval_obj = []
        self._value = {}
        for key, obj in six.iteritems(self):
            if isinstance(obj, Variable) or isinstance(obj, ContainerBase):
                self.val_keys.append(key)
                self.val_obj.append(obj)
            else:
                self.nonval_keys.append(key)
                self.nonval_obj.append(obj)
        # In case val_obj is only a single array, avert confusion.
        # Leave this even though it's confusing!
        self.val_obj.append(None)
        self.nonval_obj.append(None)

        self.n_val = len(self.val_keys)
        self.val_keys = array(self.val_keys, dtype=object)
        # self.val_obj = array(self.val_obj, dtype=object)
        self.n_nonval = len(self.nonval_keys)
        self.nonval_keys = array(self.nonval_keys, dtype=object)
        self.nonval_obj = array(self.nonval_obj, dtype=object)
        self.DCValue = DCValue(self)
示例#3
0
def load(filename):
    try:
        file = open(filename, 'rb')
    except FileNotFoundError:
        db_name = filename.split('/')[-1]
        db_path = get_dbpath()
        file = open(path.join(db_path, db_name), 'rb')
    container = std_pickle.load(file)
    file.close()
    db = Database(file.name)        
    chains = 0
    funs = set()
    for k, v in six.iteritems(container):
        if k == '_state_':
            db._state_ = v
        else:
            db._traces[k] = Trace(name=k, value=v, db=db)
            setattr(db, k, db._traces[k])
            chains = max(chains, len(v))
            funs.add(k)

    db.chains = chains
    db.trace_names = chains * [list(funs)]

    return db
示例#4
0
    def _initialize(self, funs_to_tally, length=None):
        """Initialize the tallyable objects.

        Makes sure a Trace object exists for each variable and then initialize
        the Traces.

        :Parameters:
        funs_to_tally : dict
          Name- function pairs.
        length : int
          The expected length of the chain. Some database may need the argument
          to preallocate memory.
        """

        for name, fun in six.iteritems(funs_to_tally):
            if name not in self._traces:
                self._traces[
                    name] = self.__Trace__(
                        name=name,
                        getfunc=fun,
                        db=self)

            self._traces[name]._initialize(self.chains, length)

        self.trace_names.append(list(funs_to_tally.keys()))

        self.chains += 1
示例#5
0
    def connect_model(self, model):
        """Link the Database to the Model instance.

        In case a new database is created from scratch, ``connect_model``
        creates Trace objects for all tallyable pymc objects defined in
        `model`.

        If the database is being loaded from an existing file, ``connect_model``
        restore the objects trace to their stored value.

        :Parameters:
        model : pymc.Model instance
          An instance holding the pymc objects defining a statistical
          model (stochastics, deterministics, data, ...)
        """

        # Changed this to allow non-Model models. -AP
        if isinstance(model, pymc.Model):
            self.model = model
        else:
            raise AttributeError('Not a Model instance.')

        # Restore the state of the Model from an existing Database.
        # The `load` method will have already created the Trace objects.
        if hasattr(self, '_state_'):
            names = set()
            for morenames in self.trace_names:
                names.update(morenames)
            for name, fun in six.iteritems(model._funs_to_tally):
                if name in self._traces:
                    self._traces[name]._getfunc = fun
                    names.remove(name)
            if len(names) > 0:
                raise RuntimeError(
                    "Some objects from the database have not been assigned a getfunc: %s"
                    % ', '.join(names))

        # Create a fresh new state. This is now taken care of in initialize.
        else:
            for name, fun in six.iteritems(model._funs_to_tally):
                if np.array(fun()).dtype is np.dtype('object'):
                    self._traces[name] = TraceObject(name,
                                                     getfunc=fun,
                                                     db=self)
                else:
                    self._traces[name] = Trace(name, getfunc=fun, db=self)
示例#6
0
    def connect_model(self, model):
        """Link the Database to the Model instance.

        In case a new database is created from scratch, ``connect_model``
        creates Trace objects for all tallyable pymc objects defined in
        `model`.

        If the database is being loaded from an existing file, ``connect_model``
        restore the objects trace to their stored value.

        :Parameters:
        model : pymc.Model instance
          An instance holding the pymc objects defining a statistical
          model (stochastics, deterministics, data, ...)
        """
        # Changed this to allow non-Model models. -AP
        # We could also remove it altogether. -DH
        if isinstance(model, pymc.Model):
            self.model = model
        else:
            raise AttributeError('Not a Model instance.')

        # Restore the state of the Model from an existing Database.
        # The `load` method will have already created the Trace objects.
        if hasattr(self, '_state_'):
            names = set()
            for morenames in self.trace_names:
                names.update(morenames)
            for name, fun in six.iteritems(model._funs_to_tally):
                if name in self._traces:
                    self._traces[name]._getfunc = fun
                    names.discard(name)
            # if len(names) > 0:
            # print_("Some objects from the database have not been assigned a
            # getfunc", names)

        # Create a fresh new state.
        # We will be able to remove this when we deprecate traces on objects.
        else:
            for name, fun in six.iteritems(model._funs_to_tally):
                if name not in self._traces:
                    self._traces[
                        name] = self.__Trace__(
                            name=name,
                            getfunc=fun,
                            db=self)
示例#7
0
def check_gradients(stochastic):

    stochastics = find_variable_set(stochastic)
    gradients = utils.logp_gradient_of_set(stochastics, stochastics)

    for s, analytic_gradient in six.iteritems(gradients):

            numeric_gradient = get_numeric_gradient(stochastics, s)

            assert_array_almost_equal(numeric_gradient, analytic_gradient, 3,
                                      "analytic gradient for " + str(stochastic) +
                                      " with respect to parameter " + str(s) +
                                      " is not correct.")
示例#8
0
def check_gradients(stochastic):

    stochastics = find_variable_set(stochastic)
    gradients = utils.logp_gradient_of_set(stochastics, stochastics)

    for s, analytic_gradient in six.iteritems(gradients):

        numeric_gradient = get_numeric_gradient(stochastics, s)

        assert_array_almost_equal(
            numeric_gradient, analytic_gradient, 3,
            "analytic gradient for " + str(stochastic) +
            " with respect to parameter " + str(s) + " is not correct.")
示例#9
0
def check_jacobians( deterministic):
    for parameter, pvalue in six.iteritems(deterministic.parents):

        if isinstance(pvalue, Variable):

            grad = random.normal(.5, .1, size = shape(deterministic.value))
            a_partial_grad = get_analytic_partial_gradient(deterministic, parameter, pvalue, grad)

            n_partial_grad = get_numeric_partial_gradient(deterministic, pvalue, grad)

            assert_array_almost_equal(a_partial_grad, n_partial_grad,4,
                     "analytic partial gradient for " + str(deterministic) +
                     " with respect to parameter " + str(parameter) +
                     " is not correct.")
示例#10
0
    def _initialize(self, funs_to_tally, length=None):

        for name, fun in six.iteritems(funs_to_tally):
            if name not in self._traces:
                self._traces[name] = self.__Trace__(name=name, getfunc=fun, db=self)
            # if db is loaded from disk, it might not have its tallied step method
            self._traces[name]._initialize()

        i = self.chains
        self.ncfile.createGroup("Chain#%d" % i)

        # set dimensions
        if 'nsamples' not in self.ncfile['Chain#%d' %i].dimensions:
            self.ncfile['Chain#%d' % i].createDimension('nsamples', 0)  # unlimited number of iterations

        # sanity check that nsamples is unlimited
        if self.ncfile['Chain#%d' % i].dimensions['nsamples'].isunlimited():
            pass

        # create variables for pymc variables
        for name, fun in six.iteritems(funs_to_tally):
            if not np.asarray(fun()).shape == () and name not in self.ncfile['Chain#%d' % i].variables:
                # create ncvar with dimensions of pymc parameter array and nsamples
                self.ncfile['Chain#%d' % i].createDimension(name, np.asarray(fun()).shape[0])
                self.ncfile['Chain#%d' % i].createVariable(name, np.asarray(fun()).dtype.str, ('nsamples', name))
            elif name not in self.ncfile['Chain#%d' % i].variables:
                # all other ncvar than only need nsamples dimension
                self.ncfile['Chain#%d' % i].createVariable(name, np.asarray(fun()).dtype.str, ('nsamples',))

        if len(self.trace_names) < len(self.ncfile.groups):
            try:
                self.trace_names.append(list(self.ncfile['Chain#%d' % self.chains].variables))
            except IndexError:
                self.trace_names.append(list(funs_to_tally.keys()))
        self.tally_index = len(self.ncfile['Chain#%d' % self.chains].dimensions['nsamples'])
        self.chains += 1
示例#11
0
    def _model_trace_description(self):
        """Return a description of the table and the ObjectAtoms to be created.

        :Returns:
        table_description : dict
          A Description of the pyTables table.
        ObjectAtomsn : dict
          A
      in terms of PyTables
        columns, and a"""
        D = {}
        for name, fun in six.iteritems(self.model._funs_to_tally):
            arr = asarray(fun())
            D[name] = tables.Col.from_dtype(dtype((arr.dtype, arr.shape)))
        return D, {}
示例#12
0
    def _initialize(self, funs_to_tally, length=None):

        for name, fun in six.iteritems(funs_to_tally):
            if name not in self._traces:
                self._traces[name] = self.__Trace__(name=name, getfunc=fun, db=self)
            # if db is loaded from disk, it might not have its tallied step method
            self._traces[name]._initialize()

        i = self.chains
        self.ncfile.createGroup("Chain#%d" % i)

        # set dimensions
        if 'nsamples' not in self.ncfile['Chain#%d' %i].dimensions:
            self.ncfile['Chain#%d' % i].createDimension('nsamples', 0)  # unlimited number of iterations

        # sanity check that nsamples is unlimited
        if self.ncfile['Chain#%d' % i].dimensions['nsamples'].isunlimited():
            pass

        # create variables for pymc variables
        for name, fun in six.iteritems(funs_to_tally):
            if not np.asarray(fun()).shape == () and name not in self.ncfile['Chain#%d' % i].variables:
                # create ncvar with dimensions of pymc parameter array and nsamples
                self.ncfile['Chain#%d' % i].createDimension(name, np.asarray(fun()).shape[0])
                self.ncfile['Chain#%d' % i].createVariable(name, np.asarray(fun()).dtype.str, ('nsamples', name))
            elif name not in self.ncfile['Chain#%d' % i].variables:
                # all other ncvar than only need nsamples dimension
                self.ncfile['Chain#%d' % i].createVariable(name, np.asarray(fun()).dtype.str, ('nsamples',))

        if len(self.trace_names) < len(self.ncfile.groups):
            try:
                self.trace_names.append(list(self.ncfile['Chain#%d' % self.chains].variables))
            except IndexError:
                self.trace_names.append(list(funs_to_tally.keys()))
        self.tally_index = len(self.ncfile['Chain#%d' % self.chains].dimensions['nsamples'])
        self.chains += 1
示例#13
0
    def _model_trace_description(self):
        """Return a description of the table and the ObjectAtoms to be created.

        :Returns:
        table_description : dict
          A Description of the pyTables table.
        ObjectAtomsn : dict
          A
      in terms of PyTables
        columns, and a"""
        D = {}
        for name, fun in six.iteritems(self.model._funs_to_tally):
            arr = asarray(fun())
            D[name] = tables.Col.from_dtype(dtype((arr.dtype, arr.shape)))
        return D, {}
示例#14
0
def load(dirname):
    """Create a Database instance from the data stored in the directory."""
    if not os.path.exists(dirname):
        raise AttributeError('No txt database named %s' % dirname)

    db = Database(dirname, dbmode='a')
    chain_folders = [os.path.join(dirname, c) for c in db.get_chains()]
    db.chains = len(chain_folders)

    data = {}
    for chain, folder in enumerate(chain_folders):
        files = os.listdir(folder)
        funnames = funname(files)
        db.trace_names.append(funnames)
        for file in files:
            name = funname(file)
            if name not in data:
                data[
                    name] = {
                    }  # This could be simplified using "collections.defaultdict(dict)". New in Python 2.5
            # Read the shape information
            with open(os.path.join(folder, file)) as f:
                f.readline()
                shape = eval(f.readline()[16:])
                data[
                    name][
                        chain] = np.loadtxt(
                            os.path.join(
                                folder,
                                file),
                            delimiter=',').reshape(
                                shape)
                f.close()

    # Create the Traces.
    for name, values in six.iteritems(data):
        db._traces[name] = Trace(name=name, value=values, db=db)
        setattr(db, name, db._traces[name])

    # Load the state.
    statefile = os.path.join(dirname, 'state.txt')
    if os.path.exists(statefile):
        with open(statefile, 'r') as f:
            db._state_ = eval(f.read())
    else:
        db._state_ = {}

    return db
示例#15
0
def load(dirname):
    """Create a Database instance from the data stored in the directory."""
    if not os.path.exists(dirname):
        raise AttributeError('No txt database named %s' % dirname)

    db = Database(dirname, dbmode='a')
    chain_folders = [os.path.join(dirname, c) for c in db.get_chains()]
    db.chains = len(chain_folders)

    data = {}
    for chain, folder in enumerate(chain_folders):
        files = os.listdir(folder)
        funnames = funname(files)
        db.trace_names.append(funnames)
        for file in files:
            name = funname(file)
            if name not in data:
                data[name] = {
                }  # This could be simplified using "collections.defaultdict(dict)". New in Python 2.5
            # Read the shape information
            with open(os.path.join(folder, file)) as f:
                f.readline()
                shape = eval(f.readline()[16:])
                data[name][chain] = np.loadtxt(os.path.join(folder, file),
                                               delimiter=',').reshape(shape)
                f.close()

    # Create the Traces.
    for name, values in six.iteritems(data):
        db._traces[name] = Trace(name=name, value=values, db=db)
        setattr(db, name, db._traces[name])

    # Load the state.
    statefile = os.path.join(dirname, 'state.txt')
    if os.path.exists(statefile):
        with open(statefile, 'r') as f:
            db._state_ = eval(f.read())
    else:
        db._state_ = {}

    return db
示例#16
0
    def _initialize(self, funs_to_tally, length):
        """Create a group named ``chain#`` to store all data for this chain."""

        chain = self.nchains
        self._chains[chain] = self._h5file.createGroup(
            '/', 'chain%d' % chain, 'chain #%d' % chain)

        for name, fun in six.iteritems(funs_to_tally):

            arr = np.asarray(fun())

            assert arr.dtype != np.dtype('object')

            array = self._h5file.createEArray(
                self._chains[chain], name,
                tables.Atom.from_dtype(arr.dtype), (0,) + arr.shape,
                filters=self.filter)

            self._arrays[chain, name] = array
            self._traces[name] = Trace(name, getfunc=fun, db=self)
            self._traces[name]._initialize(self.chains, length)

        self.trace_names.append(funs_to_tally.keys())
示例#17
0
文件: hdf5ea.py 项目: wqren/pymc
    def _initialize(self, funs_to_tally, length):
        """Create a group named ``chain#`` to store all data for this chain."""

        chain = self.nchains
        self._chains[chain] = self._h5file.createGroup('/', 'chain%d' % chain,
                                                       'chain #%d' % chain)

        for name, fun in six.iteritems(funs_to_tally):

            arr = np.asarray(fun())

            assert arr.dtype != np.dtype('object')

            array = self._h5file.createEArray(
                self._chains[chain],
                name,
                tables.Atom.from_dtype(arr.dtype), (0, ) + arr.shape,
                filters=self.filter)

            self._arrays[chain, name] = array
            self._traces[name] = Trace(name, getfunc=fun, db=self)
            self._traces[name]._initialize(self.chains, length)

        self.trace_names.append(funs_to_tally.keys())
示例#18
0
def load(filename):
    """Load a pickled database.

    Return a Database instance.
    """
    file = open(filename, 'rb')
    container = std_pickle.load(file)
    file.close()
    db = Database(file.name)
    chains = 0
    funs = set()
    for k, v in six.iteritems(container):
        if k == '_state_':
            db._state_ = v
        else:
            db._traces[k] = Trace(name=k, value=v, db=db)
            setattr(db, k, db._traces[k])
            chains = max(chains, len(v))
            funs.add(k)

    db.chains = chains
    db.trace_names = chains * [list(funs)]

    return db
示例#19
0
def load(filename):
    """Load a pickled database.

    Return a Database instance.
    """
    file = open(filename, 'rb')
    container = std_pickle.load(file)
    file.close()
    db = Database(file.name)
    chains = 0
    funs = set()
    for k, v in six.iteritems(container):
        if k == '_state_':
            db._state_ = v
        else:
            db._traces[k] = Trace(name=k, value=v, db=db)
            setattr(db, k, db._traces[k])
            chains = max(chains, len(v))
            funs.add(k)

    db.chains = chains
    db.trace_names = chains * [list(funs)]

    return db
示例#20
0
    def _initialize(self, funs_to_tally, length):
        """
        Create a group named ``Chain#`` to store all data for this chain.
        The group contains one pyTables Table, and at least one subgroup
        called ``group#``. This subgroup holds ObjectAtoms, which can hold
        pymc objects whose value is not a numerical array.

        There is too much stuff in here. ObjectAtoms should get initialized
        """
        i = self.chains
        self._chains.append(
            self._h5file.create_group("/", 'chain%d' % i, 'Chain #%d' % i))
        current_object_group = self._h5file.create_group(
            self._chains[-1], 'group0', 'Group storing objects.')
        group_counter = 0
        object_counter = 0

        # Create the Table in the chain# group, and ObjectAtoms in
        # chain#/group#.
        table_descr = {}
        for name, fun in six.iteritems(funs_to_tally):

            arr = asarray(fun())

            if arr.dtype is np.dtype('object'):

                self._traces[name]._vlarrays.append(
                    self._h5file.create_vlarray(current_object_group,
                                                name,
                                                tables.ObjectAtom(),
                                                title=name + ' samples.',
                                                filters=self.filter))

                object_counter += 1
                if object_counter % 4096 == 0:
                    group_counter += 1
                    current_object_group = self._h5file.create_group(
                        self._chains[-1], 'group%d' % group_counter,
                        'Group storing objects.')

            else:
                table_descr[name] = tables.Col.from_dtype(
                    dtype((arr.dtype, arr.shape)))

        table = self._h5file.create_table(self._chains[-1],
                                          'PyMCsamples',
                                          table_descr,
                                          title='PyMC samples',
                                          filters=self.filter,
                                          expectedrows=length)

        self._tables.append(table)
        self._rows.append(self._tables[-1].row)

        # Store data objects
        for object in self.model.observed_stochastics:
            if object.keep_trace is True:
                setattr(table.attrs, object.__name__, object.value)

    # Make sure the variables have a corresponding Trace instance.
        for name, fun in six.iteritems(funs_to_tally):
            if name not in self._traces:
                if np.array(fun()).dtype is np.dtype('object'):
                    self._traces[name] = TraceObject(name,
                                                     getfunc=fun,
                                                     db=self)
                else:
                    self._traces[name] = Trace(name, getfunc=fun, db=self)

            self._traces[name]._initialize(self.chains, length)

        self.trace_names.append(list(funs_to_tally.keys()))
        self.chains += 1
示例#21
0
def find_variable_set(stochastic):
    set = [stochastic]
    for parameter, variable in six.iteritems(stochastic.parents):
        if isinstance(variable, Variable):
            set.append(variable)
    return set
示例#22
0
    def __init__(self,
                 dbname,
                 dbmode='a',
                 dbcomplevel=0,
                 dbcomplib='zlib',
                 **kwds):
        """Create an HDF5 database instance, where samples are stored in tables.

        :Parameters:
        dbname : string
          Name of the hdf5 file.
        dbmode : {'a', 'w', 'r'}
          File mode: 'a': append, 'w': overwrite, 'r': read-only.
        dbcomplevel : integer (0-9)
          Compression level, 0: no compression.
        dbcomplib : string
          Compression library (zlib, bzip2, lzo)

        :Notes:
          * zlib has a good compression ratio, although somewhat slow, and
            reasonably fast decompression.
          * LZO is a fast compression library offering however a low compression
            ratio.
          * bzip2 has an excellent compression ratio but requires more CPU.
        """
        self.__name__ = 'hdf5'
        self.dbname = dbname
        self.__Trace__ = Trace
        self.mode = dbmode

        self.trace_names = []
        # A list of sequences of names of the objects to tally.
        self._traces = {}  # A dictionary of the Trace objects.

        db_exists = os.path.exists(self.dbname)
        self._h5file = tables.open_file(self.dbname, self.mode)

        default_filter = tables.Filters(complevel=dbcomplevel,
                                        complib=dbcomplib)
        if self.mode == 'r' or (self.mode == 'a' and db_exists):
            self.filter = getattr(self._h5file, 'filters', default_filter)
        else:
            self.filter = default_filter

        self._tables = self._gettables(
        )  # This should be a dict keyed by chain.
        self._rows = len(self._tables) * [
            None,
        ]  # This should be a dict keyed by chain.
        self._chains = [
            gr for gr in self._h5file.list_nodes("/")
            if gr._v_name[:5] == 'chain'
        ]  # This should be a dict keyed by chain.
        self.chains = len(self._chains)

        # LOAD LOGIC
        if self.chains > 0:
            # Create traces from objects stored in Table.
            db = self
            for k in db._tables[-1].colnames:
                db._traces[k] = Trace(name=k, db=db)
                setattr(db, k, db._traces[k])

            # Walk nodes proceed from top to bottom, so we need to invert
            # the list to have the chains in chronological order.
            objects = {}
            for chain in self._chains:
                for node in db._h5file.walk_nodes(chain, classname='VLArray'):
                    if node._v_name != '_state_':
                        try:
                            objects[node._v_name].append(node)
                        except:
                            objects[node._v_name] = [
                                node,
                            ]

            # Note that the list vlarrays is in reverse order.
            for k, vlarrays in six.iteritems(objects):
                db._traces[k] = TraceObject(name=k, db=db, vlarrays=vlarrays)
                setattr(db, k, db._traces[k])

            # Restore table attributes.
            # This restores the sampler's state for the last chain.
            table = db._tables[-1]
            for k in table.attrs._v_attrnamesuser:
                setattr(db, k, getattr(table.attrs, k))

            # Restore group attributes.
            for k in db._chains[-1]._f_list_nodes():
                if k.__class__ not in [tables.Table, tables.Group]:
                    setattr(db, k.name, k)

            varnames = db._tables[-1].colnames + list(objects.keys())
            db.trace_names = db.chains * [
                varnames,
            ]
示例#23
0
    def _initialize(self, funs_to_tally, length):
        """
        Create a group named ``Chain#`` to store all data for this chain.
        The group contains one pyTables Table, and at least one subgroup
        called ``group#``. This subgroup holds ObjectAtoms, which can hold
        pymc objects whose value is not a numerical array.

        There is too much stuff in here. ObjectAtoms should get initialized
        """
        i = self.chains
        self._chains.append(
            self._h5file.createGroup(
                "/",
                'chain%d' %
                i,
                'Chain #%d' %
                i))
        current_object_group = self._h5file.createGroup(
            self._chains[-1],
            'group0',
            'Group storing objects.')
        group_counter = 0
        object_counter = 0

        # Create the Table in the chain# group, and ObjectAtoms in
        # chain#/group#.
        table_descr = {}
        for name, fun in six.iteritems(funs_to_tally):

            arr = asarray(fun())

            if arr.dtype is np.dtype('object'):

                self._traces[name]._vlarrays.append(self._h5file.createVLArray(
                    current_object_group,
                    name,
                    tables.ObjectAtom(),
                    title=name + ' samples.',
                    filters=self.filter))

                object_counter += 1
                if object_counter % 4096 == 0:
                    group_counter += 1
                    current_object_group = self._h5file.createGroup(
                        self._chains[-1],
                        'group%d' % group_counter, 'Group storing objects.')

            else:
                table_descr[name] = tables.Col.from_dtype(
                    dtype((arr.dtype, arr.shape)))

        table = self._h5file.createTable(self._chains[-1],
                                         'PyMCsamples',
                                         table_descr,
                                         title='PyMC samples',
                                         filters=self.filter,
                                         expectedrows=length)

        self._tables.append(table)
        self._rows.append(self._tables[-1].row)

        # Store data objects
        for object in self.model.observed_stochastics:
            if object.keep_trace is True:
                setattr(table.attrs, object.__name__, object.value)

       # Make sure the variables have a corresponding Trace instance.
        for name, fun in six.iteritems(funs_to_tally):
            if name not in self._traces:
                if np.array(fun()).dtype is np.dtype('object'):
                    self._traces[
                        name] = TraceObject(
                            name,
                            getfunc=fun,
                            db=self)
                else:
                    self._traces[name] = Trace(name, getfunc=fun, db=self)

            self._traces[name]._initialize(self.chains, length)

        self.trace_names.append(list(funs_to_tally.keys()))
        self.chains += 1
示例#24
0
def filter_dict(obj):
    filtered_dict = {}
    for item in six.iteritems(obj.__dict__):
        if isinstance(item[1], Node) or isinstance(item[1], ContainerBase):
            filtered_dict[item[0]] = item[1]
    return filtered_dict
示例#25
0
    def __init__(self, dbname, dbmode='a',
                 dbcomplevel=0, dbcomplib='zlib', **kwds):
        """Create an HDF5 database instance, where samples are stored in tables.

        :Parameters:
        dbname : string
          Name of the hdf5 file.
        dbmode : {'a', 'w', 'r'}
          File mode: 'a': append, 'w': overwrite, 'r': read-only.
        dbcomplevel : integer (0-9)
          Compression level, 0: no compression.
        dbcomplib : string
          Compression library (zlib, bzip2, lzo)

        :Notes:
          * zlib has a good compression ratio, although somewhat slow, and
            reasonably fast decompression.
          * LZO is a fast compression library offering however a low compression
            ratio.
          * bzip2 has an excellent compression ratio but requires more CPU.
        """
        self.__name__ = 'hdf5'
        self.dbname = dbname
        self.__Trace__ = Trace
        self.mode = dbmode

        self.trace_names = []
        # A list of sequences of names of the objects to tally.
        self._traces = {}  # A dictionary of the Trace objects.

        # Deprecation of complevel and complib
        # Remove in V2.1
        if 'complevel' in kwds:
            warnings.warn(
                'complevel has been replaced with dbcomplevel.',
                DeprecationWarning)
            dbcomplevel = kwds.get('complevel')
        if 'complib' in kwds:
            warnings.warn(
                'complib has been replaced with dbcomplib.',
                DeprecationWarning)
            dbcomplib = kwds.get('complib')

        db_exists = os.path.exists(self.dbname)
        self._h5file = tables.openFile(self.dbname, self.mode)

        default_filter = tables.Filters(
            complevel=dbcomplevel,
            complib=dbcomplib)
        if self.mode == 'r' or (self.mode == 'a' and db_exists):
            self.filter = getattr(self._h5file, 'filters', default_filter)
        else:
            self.filter = default_filter

        self._tables = self._gettables(
        )  # This should be a dict keyed by chain.
        self._rows = len(
            self._tables) * [None,
                             ]  # This should be a dict keyed by chain.
        self._chains = [
            gr for gr in self._h5file.listNodes(
                "/") if gr._v_name[
                    :5] == 'chain']  # This should be a dict keyed by chain.
        self.chains = len(self._chains)

        # LOAD LOGIC
        if self.chains > 0:
            # Create traces from objects stored in Table.
            db = self
            for k in db._tables[-1].colnames:
                db._traces[k] = Trace(name=k, db=db)
                setattr(db, k, db._traces[k])

            # Walk nodes proceed from top to bottom, so we need to invert
            # the list to have the chains in chronological order.
            objects = {}
            for chain in self._chains:
                for node in db._h5file.walkNodes(chain, classname='VLArray'):
                    if node._v_name != '_state_':
                        try:
                            objects[node._v_name].append(node)
                        except:
                            objects[node._v_name] = [node, ]

            # Note that the list vlarrays is in reverse order.
            for k, vlarrays in six.iteritems(objects):
                db._traces[k] = TraceObject(name=k, db=db, vlarrays=vlarrays)
                setattr(db, k, db._traces[k])

            # Restore table attributes.
            # This restores the sampler's state for the last chain.
            table = db._tables[-1]
            for k in table.attrs._v_attrnamesuser:
                setattr(db, k, getattr(table.attrs, k))

            # Restore group attributes.
            for k in db._chains[-1]._f_listNodes():
                if k.__class__ not in [tables.Table, tables.Group]:
                    setattr(db, k.name, k)

            varnames = db._tables[-1].colnames + list(objects.keys())
            db.trace_names = db.chains * [varnames, ]
示例#26
0
def find_variable_set(stochastic):
    set = [stochastic]
    for parameter, variable in six.iteritems(stochastic.parents):
        if isinstance(variable, Variable):
            set.append(variable)
    return set