def __init__(self, data_key, filename, type=None, allophone_labeling=None): """ :param str data_key: e.g. "data" or "classes" :param str filename: to Sprint cache archive :param str|None type: "feat" or "align" :param dict[str] allophone_labeling: kwargs for :class:`AllophoneLabeling` """ self.data_key = data_key from SprintCache import open_file_archive self.sprint_cache = open_file_archive(filename) if not type: if data_key == "data": type = "feat" elif data_key == "classes": type = "align" else: # Some sensible defaults. if allophone_labeling: type = "align" else: type = "feat" assert type in ["feat", "align"] self.type = type self.allophone_labeling = None if allophone_labeling: from SprintCache import AllophoneLabeling self.allophone_labeling = AllophoneLabeling( **allophone_labeling) self.sprint_cache.setAllophones( self.allophone_labeling.allophone_file) else: assert type != "align", "need allophone_labeling for 'align' type" self.content_keys = [ fn for fn in self.sprint_cache.file_list() if not fn.endswith(".attribs") ] if type == "align": self.num_labels = self.allophone_labeling.num_labels if self.num_labels < 2**7: self.dtype = "int8" elif self.num_labels < 2**15: self.dtype = "int16" else: assert self.num_labels < 2**31 self.dtype = "int32" self.num_dims = 1 if self.allophone_labeling.state_tying_by_allo_state_idx: self.type = "align_raw" elif type == "feat": self.num_labels = self._get_feature_dim() self.num_dims = 2 self.dtype = "float32" else: assert False
def __init__(self, data_key, filename, data_type=None, allophone_labeling=None): """ :param str data_key: e.g. "data" or "classes" :param str filename: to Sprint cache archive :param str|None data_type: "feat" or "align" :param dict[str] allophone_labeling: kwargs for :class:`AllophoneLabeling` """ self.data_key = data_key from SprintCache import open_file_archive self.sprint_cache = open_file_archive(filename) if not data_type: if data_key == "data": data_type = "feat" elif data_key == "classes": data_type = "align" else: # Some sensible defaults. if allophone_labeling: data_type = "align" else: data_type = "feat" assert data_type in ["feat", "align"] self.type = data_type self.allophone_labeling = None if allophone_labeling: from SprintCache import AllophoneLabeling self.allophone_labeling = AllophoneLabeling(**allophone_labeling) self.sprint_cache.set_allophones(self.allophone_labeling.allophone_file) else: assert data_type != "align", "need allophone_labeling for 'align' type" self.content_keys = [fn for fn in self.sprint_cache.file_list() if not fn.endswith(".attribs")] if data_type == "align": self.num_labels = self.allophone_labeling.num_labels if self.num_labels < 2 ** 7: self.dtype = "int8" elif self.num_labels < 2 ** 15: self.dtype = "int16" else: assert self.num_labels < 2 ** 31 self.dtype = "int32" self.num_dims = 1 if self.allophone_labeling.state_tying_by_allo_state_idx: self.type = "align_raw" elif data_type == "feat": self.num_labels = self._get_feature_dim() self.num_dims = 2 self.dtype = "float32" else: assert False
class SprintCacheReader(object): def __init__(self, data_key, filename, type=None, allophone_labeling=None): """ :param str data_key: e.g. "data" or "classes" :param str filename: to Sprint cache archive :param str|None type: "feat" or "align" :param dict[str] allophone_labeling: kwargs for :class:`AllophoneLabeling` """ self.data_key = data_key from SprintCache import open_file_archive self.sprint_cache = open_file_archive(filename) if not type: if data_key == "data": type = "feat" elif data_key == "classes": type = "align" else: # Some sensible defaults. if allophone_labeling: type = "align" else: type = "feat" assert type in ["feat", "align"] self.type = type self.allophone_labeling = None if allophone_labeling: from SprintCache import AllophoneLabeling self.allophone_labeling = AllophoneLabeling( **allophone_labeling) self.sprint_cache.setAllophones( self.allophone_labeling.allophone_file) else: assert type != "align", "need allophone_labeling for 'align' type" self.content_keys = [ fn for fn in self.sprint_cache.file_list() if not fn.endswith(".attribs") ] if type == "align": self.num_labels = self.allophone_labeling.num_labels if self.num_labels < 2**7: self.dtype = "int8" elif self.num_labels < 2**15: self.dtype = "int16" else: assert self.num_labels < 2**31 self.dtype = "int32" self.num_dims = 1 if self.allophone_labeling.state_tying_by_allo_state_idx: self.type = "align_raw" elif type == "feat": self.num_labels = self._get_feature_dim() self.num_dims = 2 self.dtype = "float32" else: assert False def _get_feature_dim(self): assert self.type == "feat" assert self.content_keys times, feats = self.sprint_cache.read(self.content_keys[0], "feat") assert len(times) == len(feats) > 0 feat = feats[0] assert isinstance(feat, numpy.ndarray) assert feat.ndim == 1 return feat.shape[0] def read(self, name): """ :param str name: content-filename for sprint cache :return: numpy array of shape (time, [num_labels]) :rtype: numpy.ndarray """ res = self.sprint_cache.read(name, typ=self.type) if self.type == "align": label_seq = numpy.array([ self.allophone_labeling.get_label_idx(a, s) for (t, a, s) in res ], dtype=self.dtype) assert label_seq.shape == (len(res), ) return label_seq elif self.type == "align_raw": label_seq = numpy.array([ self.allophone_labeling.state_tying_by_allo_state_idx[a] for (t, a, s) in res ], dtype=self.dtype) assert label_seq.shape == (len(res), ) return label_seq elif self.type == "feat": times, feats = res assert len(times) == len(feats) > 0 feat_mat = numpy.array(feats, dtype=self.dtype) assert feat_mat.shape == (len(times), self.num_labels) return feat_mat else: assert False
class SprintCacheReader(object): """ Helper class to read a Sprint cache directly. """ def __init__(self, data_key, filename, data_type=None, allophone_labeling=None): """ :param str data_key: e.g. "data" or "classes" :param str filename: to Sprint cache archive :param str|None data_type: "feat" or "align" :param dict[str] allophone_labeling: kwargs for :class:`AllophoneLabeling` """ self.data_key = data_key from SprintCache import open_file_archive self.sprint_cache = open_file_archive(filename) if not data_type: if data_key == "data": data_type = "feat" elif data_key == "classes": data_type = "align" else: # Some sensible defaults. if allophone_labeling: data_type = "align" else: data_type = "feat" assert data_type in ["feat", "align"] self.type = data_type self.allophone_labeling = None if allophone_labeling: from SprintCache import AllophoneLabeling self.allophone_labeling = AllophoneLabeling(**allophone_labeling) self.sprint_cache.set_allophones(self.allophone_labeling.allophone_file) else: assert data_type != "align", "need allophone_labeling for 'align' type" self.content_keys = [fn for fn in self.sprint_cache.file_list() if not fn.endswith(".attribs")] if data_type == "align": self.num_labels = self.allophone_labeling.num_labels if self.num_labels < 2 ** 7: self.dtype = "int8" elif self.num_labels < 2 ** 15: self.dtype = "int16" else: assert self.num_labels < 2 ** 31 self.dtype = "int32" self.num_dims = 1 if self.allophone_labeling.state_tying_by_allo_state_idx: self.type = "align_raw" elif data_type == "feat": self.num_labels = self._get_feature_dim() self.num_dims = 2 self.dtype = "float32" else: assert False def _get_feature_dim(self): """ :rtype: int """ assert self.type == "feat" assert self.content_keys times, feats = self.sprint_cache.read(self.content_keys[0], "feat") assert len(times) == len(feats) > 0 feat = feats[0] assert isinstance(feat, numpy.ndarray) assert feat.ndim == 1 return feat.shape[0] def read(self, name): """ :param str name: content-filename for sprint cache :return: numpy array of shape (time, [num_labels]) :rtype: numpy.ndarray """ res = self.sprint_cache.read(name, typ=self.type) if self.type == "align": label_seq = numpy.array([self.allophone_labeling.get_label_idx(a, s) for (t, a, s) in res], dtype=self.dtype) assert label_seq.shape == (len(res),) return label_seq elif self.type == "align_raw": label_seq = numpy.array( [self.allophone_labeling.state_tying_by_allo_state_idx[a] for (t, a, s) in res], dtype=self.dtype) assert label_seq.shape == (len(res),) return label_seq elif self.type == "feat": times, feats = res assert len(times) == len(feats) > 0 feat_mat = numpy.array(feats, dtype=self.dtype) assert feat_mat.shape == (len(times), self.num_labels) return feat_mat else: assert False