def __init__(self, corpus_file, phone_info=None, orth_symbols_file=None, orth_replace_map_file=None, add_random_phone_seqs=0, partition_epoch=1, log_skipped_seqs=False, **kwargs): """ :param str corpus_file: Bliss XML or line-based txt. optionally can be gzip. :param dict | None phone_info: if you want to get phone seqs, dict with lexicon_file etc. see _PhoneSeqGenerator :param str | None orth_symbols_file: list of orthography symbols, if you want to get orth symbol seqs :param str | None orth_replace_map_file: JSON file with replacement dict for orth symbols :param int add_random_phone_seqs: will add random seqs with the same len as the real seq as additional data :param bool log_skipped_seqs: log skipped seqs """ super(LmDataset, self).__init__(**kwargs) if orth_symbols_file: assert not phone_info orth_symbols = open(orth_symbols_file).read().splitlines() self.orth_symbols_map = {sym: i for (i, sym) in enumerate(orth_symbols)} self.orth_symbols = orth_symbols self.labels["data"] = orth_symbols self.seq_gen = None else: assert not orth_symbols_file assert isinstance(phone_info, dict) self.seq_gen = _PhoneSeqGenerator(**phone_info) self.orth_symbols = None self.labels["data"] = self.seq_gen.get_class_labels() if orth_replace_map_file: orth_replace_map = load_json(filename=orth_replace_map_file) assert isinstance(orth_replace_map, dict) self.orth_replace_map = {key: parse_orthography_into_symbols(v) for (key, v) in orth_replace_map.items()} else: self.orth_replace_map = {} if len(self.labels["data"]) <= 256: self.dtype = "int8" else: self.dtype = "int32" self.num_outputs = {"data": [len(self.labels["data"]), 1]} self.num_inputs = self.num_outputs["data"][0] self.seq_order = None self.log_skipped_seqs = log_skipped_seqs self.partition_epoch = partition_epoch self.add_random_phone_seqs = add_random_phone_seqs for i in range(add_random_phone_seqs): self.num_outputs["random%i" % i] = self.num_outputs["data"] if _is_bliss(corpus_file): iter_f = _iter_bliss else: iter_f = _iter_txt self.orths = [] print >> log.v4, "LmDataset, loading file", corpus_file iter_f(corpus_file, self.orths.append) # It's only estimated because we might filter some out or so. self._estimated_num_seqs = len(self.orths) // self.partition_epoch
def __init__(self, seq_list_file, seq_lens_file, datasets, data_map, data_dims, data_dtypes=None, window=1, **kwargs): """ :param str seq_list_file: filename. line-separated :param str seq_lens_file: filename. json. dict[str,dict[str,int]], seq-tag -> data-key -> len :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files' :param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key). Should contain 'data' as key. Also defines the target-list, which is all except 'data'. :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr). :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified """ assert window == 1 # not implemented super(MetaDataset, self).__init__(**kwargs) assert self.shuffle_frames_of_nseqs == 0 # not implemented. anyway only for non-recurrent nets self.seq_list_original = open(seq_list_file).read().splitlines() self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original)} self._num_seqs = len(self.seq_list_original) self.data_map = data_map self.dataset_keys = set([m[0] for m in self.data_map.values()]); ":type: set[str]" self.data_keys = set(self.data_map.keys()); ":type: set[str]" assert "data" in self.data_keys self.target_list = sorted(self.data_keys - ["data"]) data_dims = convert_data_dims(data_dims) self.data_dims = data_dims assert "data" in data_dims for key in self.target_list: assert key in data_dims self.num_inputs = data_dims["data"][0] self.num_outputs = data_dims self.data_dtypes = {data_key: _select_dtype(data_key, data_dims, data_dtypes) for data_key in self.data_keys} if seq_lens_file: seq_lens = load_json(filename=seq_lens_file) assert isinstance(seq_lens, dict) # dict[str,NumbersDict], seq-tag -> data-key -> len self._seq_lens = {tag: NumbersDict(l) for (tag, l) in seq_lens.items()} else: self._seq_lens = None if self._seq_lens: self._num_timesteps = sum([self._seq_lens[s] for s in self.seq_list_original]) else: self._num_timesteps = None # Will only init the needed datasets. self.datasets = {key: init_dataset(datasets[key]) for key in self.dataset_keys}
def load_file(self, f): """ Reads the configuration parameters from a file and adds them to the inner set of parameters. :param string|io.TextIOBase|io.StringIO f: """ if isinstance(f, str): import os assert os.path.isfile(f), "config file not found: %r" % f self.files.append(f) filename = f content = open(filename).read() else: # assume stream-like filename = "<config string>" content = f.read() content = content.strip() if content.startswith("#!"): # assume Python from Util import custom_exec # Operate inplace on ourselves. # Also, we want that it's available as the globals() dict, so that defined functions behave well # (they would loose the local context otherwise). user_ns = self.typed_dict # Always overwrite: user_ns.update({ "config": self, "__file__": filename, "__name__": "__crnn_config__" }) custom_exec(content, filename, user_ns, user_ns) return if content.startswith("{"): # assume JSON from Util import load_json json_content = load_json(content=content) assert isinstance(json_content, dict) self.update(json_content) return # old line-based format for line in content.splitlines(): if "#" in line: # Strip away comment. line = line[:line.index("#")] line = line.strip() if not line: continue line = line.split(None, 1) assert len(line) == 2, "unable to parse config line: %r" % line self.add_line(key=line[0], value=line[1])
def load_file(self, f): """ Reads the configuration parameters from a file and adds them to the inner set of parameters. :param string|io.TextIOBase|io.StringIO f: """ if isinstance(f, str): import os assert os.path.isfile(f), "config file not found: %r" % f self.files.append(f) filename = f content = open(filename).read() else: # assume stream-like filename = "<config string>" content = f.read() content = content.strip() if content.startswith("#!"): # assume Python from Util import custom_exec # Operate inplace on ourselves. # Also, we want that it's available as the globals() dict, so that defined functions behave well # (they would loose the local context otherwise). user_ns = self.typed_dict # Always overwrite: user_ns.update({"config": self, "__file__": filename, "__name__": "__crnn_config__"}) custom_exec(content, filename, user_ns, user_ns) return if content.startswith("{"): # assume JSON from Util import load_json json_content = load_json(content=content) assert isinstance(json_content, dict) self.update(json_content) return # old line-based format for line in content.splitlines(): if "#" in line: # Strip away comment. line = line[:line.index("#")] line = line.strip() if not line: continue line = line.split(None, 1) assert len(line) == 2, "unable to parse config line: %r" % line self.add_line(key=line[0], value=line[1])
def __init__(self, datasets, data_map, data_dims, seq_list_file, seq_lens_file=None, data_dtypes=None, window=1, **kwargs): """ :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files' :param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key). Should contain 'data' as key. Also defines the target-list, which is all except 'data'. :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr). :param str seq_list_file: filename. pickle. dict[str,list[str]], dataset-key -> list of sequence tags. If tag is the same for all datasets a line-separated plain text file can be used. :param str seq_lens_file: filename. json. dict[str,dict[str,int]], seq-tag -> data-key -> len. Use if getting sequence length from loading data is too costly. :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified """ assert window == 1 # not implemented super(MetaDataset, self).__init__(**kwargs) assert self.shuffle_frames_of_nseqs == 0 # not implemented. anyway only for non-recurrent nets self.data_map = data_map self.dataset_keys = set([m[0] for m in self.data_map.values()]) ":type: set[str]" self.data_keys = set(self.data_map.keys()) ":type: set[str]" assert "data" in self.data_keys self.target_list = sorted(self.data_keys - {"data"}) self.default_dataset_key = self.data_map["data"][0] if seq_list_file.endswith(".pkl"): import pickle seq_list = pickle.load(open(seq_list_file, 'rb')) else: seq_list = open(seq_list_file).read().splitlines() assert isinstance(seq_list, (list, dict)) if isinstance(seq_list, list): seq_list = {key: seq_list for key in self.dataset_keys} self.seq_list_original = seq_list # type: dict[str,list[str]] # dataset key -> seq list self._num_seqs = len(self.seq_list_original[self.default_dataset_key]) for key in self.dataset_keys: assert len(self.seq_list_original[key]) == self._num_seqs self.tag_idx = { tag: idx for (idx, tag) in enumerate(self.seq_list_original[ self.default_dataset_key]) } data_dims = convert_data_dims(data_dims) self.data_dims = data_dims assert "data" in data_dims for key in self.target_list: assert key in data_dims self.num_inputs = data_dims["data"][0] self.num_outputs = data_dims self.data_dtypes = { data_key: _select_dtype(data_key, data_dims, data_dtypes) for data_key in self.data_keys } if seq_lens_file: seq_lens = load_json(filename=seq_lens_file) assert isinstance(seq_lens, dict) # dict[str,NumbersDict], seq-tag -> data-key -> len self._seq_lens = { tag: NumbersDict(l) for (tag, l) in seq_lens.items() } else: self._seq_lens = None if self._seq_lens: self._num_timesteps = sum([ self._seq_lens[s] for s in self.seq_list_original[self.default_dataset_key] ]) else: self._num_timesteps = None # Will only init the needed datasets. self.datasets = { key: init_dataset(datasets[key], extra_kwargs={"name": "%s_%s" % (self.name, key)}) for key in self.dataset_keys } for data_key in self.data_keys: dataset_key, dataset_data_key = self.data_map[data_key] dataset = self.datasets[dataset_key] if dataset_data_key in dataset.labels: self.labels[data_key] = dataset.labels[dataset_data_key]
def __init__(self, corpus_file, orth_symbols_file=None, orth_symbols_map_file=None, orth_replace_map_file=None, word_based=False, seq_end_symbol="[END]", unknown_symbol="[UNKNOWN]", parse_orth_opts=None, phone_info=None, add_random_phone_seqs=0, partition_epoch=1, auto_replace_unknown_symbol=False, log_auto_replace_unknown_symbols=10, log_skipped_seqs=10, error_on_invalid_seq=True, add_delayed_seq_data=False, delayed_seq_data_start_symbol="[START]", **kwargs): """ :param str|()->str corpus_file: Bliss XML or line-based txt. optionally can be gzip. :param dict|None phone_info: if you want to get phone seqs, dict with lexicon_file etc. see PhoneSeqGenerator :param str|()->str|None orth_symbols_file: list of orthography symbols, if you want to get orth symbol seqs :param str|()->str|None orth_symbols_map_file: list of orth symbols, each line: "symbol index" :param str|()->str|None orth_replace_map_file: JSON file with replacement dict for orth symbols :param bool word_based: whether to parse single words, or otherwise will be char-based :param str|None seq_end_symbol: what to add at the end, if given. will be set as postfix=[seq_end_symbol] or postfix=[] for parse_orth_opts. :param dict[str]|None parse_orth_opts: kwargs for parse_orthography() :param int add_random_phone_seqs: will add random seqs with the same len as the real seq as additional data :param bool|int log_auto_replace_unknown_symbols: write about auto-replacements with unknown symbol. if this is an int, it will only log the first N replacements, and then keep quiet. :param bool|int log_skipped_seqs: write about skipped seqs to logging, due to missing lexicon entry or so. if this is an int, it will only log the first N entries, and then keep quiet. :param bool error_on_invalid_seq: if there is a seq we would have to skip, error :param bool add_delayed_seq_data: will add another data-key "delayed" which will have the sequence delayed_seq_data_start_symbol + original_sequence[:-1] :param str delayed_seq_data_start_symbol: used for add_delayed_seq_data :param int partition_epoch: whether to partition the epochs into multiple parts. like epoch_split """ super(LmDataset, self).__init__(**kwargs) if callable(corpus_file): corpus_file = corpus_file() if callable(orth_symbols_file): orth_symbols_file = orth_symbols_file() if callable(orth_symbols_map_file): orth_symbols_map_file = orth_symbols_map_file() if callable(orth_replace_map_file): orth_replace_map_file = orth_replace_map_file() print("LmDataset, loading file", corpus_file, file=log.v4) self.word_based = word_based self.seq_end_symbol = seq_end_symbol self.unknown_symbol = unknown_symbol self.parse_orth_opts = parse_orth_opts or {} self.parse_orth_opts.setdefault("word_based", self.word_based) self.parse_orth_opts.setdefault( "postfix", [self.seq_end_symbol] if self.seq_end_symbol is not None else []) if orth_symbols_file: assert not phone_info assert not orth_symbols_map_file orth_symbols = open(orth_symbols_file).read().splitlines() self.orth_symbols_map = { sym: i for (i, sym) in enumerate(orth_symbols) } self.orth_symbols = orth_symbols self.labels["data"] = orth_symbols self.seq_gen = None elif orth_symbols_map_file: assert not phone_info orth_symbols_imap_list = [(int(b), a) for (a, b) in [ l.split(None, 1) for l in open(orth_symbols_map_file).read().splitlines() ]] orth_symbols_imap_list.sort() assert orth_symbols_imap_list[0][0] == 0 assert orth_symbols_imap_list[-1][0] == len( orth_symbols_imap_list) - 1 self.orth_symbols_map = { sym: i for (i, sym) in orth_symbols_imap_list } self.orth_symbols = [sym for (i, sym) in orth_symbols_imap_list] self.labels["data"] = self.orth_symbols self.seq_gen = None else: assert not orth_symbols_file assert isinstance(phone_info, dict) self.seq_gen = PhoneSeqGenerator(**phone_info) self.orth_symbols = None self.labels["data"] = self.seq_gen.get_class_labels() if orth_replace_map_file: orth_replace_map = load_json(filename=orth_replace_map_file) assert isinstance(orth_replace_map, dict) self.orth_replace_map = { key: parse_orthography_into_symbols(v, word_based=self.word_based) for (key, v) in orth_replace_map.items() } if self.orth_replace_map: if len(self.orth_replace_map) <= 5: print(" orth_replace_map: %r" % self.orth_replace_map, file=log.v5) else: print(" orth_replace_map: %i entries" % len(self.orth_replace_map), file=log.v5) else: self.orth_replace_map = {} num_labels = len(self.labels["data"]) use_uint_types = False if BackendEngine.is_tensorflow_selected(): use_uint_types = True if num_labels <= 2**7: self.dtype = "int8" elif num_labels <= 2**8 and use_uint_types: self.dtype = "uint8" elif num_labels <= 2**31: self.dtype = "int32" elif num_labels <= 2**32 and use_uint_types: self.dtype = "uint32" elif num_labels <= 2**61: self.dtype = "int64" elif num_labels <= 2**62 and use_uint_types: self.dtype = "uint64" else: raise Exception("cannot handle so much labels: %i" % num_labels) self.num_outputs = {"data": [len(self.labels["data"]), 1]} self.num_inputs = self.num_outputs["data"][0] self.seq_order = None self.auto_replace_unknown_symbol = auto_replace_unknown_symbol self.log_auto_replace_unknown_symbols = log_auto_replace_unknown_symbols self.log_skipped_seqs = log_skipped_seqs self.error_on_invalid_seq = error_on_invalid_seq self.partition_epoch = partition_epoch self.add_random_phone_seqs = add_random_phone_seqs for i in range(add_random_phone_seqs): self.num_outputs["random%i" % i] = self.num_outputs["data"] self.add_delayed_seq_data = add_delayed_seq_data self.delayed_seq_data_start_symbol = delayed_seq_data_start_symbol if add_delayed_seq_data: self.num_outputs["delayed"] = self.num_outputs["data"] if _is_bliss(corpus_file): iter_f = _iter_bliss else: iter_f = _iter_txt self.orths = [] iter_f(corpus_file, self.orths.append) # It's only estimated because we might filter some out or so. self._estimated_num_seqs = len(self.orths) // self.partition_epoch print(" done, loaded %i sequences" % len(self.orths), file=log.v4)
def __init__(self, program_path, patterns_path): self.program_json = load_json(program_path) self.patterns_json = load_json(patterns_path) self.flow_graph = Node('root')