示例#1
0
    def _removeStaleKeyVal(self, index):

        self.repack = True

        keys = self.fileKeys.read()
        vals = self.fileVals.read()

        self.file.removeNode("/keys")
        self.file.removeNode("/vals")

        self.file.flush()

        self.file.createVLArray("/",
                                "keys",
                                tables.VLStringAtom(),
                                filters=self.filter)
        self.file.createVLArray("/",
                                "vals",
                                tables.ObjectAtom(),
                                filters=self.filter)
        self.fileKeys = self.file.root.keys
        self.fileVals = self.file.root.vals
        for (curKey, curVal) in zip(keys[0:index], vals[0:index]):
            self.fileKeys.append(curKey)
            self.fileVals.append(curVal)
        for (curKey, curVal) in zip(keys[index + 1:], vals[index + 1:]):
            self.fileKeys.append(curKey)
            self.fileVals.append(curVal)
示例#2
0
    def formatFile(self):

        if '/keys' not in self.file:
            self.file.createVLArray("/",
                                    "keys",
                                    tables.VLStringAtom(),
                                    filters=self.filter)

        if '/vals' not in self.file:
            self.file.createVLArray("/",
                                    "vals",
                                    tables.ObjectAtom(),
                                    filters=self.filter)

        try:
            self.fileKeys = self.file.getNode("/", "keys", classname="VLArray")
            self.fileVals = self.file.getNode("/", "vals", classname="VLArray")
        except tables.NoSuchNodeError:
            raise AttributeError("Incorrect File Structure: %s" % \
                                     self.filename)

        if not isinstance(self.fileKeys.atom, tables.VLStringAtom) or \
                not isinstance(self.fileVals.atom, tables.ObjectAtom):
            raise AttributeError("Incorrect File Structure: %s" % \
                                     self.filename)
        self.constraints()
示例#3
0
    def setUp(self):

        h5file = tables.openFile("test.h5", "w")

        h5file.createVLArray("/",
                             "keys",
                             tables.VLStringAtom(),
                             filters=tables.Filters(complevel=1))
        keys = h5file.root.keys

        h5file.createVLArray("/",
                             "vals",
                             tables.ObjectAtom(),
                             filters=tables.Filters(complevel=1))
        vals = h5file.root.vals

        keys.append('1')
        vals.append(MyObj('1'))

        keys.append('2')
        vals.append(MyObj('2'))

        keys.append('list')
        vals.append([1, 2, 3])

        h5file.close()

        self.shelf = PytableShelf('test.h5')
示例#4
0
def _save_pickled(handler, group, level, name=None):
    warnings.warn(('(deepdish.io.save) Pickling {}: This may cause '
                   'incompatibities (for instance between Python 2 and '
                   '3) and should ideally be avoided').format(level),
                  DeprecationWarning)
    node = handler.create_vlarray(group, name, tables.ObjectAtom())
    node.append(level)
示例#5
0
 def __setitem__(self, name, value):
     if not isinstance(name, str):
         raise TypeError('H5Vault requires str keys')
     name = name.replace('.', '_')
     if name not in self._vaultnode._v_children:
         vault_bin = self._h5f.create_vlarray(self._vaultnode, name,
                                              tb.ObjectAtom())
     else:
         vault_bin = self._vaultnode._v_children[name]
     vault_bin.append(cloudpickle.dumps(value))
示例#6
0
def create_data_file(out_file, n_samples):
    hdf5_file = tables.open_file(out_file, mode='w')
    filters = tables.Filters(complevel=5, complib='blosc')
    data_storage = hdf5_file.create_vlarray(hdf5_file.root,
                                            'data',
                                            tables.ObjectAtom(),
                                            filters=filters,
                                            expectedrows=n_samples)
    truth_storage = hdf5_file.create_vlarray(hdf5_file.root,
                                             'truth',
                                             tables.ObjectAtom(),
                                             filters=filters,
                                             expectedrows=n_samples)
    mask_storage = hdf5_file.create_vlarray(hdf5_file.root,
                                            'mask',
                                            tables.ObjectAtom(),
                                            filters=filters,
                                            expectedrows=n_samples)
    return hdf5_file, data_storage, truth_storage, mask_storage
示例#7
0
 def savestate(self, state, chain=-1):
     """Store a dictionnary containing the state of the Model and its
     StepMethods."""
     cur_chain = self._chains[chain]
     if hasattr(cur_chain, '_state_'):
         cur_chain._state_[0] = state
     else:
         s = self._h5file.createVLArray(cur_chain,'_state_',tables.ObjectAtom(),title='The saved state of the sampler',filters=self.filter)
         s.append(state)
     self._h5file.flush()
示例#8
0
def sort_db_file(fp_filename):
    """ Sort fp file.

    Sorts an existing fp file. 
    
    :param fp_filename: FPs filename.
    :return: None.
    """
    # rename not sorted filename
    tmp_filename = fp_filename + '_tmp'
    os.rename(fp_filename, tmp_filename)
    filters = tb.Filters(complib='blosc', complevel=5)

    # copy sorted fps and config to a new file
    with tb.open_file(tmp_filename, mode='r') as fp_file:
        with tb.open_file(fp_filename, mode='w') as sorted_fp_file:
            fp_func = fp_file.root.config[0]
            fp_func_params = fp_file.root.config[1]
            fp_length = get_fp_length(fp_func, fp_func_params)

            # create a sorted copy of the fps table
            dst_fps = fp_file.root.fps.copy(sorted_fp_file.root,
                                            'fps',
                                            filters=filters,
                                            copyuserattrs=True,
                                            overwrite=True,
                                            stats={
                                                'groups': 0,
                                                'leaves': 0,
                                                'links': 0,
                                                'bytes': 0,
                                                'hardlinks': 0
                                            },
                                            start=None,
                                            stop=None,
                                            step=None,
                                            chunkshape='keep',
                                            sortby='popcnt',
                                            check_CSI=True,
                                            propindexes=True)

            # set config table; used fp function, parameters and rdkit version
            param_table = sorted_fp_file.create_vlarray(sorted_fp_file.root,
                                                        'config',
                                                        atom=tb.ObjectAtom())
            param_table.append(fp_func)
            param_table.append(fp_func_params)
            param_table.append(rdkit.__version__)

            # update count ranges
            count_ranges = calc_count_ranges(dst_fps, fp_length)
            param_table.append(count_ranges)

    # remove not sorted file
    os.remove(tmp_filename)
