def __init__(self, network, devices, data, batches, cache, compression="none"): super(HDFForwardTaskThread, self).__init__('extract', network, devices, data, batches, eval_batch_size=1) self.tags = [] self.cache = cache self.network = network self.num_seqs = 0 if network.get_layer('output'): target = network.get_layer('output').attrs['target'] else: target = 'classes' cache.attrs['numTimesteps'] = 0 cache.attrs['inputPattSize'] = data.num_inputs cache.attrs['numDims'] = 1 cache.attrs['numLabels'] = data.num_outputs[target] self.compression=compression if target in data.labels: hdf5_strings(cache, 'labels', data.labels[target]) try: cache.attrs['numSeqs'] = data.num_seqs except Exception: cache.attrs['numSeqs'] = 1 self.seq_lengths = cache.create_dataset("seqLengths", (cache.attrs['numSeqs'],), dtype='i', maxshape=(None,), compression=compression) else: self.seq_lengths = cache.create_dataset("seqLengths", (cache.attrs['numSeqs'],), dtype='i', compression=compression) self.seq_dims = cache.create_dataset("seqDims", (cache.attrs['numSeqs'], 1), dtype='i', compression=compression) try: self.targets = { k: cache.create_dataset("targets/data/" + k, (data.get_num_timesteps(),), dtype='i', compression=compression) for k in data.get_target_list() } except Exception: self.targets = None self.times = []
def finalize(self): hdf5_strings(self.cache, 'seqTags', self.tags) if self.times: times = self.cache.create_dataset("times", (len(self.times), 2), dtype='f') times[...] = self.times self.cache.attrs['numSeqs'] = self.num_seqs
def __init__(self, network, devices, data, batches, cache, merge={}): super(HDFForwardTaskThread, self).__init__('extract', network, devices, data, batches, eval_batch_size=1) self.tags = [] self.merge = merge self.cache = cache self.network = network self.num_seqs = 0 target = network.get_layer('output').attrs['target'] cache.attrs['numTimesteps'] = 0 cache.attrs['inputPattSize'] = data.num_inputs cache.attrs['numDims'] = 1 cache.attrs['numLabels'] = data.num_outputs[target] if target in data.labels: hdf5_strings(cache, 'labels', data.labels[target]) try: cache.attrs['numSeqs'] = data.num_seqs except Exception: cache.attrs['numSeqs'] = 1 self.seq_lengths = cache.create_dataset("seqLengths", (cache.attrs['numSeqs'],), dtype='i', maxshape=(None,)) else: self.seq_lengths = cache.create_dataset("seqLengths", (cache.attrs['numSeqs'],), dtype='i') self.seq_dims = cache.create_dataset("seqDims", (cache.attrs['numSeqs'], 1), dtype='i') try: self.targets = { k: cache.create_dataset("targets/data/" + k, (data.get_num_timesteps(),), dtype='i') for k in data.get_target_list() } except Exception: self.targets = None self.times = []
def __init__(self, filename, dim, labels=None, ndim=None): """ :param str filename: :param int|None dim: :param int ndim: counted without batch :param list[str]|None labels: """ if ndim is None: if dim is None: ndim = 1 else: ndim = 2 from Util import hdf5_strings self.dim = dim self.ndim = ndim self.labels = labels if labels: assert len(labels) == dim self._file = h5py.File(filename, "w") self._file.attrs['numTimesteps'] = 0 # we will increment this on-the-fly self._other_num_time_steps = 0 self._file.attrs['inputPattSize'] = dim or 1 self._file.attrs['numDims'] = 1 # ignored? self._file.attrs['numLabels'] = dim or 1 self._file.attrs['numSeqs'] = 0 # we will increment this on-the-fly if labels: hdf5_strings(self._file, 'labels', labels) else: self._file.create_dataset('labels', (0,), dtype="S5") self._datasets = {} # type: dict[str, h5py.Dataset] self._tags = [] # type: list[str] self._seq_lengths = self._file.create_dataset("seqLengths", (0, 2), dtype='i', maxshape=(None, 2))
def __init__(self, filename, dim, labels=None, ndim=None, extra_type=None, swmr=False): """ :param str filename: Create file, truncate if exists :param int|None dim: :param int ndim: counted without batch :param list[str]|None labels: :param dict[str,(int,int,str)]|None extra_type: key -> (dim,ndim,dtype) :param bool swmr: see http://docs.h5py.org/en/stable/swmr.html """ from Util import hdf5_strings, unicode import tempfile import os if ndim is None: if dim is None: ndim = 1 else: ndim = 2 self.dim = dim self.ndim = ndim self.labels = labels if labels: assert len(labels) == dim self.filename = filename # By default, we should not override existing data. # If we want that at some later point, we can introduce an option for it. assert not os.path.exists(self.filename) tmp_fd, self.tmp_filename = tempfile.mkstemp(suffix=".hdf") os.close(tmp_fd) self._file = h5py.File(self.tmp_filename, "w", libver='latest' if swmr else None) self._file.attrs['numTimesteps'] = 0 # we will increment this on-the-fly self._file.attrs['inputPattSize'] = dim or 1 self._file.attrs['numDims'] = 1 # ignored? self._file.attrs['numLabels'] = dim or 1 self._file.attrs['numSeqs'] = 0 # we will increment this on-the-fly if labels: hdf5_strings(self._file, 'labels', labels) else: self._file.create_dataset('labels', (0,), dtype="S5") # dtype string length does not matter self._datasets = {} # type: typing.Dict[str, h5py.Dataset] # key -> data # seq_length idx represents (seq_idx,data_key_idx), # where data_key_idx == 0 is for the main input data, # and otherwise data_key_idx == 1 + sorted(self._prepared_extra).index(data_key). # data_key_idx must allow for 2 entries by default, as HDFDataset assumes 'classes' by default. self._seq_lengths = self._file.create_dataset("seqLengths", (0, 2), dtype='i', maxshape=(None, None)) # Note about strings in HDF: http://docs.h5py.org/en/stable/strings.html # Earlier we used S%i, i.e. fixed-sized strings, with the calculated max string length. # noinspection PyUnresolvedReferences dt = h5py.special_dtype(vlen=unicode) self._seq_tags = self._file.create_dataset('seqTags', (0,), dtype=dt, maxshape=(None,)) self._extra_num_time_steps = {} # type: typing.Dict[str,int] # key -> num-steps self._prepared_extra = set() if extra_type: self._prepare_extra(extra_type) if swmr: assert not self._file.swmr_mode # this also checks whether the attribute exists (right version) self._file.swmr_mode = True # See comments in test_SimpleHDFWriter_swmr... raise NotImplementedError("SimpleHDFWriter SWMR is not really finished...")