示例#9
0
文件: data.py 项目: shaikr/UNET_3D
def create_data_file(out_file, add_pred, n_samples):
    hdf5_file = tables.open_file(out_file, mode='w')
    filters = tables.Filters(complevel=5, complib='blosc')
    data_storage = hdf5_file.create_vlarray(hdf5_file.root,
                                            'data',
                                            tables.ObjectAtom(),
                                            filters=filters,
                                            expectedrows=n_samples)
    truth_storage = hdf5_file.create_vlarray(hdf5_file.root,
                                             'truth',
                                             tables.ObjectAtom(),
                                             filters=filters,
                                             expectedrows=n_samples)
    if add_pred is not None:
        pred_storage = hdf5_file.create_vlarray(hdf5_file.root,
                                                'pred',
                                                tables.ObjectAtom(),
                                                filters=filters,
                                                expectedrows=n_samples)
        return hdf5_file, data_storage, truth_storage, pred_storage
    else:
        return hdf5_file, data_storage, truth_storage, None
示例#10
0
    def testPartialFile(self):

        h5file = tables.openFile("incorrectFile.h5", mode="w")
        h5file.createVLArray("/",
                             "vals",
                             tables.ObjectAtom(),
                             filters=tables.Filters(complevel=1))
        arr = h5file.root.vals
        arr.append([5, 6, "7"])
        arr.append("1,2,3")
        h5file.close()

        self.expectException("incorrectFile.h5")
示例#11
0
文件: pytables.py 项目: adalke/FPSim2
def sort_db_file(filename: str) -> None:
    """Sorts the FPs db file."""
    # rename not sorted filename
    tmp_filename = filename + "_tmp"
    os.rename(filename, tmp_filename)
    filters = tb.Filters(complib="blosc", complevel=5)

    # copy sorted fps and config to a new file
    with tb.open_file(tmp_filename, mode="r") as fp_file:
        with tb.open_file(filename, mode="w") as sorted_fp_file:
            fp_type = fp_file.root.config[0]
            fp_params = fp_file.root.config[1]
            fp_length = get_fp_length(fp_type, fp_params)

            # create a sorted copy of the fps table
            dst_fps = fp_file.root.fps.copy(
                sorted_fp_file.root,
                "fps",
                filters=filters,
                copyuserattrs=True,
                overwrite=True,
                stats={
                    "groups": 0,
                    "leaves": 0,
                    "links": 0,
                    "bytes": 0,
                    "hardlinks": 0,
                },
                start=None,
                stop=None,
                step=None,
                chunkshape="keep",
                sortby="popcnt",
                check_CSI=True,
                propindexes=True,
            )

            # set config table; used fp function, parameters and rdkit version
            param_table = sorted_fp_file.create_vlarray(
                sorted_fp_file.root, "config", atom=tb.ObjectAtom()
            )
            param_table.append(fp_type)
            param_table.append(fp_params)
            param_table.append(rdkit.__version__)

            # update count ranges
            popcnt_bins = calc_popcnt_bins_pytables(dst_fps, fp_length)
            param_table.append(popcnt_bins)

    # remove not sorted file
    os.remove(tmp_filename)
示例#12
0
    def testWrongTypes1(self):

        h5file = tables.openFile("wrongTypes1.h5", mode="w")

        h5file.createTable("/", "keys", {"key": tables.StringCol(itemsize=16)})
        row = h5file.root.keys.row
        row['key'] = 'key1'
        row.append()
        h5file.root.keys.flush()

        h5file.createVLArray("/",
                             "vals",
                             tables.ObjectAtom(),
                             filters=tables.Filters(complevel=1))
        vals = h5file.root.vals
        vals.append([1, 2, 3])

        h5file.close()

        self.expectException("wrongTypes1.h5")
示例#13
0
    def testWrongDims(self):

        h5file = tables.openFile("wrongDims.h5", mode="w")

        h5file.createVLArray("/",
                             "keys",
                             tables.VLStringAtom(),
                             filters=tables.Filters(complevel=1))
        keys = h5file.root.keys
        keys.append("key1")
        keys.append("key2")

        h5file.createVLArray("/",
                             "vals",
                             tables.ObjectAtom(),
                             filters=tables.Filters(complevel=1))
        vals = h5file.root.vals
        vals.append([1, 2, 3])

        h5file.close()

        self.expectException("wrongDims.h5")
示例#14
0
def _save_level(handler, group, level, name=None):
    if isinstance(level, dict):
        # First create a new group
        new_group = handler.createGroup(group, name, "dict:{}".format(len(level)))
        for k, v in level.items():
            _save_level(handler, new_group, v, name=k) 
    elif isinstance(level, list):
        # Lists can contain other dictionaries and numpy arrays, so we don't want to
        # serialize them. Instead, we will store each entry as i0, i1, etc.
        new_group = handler.createGroup(group, name, "list:{}".format(len(level)))

        for i, entry in enumerate(level):
            level_name = 'i{}'.format(i)
            _save_level(handler, new_group, entry, name=level_name)

    elif isinstance(level, np.ndarray):
        atom = tables.Atom.from_dtype(level.dtype)
        node = handler.createCArray(group, name, atom=atom, shape=level.shape, chunkshape=level.shape, filters=COMPRESSION) 
        node[:] = level
    elif isinstance(level, ATTR_TYPES):
        setattr(group._v_attrs, name, level)
    else:
        node = handler.createVLArray(group, name, tables.ObjectAtom())
        node.append(level)
示例#15
0
def import_ch10(irig_filename, hdf5_filename, status_callback=None):
    # Make IRIG 106 library classes
    pkt_io = Packet.IO()
    time_utils = Time.Time(pkt_io)
    decode1553 = MsgDecode1553.Decode1553F1(pkt_io)

    # Initialize variables
    packet_count = 0
    packet_count_1553 = 0
    msg_count_1553 = 0

    # Open the IRIG file
    ret_status = pkt_io.open(irig_filename, Packet.FileMode.READ)
    if ret_status != Status.OK:
        print "Error opening data file %s" % (irig_filename)
        sys.exit(1)

    # If using status callback then get file size
    if status_callback != None:
        file_size = os.stat(irig_filename).st_size

#        ret_status = time_utils.SyncTime(False, 10)
#        if ret_status != Py106.Status.OK:
#            print ("Sync Status = %s" % Py106.Status.Message(ret_status))
#            sys.exit(1)

# Set the default 1553 message table layout version
    layout_version = 2

    # Open the PyTable tables
    ch10_h5_file = tables.openFile(hdf5_filename,
                                   mode="w",
                                   title="Ch 10 Data File")

    # Create the 1553 message table
    if layout_version == 1:
        ch10_bus_data = ch10_h5_file.createVLArray("/",
                                                   "Bus_Data",
                                                   tables.ObjectAtom(),
                                                   title="1553 Bus Data")
    elif layout_version == 2:
        ch10_bus_data = ch10_h5_file.createVLArray("/",
                                                   "Bus_Data",
                                                   tables.UInt16Atom(),
                                                   title="1553 Bus Data")

    ch10_bus_data.attrs.layout_version = layout_version

    # Create the 1553 message index table
    ch10_bus_data_index = ch10_h5_file.createTable("/", "Bus_Data_Index",
                                                   IndexMsg1553,
                                                   "1553 Bus Data Index")

    # Iterate over all the IRIG packets
    for PktHdr in pkt_io.packet_headers():
        packet_count += 1

        # Update the callback function if it exists
        if status_callback != None:
            (status, offset) = pkt_io.get_pos()
            progress = float(offset) / float(file_size)
            status_callback(progress)

        if PktHdr.DataType == Packet.DataType.IRIG_TIME:
            pkt_io.read_data()
            time_utils.SetRelTime()

        if PktHdr.DataType == Packet.DataType.MIL1553_FMT_1:

            packet_count_1553 += 1
            pkt_io.read_data()
            for Msg in decode1553.msgs():
                msg_count_1553 += 1

                # Extract the import 1553 info
                WC = decode1553.word_cnt(Msg.pCmdWord1.contents.Value)

                # Put the 1553 message data into our storage class
                msg_1553 = Msg1553()
                msg_1553.msg_time = time_utils.RelInt2IrigTime(
                    Msg.p1553Hdr.contents.Field.PktTime)
                msg_1553.chan_id = numpy.uint16(pkt_io.Header.ChID)
                msg_1553.header_flags = Msg.p1553Hdr.contents.Field.Flags
                msg_1553.cmd_word_1 = numpy.uint16(
                    Msg.pCmdWord1.contents.Value)
                msg_1553.stat_word_1 = numpy.uint16(
                    Msg.pStatWord1.contents.Value)
                if (Msg.p1553Hdr.contents.Field.Flags.RT2RT == 0):
                    msg_1553.cmd_word_2 = numpy.uint16(0)
                    msg_1553.stat_word_2 = numpy.uint16(0)
                else:
                    msg_1553.cmd_word_2 = numpy.uint16(
                        Msg.pCmdWord2.contents.Value)
                    msg_1553.stat_word_2 = numpy.uint16(
                        Msg.pStatWord2.contents.Value)
                msg_1553.data = numpy.array(Msg.pData.contents[0:WC])
                msg_1553.layout_version = ch10_bus_data.attrs.layout_version
                DataMsg = msg_1553.encode_tuple()

                ch10_bus_data.append(DataMsg)

                # Store the 1553 command word index
                row_offset = ch10_bus_data.nrows - 1
                time_tuple_utc = msg_1553.msg_time.time.timetuple()
                timestamp_utc = calendar.timegm(time_tuple_utc)
                timestamp_utc += msg_1553.msg_time.time.microsecond / 1000000.0

                new_row = ch10_bus_data_index.row
                new_row['offset'] = row_offset
                new_row['time'] = timestamp_utc
                new_row['channel_id'] = pkt_io.Header.ChID
                new_row['rt'] = Msg.pCmdWord1.contents.Field.RTAddr
                new_row['tr'] = Msg.pCmdWord1.contents.Field.TR
                new_row['subaddr'] = Msg.pCmdWord1.contents.Field.SubAddr
                new_row.append()

            # Done with 1553 messages in packet

    ch10_h5_file.flush()
    pkt_io.close()

    return ch10_h5_file
示例#16
0
def create_db_file(io_source,
                   out_fname,
                   fp_func,
                   fp_func_params={},
                   mol_id_prop='mol_id',
                   gen_ids=False,
                   sort_by_popcnt=True):
    """ Create FPSim2 fingerprints file from .smi, .sdf files, python lists or SQLA queries.
    
    :param io_source: .smi or .sdf filename.
    :param out_fname: FPs output filename.
    :param fp_func: Name of fingerprint function to use to generate the fingerprints.
    :param fp_func_params: Parameters for the fingerprint function.
    :param mol_id_prop: Name of the .sdf property to read the molecule id.
    :param gen_ids: Flag to auto-generate ids for the molecules.
    :return: 
    """
    # if params dict is empty use defaults
    if not fp_func_params:
        fp_func_params = FP_FUNC_DEFAULTS[fp_func]

    # get mol supplier
    supplier = get_mol_suplier(io_source)

    fp_length = get_fp_length(fp_func, fp_func_params)

    # set compression
    filters = tb.Filters(complib='blosc', complevel=5)

    # set the output file and fps table
    with tb.open_file(out_fname, mode='w') as fp_file:

        class Particle(tb.IsDescription):
            pass

        # hacky...
        columns = {}
        pos = 1
        columns['fp_id'] = tb.Int64Col(pos=pos)
        for i in range(1, int(fp_length / 64) + 1):
            pos += 1
            columns['f' + str(i)] = tb.UInt64Col(pos=pos)
        columns['popcnt'] = tb.Int64Col(pos=pos + 1)
        Particle.columns = columns

        fps_table = fp_file.create_table(fp_file.root,
                                         'fps',
                                         Particle,
                                         'Table storing fps',
                                         filters=filters)

        # set config table; used fp function, parameters and rdkit version
        param_table = fp_file.create_vlarray(fp_file.root,
                                             'config',
                                             atom=tb.ObjectAtom())
        param_table.append(fp_func)
        param_table.append(fp_func_params)
        param_table.append(rdkit.__version__)

        fps = []
        for mol_id, rdmol in supplier(io_source,
                                      gen_ids,
                                      mol_id_prop=mol_id_prop):
            efp = rdmol_to_efp(rdmol, fp_func, fp_func_params)
            popcnt = py_popcount(np.array(efp, dtype=np.uint64))
            efp.insert(0, mol_id)
            efp.append(popcnt)
            fps.append(tuple(efp))
            if len(fps) == BATCH_WRITE_SIZE:
                fps_table.append(fps)
                fps = []
        # append last batch < 10k
        fps_table.append(fps)

        # create index so table can be sorted
        fps_table.cols.popcnt.create_index(kind='full')

    if sort_by_popcnt:
        sort_db_file(out_fname)
示例#17
0
def create_many_realizations(burn,
                             n,
                             trace,
                             meta,
                             grid_lims,
                             start_year,
                             nmonths,
                             outfile_name,
                             memmax,
                             relp=1e-3,
                             mask_name=None,
                             n_in_trace=None,
                             thinning=10,
                             paramfileINDEX=0,
                             NinThinnedBlock=0,
                             merged_urb=False,
                             TESTRANGE=False,
                             TESTSQUARE=False):
    """
    Creates N realizations from the predictive distribution over the specified space-time mesh.
    """

    # Establish grids
    xllc_here = (xllc + cellsize * (grid_lims['leftCol'] - 1)) * deg_to_rad
    xurc_here = (xllc + cellsize * (grid_lims['rightCol'] - 1)) * deg_to_rad
    yllc_here = (yllc + cellsize *
                 (nrows - grid_lims['bottomRow'])) * deg_to_rad
    yurc_here = (yllc + cellsize * (nrows - grid_lims['topRow'])) * deg_to_rad
    grids = [(xllc_here, xurc_here,
              grid_lims['rightCol'] - grid_lims['leftCol'] + 1),
             (yllc_here, yurc_here,
              grid_lims['bottomRow'] - grid_lims['topRow'] + 1),
             (start_year - 2009., start_year - 2009 + (nmonths - 1) / 12.,
              nmonths)]
    axes = [np.linspace(*grids[i]) for i in xrange(3)]
    grid_shape = (grids[0][2], grids[1][2], grids[2][2])

    if mask_name is not None:
        mask = get_covariate_submesh(mask_name, grid_lims)
    else:
        mask = np.ones(grid_shape[:2])

    if not mask.shape == grid_shape[:2]:
        raise ValueError, 'You screwed up the shapes.'

    # Check that all data are in bounds
    data_locs = meta.logp_mesh[:]
    in_mesh = np.ones(data_locs.shape[0], dtype=bool)
    for i in xrange(len(data_locs)):
        l = data_locs[i]
        for j in xrange(3):
            if l[j] <= grids[j][0] or l[j] >= grids[j][1]:
                in_mesh[i] = False

    print '****', np.sum(in_mesh)
    # from pylab import plot,clf,show
    #
    # print np.asarray(grids)
    #
    # plot(data_locs[:,0], data_locs[:,1], 'k.')
    # dl = data_locs[np.where((data_locs[:,-1]<=grids[-1][1])*(data_locs[:,-1]>=grids[-1][0]))]
    # plot(dl[:,0],dl[:,1],'r.')
    #
    # plot(grids[0][:2],grids[1][:2],'b.',markersize=16)
    # plot(grids[0][:2],grids[1][:2],'b.',markersize=16)
    # # show()
    # from IPython.Debugger import Pdb
    # Pdb(color_scheme='Linux').set_trace()

    # Find the mesh indices closest to the data locations
    data_mesh_indices = np.empty(data_locs.shape, dtype=np.int)

    for i in xrange(len(data_locs)):
        for j in xrange(3):
            data_mesh_indices[i,
                              j] = np.argmin(np.abs(data_locs[i, j] - axes[j]))

    if n_in_trace is None:
        n_in_trace = len(trace.group0.C)
    spacing = (n_in_trace - burn) / n
    indices = np.arange(burn, n_in_trace, spacing)
    N = len(indices)

    outfile = tb.openFile(outfile_name, 'w')
    outfile.createArray('/', 'lon_axis', axes[0], title='Longitude in radians')
    outfile.createArray('/', 'lat_axis', axes[1], title='Latitude in radians')
    outfile.createArray('/',
                        't_axis',
                        axes[2],
                        title='Time in years since 2009')
    outfile.createCArray('/',
                         'realizations',
                         tb.Float32Atom(),
                         shape=(N, grid_shape[1], grid_shape[0],
                                grid_shape[2]),
                         filters=tb.Filters(complevel=1),
                         chunkshape=(1, grid_shape[0], grid_shape[1], 1))

    #Store information to help access the parameter samples that generated the realizations.
    outfile.createArray(
        '/',
        'indices',
        indices,
        title=
        'Indices in the trace file that correspond to the realizations here.')
    new_table = outfile.createTable(
        '/',
        'PyMCsamples',
        description=trace.PyMCsamples.description,
        expectedrows=len(indices),
        title=
        'Trace of numeric-valued variables in model, thinned to be relevant to realizations'
    )
    new_table.append(trace.PyMCsamples[slice(burn, n_in_trace, spacing)])
    outfile.createGroup(
        '/',
        'group0',
        title=
        'Trace of object-valued variables in model, thinned to be relevant to realizations'
    )
    for node in trace.group0._f_iterNodes():
        new_node = outfile.createVLArray('/group0', node.name, tb.ObjectAtom())
        [new_node.append(node[index]) for index in indices]
    outfile.root._v_attrs.orig_filename = trace._v_file.filename

    # Total number of pixels in month.
    npix = grid_shape[0] * grid_shape[1] / thinning**2
    # Maximum number of pixels in tile.
    npixmax = memmax / 4. / data_locs.shape[0]
    # Minimum number of tiles needed.
    ntiles = npix / npixmax
    # Blocks.
    n_blocks_x = n_blocks_y = np.ceil(np.sqrt(ntiles))

    print 'I can afford %i by %i' % (n_blocks_x, n_blocks_y)

    # Scatter this part to many processes
    for i in xrange(len(indices)):
        print 'Realization %i of %i' % (i, N)

        # Pull mean information out of trace
        this_M = trace.group0.M[indices[i]]
        mean_ondata = this_M(data_locs)
        covariate_mesh = np.zeros(grid_shape[:2])
        for key in meta.covariate_names[0]:
            try:
                this_coef = trace.PyMCsamples.col(key + '_coef')[indices[i]]
            except KeyError:
                print 'Warning, no column named %s' % key + '_coef'
                continue

            if merged_urb and key == 'urb':
                print 'Merging urb'
                mean_ondata += (meta.urb[:] +
                                meta.periurb[:])[meta.ui[:]] * this_coef
                this_pred_covariate = (get_covariate_submesh(
                    'urb5km-e_y-x+', grid_lims) + get_covariate_submesh(
                        'periurb5km-e_y-x+', grid_lims)) * this_coef
            else:
                mean_ondata += getattr(meta, key)[:][meta.ui[:]] * this_coef
                this_pred_covariate = get_covariate_submesh(
                    key + '5km-e_y-x+', grid_lims) * this_coef
            covariate_mesh += this_pred_covariate

        # Pull covariance information out of trace
        this_C = trace.group0.C[indices[i]]
        this_C = pm.gp.NearlyFullRankCovariance(this_C.eval_fun,
                                                relative_precision=relp,
                                                **this_C.params)

        #data_vals = trace.PyMCsamples[i]['f'][in_mesh]
        #create_realization(outfile.root.realizations, i, this_C, mean_ondata, this_M, covariate_mesh, data_vals, data_locs, grids, axes, data_mesh_indices, n_blocks_x, n_blocks_y, relp, mask, thinning, indices)

        data_vals = trace.PyMCsamples[indices[i]]['f'][:]
        create_realization(outfile.root, i, this_C, trace.group0.C[indices[i]],
                           mean_ondata, this_M, covariate_mesh, data_vals,
                           data_locs, grids, axes, data_mesh_indices,
                           np.where(in_mesh)[0],
                           np.where(True - in_mesh)[0], n_blocks_x, n_blocks_y,
                           relp, mask, thinning, indices, paramfileINDEX,
                           NinThinnedBlock, TESTRANGE, TESTSQUARE)
        outfile.flush()
    outfile.close()
示例#18
0
def tdms_to_hdf5(tdms_file,
                 h5_file,
                 load_data=True,
                 chan_map='',
                 memmap=True,
                 compression_level=0):
    """
    Converts TDMS data output from the LabView DAQ software used in the Viventi Lab to
    record from multiplexing neural implants. This method will most likely not interpret
    other TDMS files: see npTDMS for general file handling.

    Parameters
    ----------
    tdms_file : path (string)
    h5_file : path (string)
    chan_map : path (string)
        Optional table specifying a channel permutation. The first p rows
        of the outgoing H5 file will be the contents of these channels in
        sequence. The next (N-p) rows will be any channels not specified,
        in the order they are found.
    memmap : bool
    compression_level : int
        Optionally compress the outgoing H5 rows with zlib compression.
        This can reduce the time cost caused by disk access.

    """

    map_dir = tempfile.gettempdir() if memmap else None

    with tables.open_file(h5_file, mode='w') as h5_file:

        tdms_file = nptdms.TdmsFile(tdms_file, memmap_dir=map_dir)
        # assume for now there is only a single group -- see more files later

        # Catch an API change (older version first)
        try:
            t_group = tdms_file.groups()[0]
            group = tdms_file.object(t_group)
            chans = tdms_file.group_channels(t_group)
        except AttributeError:
            group = tdms_file.groups()[0]
            # Headstage channels and BNC channels are presently lumped into the "data" HDF5 array.. it might make sense
            # to separate them
            chans = group.channels()

        n_col = len(chans)
        n_row = len(chans[0])

        # The H5 file will be constructed as follows:
        #  * create a Group for the info section
        #  * create a CArray with zlib(3) compression for the data channels
        #  * create separate Arrays for special values
        #    (SampRate[SamplingRate], numRow[nrRows], numCol[nrColumns],
        #     OSR[OverSampling], numChan[nrColumns+nrBNCs])
        special_conversion = dict(SamplingRate='sampRate',
                                  nrRows='numRow',
                                  nrColumns='numCol',
                                  OverSampling='OSR')
        h5_info = h5_file.create_group(h5_file.root, 'info')
        for (key, val) in group.properties.items():
            if isinstance(val, str):
                # pytables doesn't support strings as arrays
                arr = h5_file.create_vlarray(h5_info,
                                             key,
                                             atom=tables.ObjectAtom())
                arr.append(val)
            elif isinstance(val, np.datetime64):
                h5_file.create_array(h5_info,
                                     key,
                                     obj=val.astype('f8'),
                                     atom=tables.Time64Atom())
            else:
                h5_file.create_array(h5_info, key, obj=val)
                if key in special_conversion:
                    print('caught', key)
                    # Put this array at the top level with new name
                    h5_file.create_array('/', special_conversion[key], obj=val)

        # do extra extra conversions
        try:
            num_chan = group.properties['nrColumns'] + group.properties[
                'nrBNCs']
            h5_file.create_array(h5_file.root, 'numChan', num_chan)
        except KeyError:
            pass
        try:
            mux_ratio = group.properties['OverSampling'] * group.properties[
                'nrRows']
            Fs = float(group.properties['SamplingRate']) / mux_ratio
            h5_file.create_array(h5_file.root, 'Fs', Fs)
        except KeyError:
            print('Could not determine sampling rate')

        h5_file.flush()

        if not load_data:
            return h5_file

        # now get down to the data
        atom = tables.Float64Atom()
        if compression_level > 0:
            filters = tables.Filters(complevel=compression_level,
                                     complib='zlib')
        else:
            filters = None

        d_array = h5_file.create_earray(h5_file.root,
                                        'data',
                                        atom=atom,
                                        shape=(0, n_row),
                                        filters=filters,
                                        expectedrows=n_col)

        # create a reverse lookup to index channels by number
        col_mapping = dict([(ch.properties['NI_ArrayColumn'], ch)
                            for ch in chans])
        # If a channel permutation is requested, lay down channels
        # in that order. Otherwise go in sequential order.
        if chan_map:
            chan_map = np.loadtxt(chan_map).astype('i')
            if chan_map.ndim > 1:
                print(chan_map.shape)
                # the actual channel permutation is in the 1st column
                # the array matrix coordinates are in the next columns
                chan_ij = chan_map[:, 1:3]
                chan_map = chan_map[:, 0]
            else:
                chan_ij = None
            # do any channels not specified at the end
            if len(chan_map) < n_col:
                left_out = set(range(n_col)).difference(chan_map.tolist())
                left_out = sorted(left_out)
                chan_map = np.r_[chan_map, left_out]
        else:
            chan_map = list(range(n_col))
            chan_ij = None

        for n in chan_map:
            # get TDMS column
            ch = col_mapping[n]
            # make a temp array here.. if all data in memory, then this is
            # slightly wasteful, but if it is mmap'd then this is more flexible
            d = ch.data[:]
            d_array.append(d[None, :])
            print('copied channel', ch.path, d_array.shape)

        if chan_ij is not None:
            h5_file.create_array(h5_file.root, 'channel_ij', obj=chan_ij)

    return h5_file
示例#19
0
def prepare_afe_primary_file(tdms_file, write_path=None):

    tdms_path, tdms_name = os.path.split(tdms_file)
    tdms_name = os.path.splitext(tdms_name)[0]
    if write_path is None:
        write_path = tdms_path
    if not os.path.exists(write_path):
        mkdir_p(write_path)
    new_primary_file = os.path.join(write_path, tdms_name + '.h5')

    with tables.open_file(new_primary_file, 'w') as hdf:
        tdms = nptdms.TdmsFile(tdms_file)
        # Translate metadata
        t_group = tdms.groups()[0]
        group = tdms.object(t_group)
        rename_arrays = dict(SamplingRate='sampRate',
                             nrRows='numRow',
                             nrColumns='numCol',
                             OverSampling='OSR')
        h5_info = hdf.create_group('/', 'info')
        for (key, val) in group.properties.items():
            if isinstance(val, str):
                # pytables doesn't support strings as arrays
                arr = hdf.create_vlarray(h5_info,
                                         key,
                                         atom=tables.ObjectAtom())
                arr.append(val)
            else:
                hdf.create_array(h5_info, key, obj=val)
                if key in rename_arrays:
                    hdf.create_array('/', rename_arrays[key], obj=val)

        # Parse channels
        chans = tdms.group_channels(t_group)
        elec_chans = [c for c in chans if 'CH_' in c.channel]
        bnc_chans = [c for c in chans if 'BNC_' in c.channel]
        # For now, files only have 1 row and all channels relate to that row
        num_per_electrode_row = len(elec_chans)

        # do extra extra conversions
        fs = float(group.properties['SamplingRate']) / (
            num_per_electrode_row * group.properties['OverSampling'])
        hdf.create_array(hdf.root, 'Fs', fs)

        # ensure that channels are ordered correctly
        elec_mapping = dict([(ch.properties['NI_ArrayColumn'], ch)
                             for ch in elec_chans])
        # This value is currently always 1 -- but leave this factor in for future flexibility
        num_electrode_rows = len(elec_chans) // num_per_electrode_row
        if num_per_electrode_row * num_electrode_rows < len(elec_chans):
            print('There were excess TDMS channels: {}'.format(
                len(elec_chans)))
        channels_per_row = 32
        sampling_offset = 6
        hdf_array = hdf.create_carray('/',
                                      'data',
                                      atom=tables.Float64Atom(),
                                      shape=(channels_per_row *
                                             num_electrode_rows,
                                             len(elec_chans[0].data)))
        for elec_group in range(num_electrode_rows):
            for c in range(channels_per_row):
                chan_a = elec_mapping[2 * c + sampling_offset]
                chan_b = elec_mapping[2 * c + 1 + sampling_offset]
                hdf_array[elec_group * channels_per_row +
                          c, :] = 0.5 * (chan_a.data + chan_b.data)
            sampling_offset += num_per_electrode_row

        bnc_array = hdf.create_carray('/',
                                      'bnc',
                                      atom=tables.Float64Atom(),
                                      shape=(len(bnc_chans),
                                             len(bnc_chans[0].data)))
        bnc_mapping = dict([(ch.properties['NI_ArrayColumn'] - len(elec_chans),
                             ch) for ch in bnc_chans])
        for n in range(len(bnc_mapping)):
            bnc_array[n, :] = bnc_mapping[n].data
    return
示例#20
0
def atom_(dt):
    if dt is 'obj': return tables.ObjectAtom()
    if dt is 'bool': return tables.BoolAtom()
    return tables.Atom.from_dtype(numpy.dtype(dt))
示例#21
0
def make_hdf5(hdf5_save_name,
              label_list,
              root_directory='.',
              nfft=1024,
              nhop=512,
              fs=22050,
              seglen=30):

    if os.path.exists(hdf5_save_name):
        warnings.warn(
            'hdf5 file {} already exists, new file will not be created'.format(
                hdf5_save_name))
        return

    file_dict = collect_audio(root_directory, label_list)

    # hdf5 setup
    hdf5_file = tables.open_file(hdf5_save_name, mode="w")
    data_node = hdf5_file.create_group(hdf5_file.root, "Data", "Data")
    data_atom = tables.Float32Atom(
    ) if theano.config.floatX == 'float32' else tables.Float64Atom()
    data_atom_complex = tables.ComplexAtom(
        8) if theano.config.floatX == 'float32' else tables.ComplexAtom(16)

    # data nodes
    hdf5_file.create_earray(data_node,
                            'X',
                            atom=data_atom_complex,
                            shape=(0, nfft / 2 + 1),
                            title="features")
    hdf5_file.create_earray(data_node,
                            'y',
                            atom=data_atom,
                            shape=(0, len(label_list)),
                            title="targets")

    targets = range(len(label_list))
    window = np.hanning(nfft)

    file_index = {}
    offset = 0
    for target, key in zip(targets, label_list):
        print 'Processing %s' % key

        for f in file_dict[key.lower()]:

            if f.endswith('.wav'):
                read_fun = audiolab.wavread
            elif f.endswith('.au'):
                read_fun = audiolab.auread
            elif f.endswith('.mp3'):
                read_fun = read_mp3

            # read audio
            audio_data, fstmp, _ = read_fun(os.path.join(root_directory, f))

            # make mono
            if len(audio_data.shape) != 1:
                audio_data = np.sum(audio_data, axis=1) / 2.

            # work with only first seglen seconds
            audio_data = audio_data[:fstmp * seglen]

            # resample audio data
            if fstmp != fs:
                audio_data = samplerate.resample(audio_data, fs / float(fstmp),
                                                 'sinc_best')

            # compute dft
            nframes = (len(audio_data) - nfft) // nhop
            fft_data = np.zeros((nframes, nfft))

            for i in xrange(nframes):
                sup = i * nhop + np.arange(nfft)
                fft_data[i, :] = audio_data[sup] * window

            fft_data = np.fft.fft(fft_data)

            # write dft frames to hdf5 file
            data_node.X.append(fft_data[:, :nfft / 2 + 1])  # keeping phase too

            # write target values to hdf5 file
            one_hot = np.zeros((nframes, len(label_list)))
            one_hot[:, target] = 1
            data_node.y.append(one_hot)

            # keep file-level info
            file_index[f] = (offset, nframes, key.lower(), target)
            offset += nframes

            hdf5_file.flush()

    # write file_index and dft parameters to hdf5 file
    param_node = hdf5_file.create_group(hdf5_file.root, "Param", "Param")
    param_atom = tables.ObjectAtom()

    # save dataset metadata
    hdf5_file.create_vlarray(param_node,
                             'file_index',
                             atom=param_atom,
                             title='file_index')
    param_node.file_index.append(file_index)

    hdf5_file.create_vlarray(param_node,
                             'file_dict',
                             atom=param_atom,
                             title='file_dict')
    param_node.file_dict.append(file_dict)

    hdf5_file.create_vlarray(param_node, 'fft', atom=param_atom, title='fft')
    param_node.fft.append({'nfft': nfft, 'nhop': nhop, 'window': window})

    hdf5_file.create_vlarray(param_node,
                             'label_list',
                             atom=param_atom,
                             title='label_list')
    param_node.label_list.append(label_list)

    hdf5_file.create_vlarray(param_node,
                             'targets',
                             atom=param_atom,
                             title='targets')
    param_node.targets.append(targets)

    hdf5_file.close()
    print ''  # newline
示例#22
0
# (measure current DICE, VOD, etc.)
# Step 5: iterate steps 3-4 until full fix

step_1_network_experiments_paths = ['/datadrive/configs/...', '/datadrive/configs/...', '/datadrive/configs/...']
step_2_network_experiments_paths = ['/datadrive/configs/...']
output_prediction_dir = r""
if not os.path.exists(output_prediction_dir):
    os.mkdir(output_prediction_dir)
#subject_ids = ['40']
subject_index = []
overlap_factor = 0.9

hdf5_file = r""
data_file = tables.open_file(hdf5_file, "a")
filters = tables.Filters(complevel=5, complib='blosc')
pred_storage = data_file.create_vlarray(data_file.root, 'pred', tables.ObjectAtom(), filters=filters,
                                        expectedrows=len(subject_index)) # TODO: needs to be same length as other arrays

step_1_networks = [load_old_model(get_last_model_path(os.path.join(exp_folder, "fetal_net_model")))
                   for exp_folder in step_1_network_experiments_paths]

step_1_configs = []
n_step_1_models = len(step_1_networks)
for i in range(n_step_1_models):
    with open(os.path.join(step_1_network_experiments_paths[i], 'config.json')) as f:
        config = json.load(f)
        step_1_configs.append(config)

step_2_network = [load_old_model(get_last_model_path(os.path.join(exp_folder, "fetal_net_model")))
                   for exp_folder in step_2_network_experiments_paths]
with open(os.path.join(step_2_network_experiments_paths[0], 'config.json')) as f:
示例#23
0
# Python flavor
vlarray = fileh.create_vlarray(root, 'vlarray3b',
                               tables.StringAtom(itemsize=3),
                               "Ragged array of strings")
vlarray.flavor = "python"
vlarray.append(["123", "456", "3"])
vlarray.append(["456", "3"])

# Binary strings
vlarray = fileh.create_vlarray(root, 'vlarray4', tables.UInt8Atom(),
                               "pickled bytes")
data = pickle.dumps((["123", "456"], "3"))
vlarray.append(np.ndarray(buffer=data, dtype=np.uint8, shape=len(data)))

# The next is a way of doing the same than before
vlarray = fileh.create_vlarray(root, 'vlarray5', tables.ObjectAtom(),
                               "pickled object")
vlarray.append([["123", "456"], "3"])

# Boolean arrays are supported as well
vlarray = fileh.create_vlarray(root, 'vlarray6', tables.BoolAtom(),
                               "Boolean atoms")
# The next lines are equivalent...
vlarray.append([1, 0])
vlarray.append([1, 0, 3, 0])  # This will be converted to a boolean
# This gives a TypeError
# vlarray.append([1,0,1])

# Variable length strings
vlarray = fileh.create_vlarray(root, 'vlarray7', tables.VLStringAtom(),
                               "Variable Length String")
示例#24
0
文件: pytables.py 项目: adalke/FPSim2
def create_db_file(
    mols_source: Union[str, IterableType],
    filename: str,
    fp_type: str,
    fp_params: dict = {},
    mol_id_prop: str = "mol_id",
    gen_ids: bool = False,
    sort_by_popcnt: bool = True,
) -> None:
    """Creates FPSim2 FPs db file from .smi, .sdf files or from an iterable.

    Parameters
    ----------
    mols_source : str
        .smi/.sdf filename or iterable.

    filename: float
        Fingerprint database filename.

    fp_type : str
        Fingerprint type used to create the fingerprints.

    fp_params : dict
        Parameters used to create the fingerprints.

    mol_id_prop : str
        Name of the .sdf property to read the molecule id.

    gen_ids : bool
        Autogenerate FP ids.

    sort_by_popcnt: bool
        Whether if the FPs should be sorted or not.

    Returns
    -------
    None
    """
    # if params dict is empty use defaults
    if not fp_params:
        fp_params = FP_FUNC_DEFAULTS[fp_type]
    supplier = get_mol_suplier(mols_source)
    fp_length = get_fp_length(fp_type, fp_params)
    # set compression
    filters = tb.Filters(complib="blosc", complevel=5)

    # set the output file and fps table
    with tb.open_file(filename, mode="w") as fp_file:
        particle = create_schema(fp_length)
        fps_table = fp_file.create_table(
            fp_file.root, "fps", particle, "Table storing fps", filters=filters
        )

        # set config table; used fp function, parameters and rdkit version
        param_table = fp_file.create_vlarray(
            fp_file.root, "config", atom=tb.ObjectAtom()
        )
        param_table.append(fp_type)
        param_table.append(fp_params)
        param_table.append(rdkit.__version__)

        fps = []
        for mol_id, rdmol in supplier(mols_source, gen_ids, mol_id_prop=mol_id_prop):
            efp = rdmol_to_efp(rdmol, fp_type, fp_params)
            popcnt = py_popcount(np.array(efp, dtype=np.uint64))
            efp.insert(0, mol_id)
            efp.append(popcnt)
            fps.append(tuple(efp))
            if len(fps) == BATCH_WRITE_SIZE:
                fps_table.append(fps)
                fps = []
        # append last batch < 10k
        fps_table.append(fps)

        # create index so table can be sorted
        fps_table.cols.popcnt.create_index(kind="full")

    if sort_by_popcnt:
        sort_db_file(filename)
示例#25
0
 def create(cls, filename, pt_path, pt_name, filters=None):
     fileh = get_pyt_handle(filename)
     array = fileh.createVLArray(pt_path, pt_name, tables.ObjectAtom(), filters=filters)
     return cls(array)
示例#26
0
#c2 = c
#print c2.datashape

tlen = 0
for i in range(N):
    #print "i:", i, repr(c2[i]), type(c2[i])
    tlen += len(c2[i][()])
print "time taken for reading in Blaze: %.3f" % (time() - t0)
print "tlen", tlen

# Create a VLArray:
t0 = time()
f = tables.openFile('vlarray.h5', mode='w')
vlarray = f.createVLArray(f.root,
                          'vlarray',
                          tables.ObjectAtom(),
                          "array of objects",
                          filters=tables.Filters(5))

for i in xrange(N):
    vlarray.append(u"s" * N * i)
f.close()
print "time taken for writing in HDF5: %.3f" % (time() - t0)

# Read the VLArray:
t0 = time()
f = tables.openFile('vlarray.h5', mode='r')
vlarray = f.root.vlarray

tlen = 0
for obj in vlarray:
示例#27
0
def _save_pickled(handler, group, level, name=None):
    node = handler.create_vlarray(group, name, tables.ObjectAtom())
    node.append(level)
示例#28
0
    def _initialize(self, funs_to_tally, length):
        """
        Create a group named ``Chain#`` to store all data for this chain.
        The group contains one pyTables Table, and at least one subgroup
        called ``group#``. This subgroup holds ObjectAtoms, which can hold
        pymc objects whose value is not a numerical array.

        There is too much stuff in here. ObjectAtoms should get initialized
        """
        i = self.chains
        self._chains.append(
            self._h5file.create_group("/", 'chain%d' % i, 'Chain #%d' % i))
        current_object_group = self._h5file.create_group(
            self._chains[-1], 'group0', 'Group storing objects.')
        group_counter = 0
        object_counter = 0

        # Create the Table in the chain# group, and ObjectAtoms in
        # chain#/group#.
        table_descr = {}
        for name, fun in six.iteritems(funs_to_tally):

            arr = asarray(fun())

            if arr.dtype is np.dtype('object'):

                self._traces[name]._vlarrays.append(
                    self._h5file.create_vlarray(current_object_group,
                                                name,
                                                tables.ObjectAtom(),
                                                title=name + ' samples.',
                                                filters=self.filter))

                object_counter += 1
                if object_counter % 4096 == 0:
                    group_counter += 1
                    current_object_group = self._h5file.create_group(
                        self._chains[-1], 'group%d' % group_counter,
                        'Group storing objects.')

            else:
                table_descr[name] = tables.Col.from_dtype(
                    dtype((arr.dtype, arr.shape)))

        table = self._h5file.create_table(self._chains[-1],
                                          'PyMCsamples',
                                          table_descr,
                                          title='PyMC samples',
                                          filters=self.filter,
                                          expectedrows=length)

        self._tables.append(table)
        self._rows.append(self._tables[-1].row)

        # Store data objects
        for object in self.model.observed_stochastics:
            if object.keep_trace is True:
                setattr(table.attrs, object.__name__, object.value)

    # Make sure the variables have a corresponding Trace instance.
        for name, fun in six.iteritems(funs_to_tally):
            if name not in self._traces:
                if np.array(fun()).dtype is np.dtype('object'):
                    self._traces[name] = TraceObject(name,
                                                     getfunc=fun,
                                                     db=self)
                else:
                    self._traces[name] = Trace(name, getfunc=fun, db=self)

            self._traces[name]._initialize(self.chains, length)

        self.trace_names.append(list(funs_to_tally.keys()))
        self.chains += 1
示例#29
0
def _save_level(handler, group, level, name=None):
    if isinstance(level, dict):
        # First create a new group
        new_group = handler.create_group(group, name,
                                         "dict:{}".format(len(level)))
        for k, v in level.items():
            if isinstance(k, six.string_types):
                _save_level(handler, new_group, v, name=k)
            else:
                # Key is not string, so it gets a bit more complicated.
                # If the key is not a string, we will store it as a tuple instead,
                # inside a new group
                hsh = hash(k)
                if hsh < 0:
                    hname = 'm{}'.format(-hsh)
                else:
                    hname = '{}'.format(hsh)
                new_group2 = handler.create_group(new_group,
                                                  '__pair_{}'.format(hname),
                                                  "keyvalue_pair")
                new_name = '__pair_{}'.format(hname)
                _save_level(handler, new_group2, k, name='key')
                _save_level(handler, new_group2, v, name='value')

                #new_name = '__keyvalue_pair_{}'.format(hash(name))
                #setattr(group._v_attrs, new_name, (name, level))
    elif isinstance(level, list):
        # Lists can contain other dictionaries and numpy arrays, so we don't
        # want to serialize them. Instead, we will store each entry as i0, i1,
        # etc.
        new_group = handler.create_group(group, name,
                                         "list:{}".format(len(level)))

        for i, entry in enumerate(level):
            level_name = 'i{}'.format(i)
            _save_level(handler, new_group, entry, name=level_name)

    elif isinstance(level, tuple):
        # Lists can contain other dictionaries and numpy arrays, so we don't
        # want to serialize them. Instead, we will store each entry as i0, i1,
        # etc.
        new_group = handler.create_group(group, name,
                                         "tuple:{}".format(len(level)))

        for i, entry in enumerate(level):
            level_name = 'i{}'.format(i)
            _save_level(handler, new_group, entry, name=level_name)

    elif isinstance(level, np.ndarray):
        atom = tables.Atom.from_dtype(level.dtype)
        node = handler.create_carray(group,
                                     name,
                                     atom=atom,
                                     shape=level.shape,
                                     chunkshape=level.shape,
                                     filters=COMPRESSION)
        node[:] = level

    elif isinstance(level, ATTR_TYPES):
        setattr(group._v_attrs, name, level)

    elif level is None:
        # Store a None as an empty group
        new_group = handler.create_group(group, name, "nonetype:")

    else:
        ag.warning(
            '(amitgroup.io.save) Pickling', level, ': '
            'This may cause incompatiblities (for instance between '
            'Python 2 and 3) and should ideally be avoided')
        node = handler.create_vlarray(group, name, tables.ObjectAtom())
        node.append(level)
示例#30
0
def save_bunch(f,
               path,
               b,
               mode='a',
               overwrite_paths=False,
               compress_arrays=0,
               skip_pickles=False):
    """
    Save a Bunch type to an HDF5 group in a new or existing table.

    Arrays, strings, lists, and various scalar types are saved as
    naturally supported array types. Sub-Bunches are written
    recursively in sub-paths. The remaining Bunch elements are
    pickled, preserving their object classification.

    MappedSource and BufferBase types are not saved, but can be reloaded
    if the corresponding FileLoader is included in the Bunch. This is presently
    limited to one FileLoader per HDF5 Group (or path level).

    Parameters
    ----------
    f: path or open tables file
    path: str
        Path in the HDF5 tree (e.g. /branch/node)
    b: Bunch
        Bunch to store at the path
    mode: str
        File access mode (caution: 'w' overwrites the entire file)
    overwrite_paths: bool
        If True, then an existing path in the HDF5 may be over-written
    compress_arrays: int
        Compression level (>0) for arrays. Arrays uncompressed if 0.
    skip_pickles: bool
        Non-array types are "pickled" as strings in pytables, which may be unpickled by
        Python on loading. For maximum compatibility (e.g. Matlab), skip pickling.

    """

    # * create a new group
    # * save any array-like type natively (esp ndarrays)
    # * save everything else as the pickled ObjectAtom
    # * if there are any sub-bunches, then re-enter method with subgroup

    if not isinstance(f, tables.file.File):
        with closing(tables.open_file(f, mode)) as f:
            return save_bunch(f,
                              path,
                              b,
                              overwrite_paths=overwrite_paths,
                              compress_arrays=compress_arrays,
                              skip_pickles=skip_pickles)
    from ecogdata.devices.load.file2data import FileLoader
    # If we want to overwrite a node, check to see that it exists.
    # If we want an exception when trying to overwrite, that will
    # be caught on f.create_group()
    if overwrite_paths:
        try:
            n = f.get_node(path)
            n._f_remove(recursive=True, force=True)
        except NoSuchNodeError:
            pass
    p, node = os.path.split(path)
    if node:
        f.create_group(p, node, createparents=True)

    sub_bunches = list()
    items = iter(b.items())
    pickle_bunch = Bunch()
    mapped_data = list()
    loader_saved = False

    # 1) create arrays for suitable types
    for key, val in items:
        if isinstance(val, FileLoader):
            loader_saved = True
        if isinstance(val, np.ndarray) and len(val.shape):
            atom = tables.Atom.from_dtype(val.dtype)
            if compress_arrays:
                filters = tables.Filters(complevel=compress_arrays,
                                         complib='zlib')
            else:
                filters = None
            ca = f.create_carray(path,
                                 key,
                                 atom=atom,
                                 shape=val.shape,
                                 filters=filters)
            ca[:] = val
        elif type(val) in _h5_seq_types:
            try:
                f.create_array(path, key, val)
            except (TypeError, ValueError) as e:
                pickle_bunch[key] = val
        elif isinstance(val, _not_pickled):
            mapped_data.append(key)
        elif isinstance(val, Bunch):
            sub_bunches.append((key, val))
        else:
            pickle_bunch[key] = val

    # 2) pickle the remaining items (that are not bunches)
    if len(pickle_bunch):
        if skip_pickles:
            print('Warning: these keys are being skipped on path {}'.format(
                path))
            print(pickle_bunch)
        else:
            p_arr = f.create_vlarray(path,
                                     'b_pickle',
                                     atom=tables.ObjectAtom())
            p_arr.append(pickle_bunch)

    # 3) repeat these steps for any bunch elements that are also bunches
    for n, b in sub_bunches:
        #print 'saving', n, b
        subpath = path + '/' + n if path != '/' else path + n
        save_bunch(f,
                   subpath,
                   b,
                   compress_arrays=compress_arrays,
                   skip_pickles=skip_pickles)

    if mapped_data:
        print('Mapped data was skipped: ' + ', '.join(mapped_data))
        if loader_saved:
            print(
                'A data loader object was saved. Use "attempt_reload=True" with load_bunch to recover data.'
            )
    return