def main(argv): """ Main entry. """ arg_parser = argparse.ArgumentParser( description='Forward something and dump it.') arg_parser.add_argument('returnn_config') arg_parser.add_argument( "--dataset", help="if given the config, specifies the dataset. e.g. 'train'", default="train") arg_parser.add_argument("--reset_partition_epoch", type=int, default=1) arg_parser.add_argument("--reset_seq_ordering", default="sorted_reverse") arg_parser.add_argument("--reset_epoch_wise_filter", default=None) arg_parser.add_argument("--layer", required=True) arg_parser.add_argument('--epoch', type=int, default=1, help="for the dataset") arg_parser.add_argument("--load", help="model to load") arg_parser.add_argument( '--stats', action="store_true", help="calculate mean/stddev stats over stats_layer") arg_parser.add_argument('--dump_stats', help="file-prefix to dump stats to") args, remaining_args = arg_parser.parse_known_args(argv[1:]) init(config_filename=args.returnn_config, command_line_options=remaining_args, args=args) dump(args) rnn.finalize()
def exit(): """ Called by Sprint at exit. """ print("SprintInterface[pid %i] exit()" % (os.getpid(), )) assert isInitialized global isExited if isExited: print("SprintInterface[pid %i] exit called multiple times" % (os.getpid(), )) return isExited = True if isTrainThreadStarted: engine.stop_train_after_epoch_request = True sprintDataset.finish_sprint_epoch( ) # In case this was not called yet. (No PythonSegmentOrdering.) sprintDataset.finalize_sprint( ) # In case this was not called yet. (No PythonSegmentOrdering.) trainThread.join() rnn.finalize() if startTime: print("SprintInterface[pid %i]: elapsed total time: %f" % (os.getpid(), time.time() - startTime), file=log.v3) else: print("SprintInterface[pid %i]: finished (unknown start time)" % os.getpid(), file=log.v3)
def main(argv): """ Main entry. """ assert len(argv) >= 2, "usage: %s <config>" % argv[0] init(config_filename=argv[1], command_line_options=argv[2:]) iterate_epochs() rnn.finalize()
def main(): """ Main entry. """ argparser = argparse.ArgumentParser( description='Dump something from dataset.') argparser.add_argument( 'returnn_config', help="either filename to config-file, or dict for dataset") argparser.add_argument( "--dataset", help="if given the config, specifies the dataset. e.g. 'dev'") argparser.add_argument('--epoch', type=int, default=1) argparser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') argparser.add_argument('--endseq', type=int, default=10, help='end seq idx (inclusive) or -1 (default: 10)') argparser.add_argument('--get_num_seqs', action="store_true") argparser.add_argument( '--type', default='stdout', help="'numpy', 'stdout', 'plot', 'null' (default 'stdout')") argparser.add_argument("--stdout_limit", type=float, default=None, help="e.g. inf to disable") argparser.add_argument("--stdout_as_bytes", action="store_true") argparser.add_argument("--verbosity", type=int, default=4, help="overwrites log_verbosity (default: 4)") argparser.add_argument('--dump_prefix', default='/tmp/returnn.dump-dataset.') argparser.add_argument('--dump_postfix', default='.txt.gz') argparser.add_argument( "--key", default="data", help="data-key, e.g. 'data' or 'classes'. (default: 'data')") argparser.add_argument('--stats', action="store_true", help="calculate mean/stddev stats") argparser.add_argument('--dump_stats', help="file-prefix to dump stats to") args = argparser.parse_args() init(config_str=args.returnn_config, config_dataset=args.dataset, verbosity=args.verbosity) try: dump_dataset(rnn.train_data, args) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def main(argv): """ Main entry. """ arg_parser = argparse.ArgumentParser( description='Dump raw strings from dataset. Same format as in search.') arg_parser.add_argument( '--config', help="filename to config-file. will use dataset 'eval' from it") arg_parser.add_argument("--dataset", help="dataset, overwriting config") arg_parser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') arg_parser.add_argument('--endseq', type=int, default=-1, help='end seq idx (inclusive) or -1 (default: -1)') arg_parser.add_argument( "--key", default="raw", help="data-key, e.g. 'data' or 'classes'. (default: 'raw')") arg_parser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)") arg_parser.add_argument("--out", required=True, help="out-file. py-format as in task=search") args = arg_parser.parse_args(argv[1:]) assert args.config or args.dataset init(config_filename=args.config, log_verbosity=args.verbosity) if args.dataset: dataset = init_dataset(args.dataset) elif config.value("dump_data", "eval") in ["train", "dev", "eval"]: dataset = init_dataset( config.opt_typed_value(config.value("search_data", "eval"))) else: dataset = init_dataset(config.opt_typed_value("wer_data")) dataset.init_seq_order(epoch=1) try: with generic_open(args.out, "w") as output_file: refs = get_raw_strings(dataset=dataset, options=args) output_file.write("{\n") for seq_tag, ref in refs: output_file.write("%r: %r,\n" % (seq_tag, ref)) output_file.write("}\n") print("Done. Wrote to %r." % args.out) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def main(argv): """ Main entry. """ arg_parser = argparse.ArgumentParser(description='Forward something and dump it.') arg_parser.add_argument('returnn_config') arg_parser.add_argument('--epoch', type=int, default=1) arg_parser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') arg_parser.add_argument('--endseq', type=int, default=10, help='end seq idx (inclusive) or -1 (default: 10)') args = arg_parser.parse_args(argv[1:]) init(config_filename=args.returnn_config, command_line_options=[]) dump(rnn.train_data, args) rnn.finalize()
def main(argv): """ Main entry. """ arg_parser = argparse.ArgumentParser(description='Collect orth symbols.') arg_parser.add_argument('input', help="RETURNN config, Corpus Bliss XML or just txt-data") arg_parser.add_argument("--dump_orth", action="store_true") arg_parser.add_argument("--lexicon") args = arg_parser.parse_args(argv[1:]) bliss_filename = None crnn_config_filename = None txt_filename = None if is_bliss(args.input): bliss_filename = args.input print("Read Bliss corpus:", bliss_filename) elif is_returnn_config(args.input): crnn_config_filename = args.input print("Read corpus from RETURNN config:", crnn_config_filename) else: # treat just as txt txt_filename = args.input print("Read corpus from txt-file:", txt_filename) init(config_filename=crnn_config_filename) if bliss_filename: def _iter_corpus(cb): return iter_bliss(bliss_filename, callback=cb) elif txt_filename: def _iter_corpus(cb): return iter_txt(txt_filename, callback=cb) else: def _iter_corpus(cb): return iter_dataset(rnn.train_data, callback=cb) corpus_stats = CollectCorpusStats(args, _iter_corpus) if args.lexicon: print("Lexicon:", args.lexicon) lexicon = Lexicon(args.lexicon) print("Words not in lexicon:") c = 0 for w in sorted(corpus_stats.words): if w not in lexicon.lemmas: print(w) c += 1 print("Count: %i (%f%%)" % (c, 100. * float(c) / len(corpus_stats.words))) else: print("No lexicon provided (--lexicon).") if crnn_config_filename: rnn.finalize()
def main(argv): """ Main entry. """ arg_parser = argparse.ArgumentParser(description='Dump something from dataset.') arg_parser.add_argument('--config', help="filename to config-file. will use dataset 'eval' from it") arg_parser.add_argument("--dataset", help="dataset, overwriting config") arg_parser.add_argument("--refs", help="same format as hyps. alternative to providing dataset/config") arg_parser.add_argument("--hyps", help="hypotheses, dumped via search in py format") arg_parser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') arg_parser.add_argument('--endseq', type=int, default=-1, help='end seq idx (inclusive) or -1 (default: -1)') arg_parser.add_argument("--key", default="raw", help="data-key, e.g. 'data' or 'classes'. (default: 'raw')") arg_parser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)") arg_parser.add_argument("--out", help="if provided, will write WER% (as string) to this file") arg_parser.add_argument("--expect_full", action="store_true", help="full dataset should be scored") args = arg_parser.parse_args(argv[1:]) assert args.config or args.dataset or args.refs init(config_filename=args.config, log_verbosity=args.verbosity) dataset = None refs = None if args.refs: refs = load_hyps_refs(args.refs) elif args.dataset: dataset = init_dataset(args.dataset) elif config.value("wer_data", "eval") in ["train", "dev", "eval"]: dataset = init_dataset(config.opt_typed_value(config.value("search_data", "eval"))) else: dataset = init_dataset(config.opt_typed_value("wer_data")) hyps = load_hyps_refs(args.hyps) global wer_compute wer_compute = WerComputeGraph() with tf_compat.v1.Session(config=tf_compat.v1.ConfigProto(device_count={"GPU": 0})) as _session: global session session = _session session.run(tf_compat.v1.global_variables_initializer()) try: wer = calc_wer_on_dataset(dataset=dataset, refs=refs, options=args, hyps=hyps) print("Final WER: %.02f%%" % (wer * 100), file=log.v1) if args.out: with open(args.out, "w") as output_file: output_file.write("%.02f\n" % (wer * 100)) print("Wrote WER%% to %r." % args.out) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def main(argv): """ Main entry. """ parser = argparse.ArgumentParser( description= "Dump dataset or subset of dataset into external HDF dataset") parser.add_argument( 'config_file_or_dataset', type=str, help="Config file for RETURNN, or directly the dataset init string") parser.add_argument( 'hdf_filename', type=str, help="File name of the HDF dataset, which will be created") parser.add_argument('--start_seq', type=int, default=0, help="Start sequence index of the dataset to dump") parser.add_argument('--end_seq', type=int, default=float("inf"), help="End sequence index of the dataset to dump") parser.add_argument('--epoch', type=int, default=1, help="Optional start epoch for initialization") args = parser.parse_args(argv[1:]) returnn_config = None dataset_config_str = None if _is_crnn_config(args.config_file_or_dataset): returnn_config = args.config_file_or_dataset else: dataset_config_str = args.config_file_or_dataset dataset = init(config_filename=returnn_config, cmd_line_opts=[], dataset_config_str=dataset_config_str) hdf_dataset = hdf_dataset_init(args.hdf_filename) hdf_dump_from_dataset(dataset, hdf_dataset, args) hdf_close(hdf_dataset) rnn.finalize()
def main(): """ Main entry. """ arg_parser = argparse.ArgumentParser() arg_parser.add_argument("--config") arg_parser.add_argument("--cwd", help="will change to this dir") arg_parser.add_argument("--model", help="model filenames") arg_parser.add_argument( "--scores", help="learning_rate_control file, e.g. newbob.data") arg_parser.add_argument("--dry_run", action="store_true") args = arg_parser.parse_args() return_code = 0 try: if args.cwd: os.chdir(args.cwd) init(extra_greeting="Delete old models.", config_filename=args.config or None, config_updates={ "use_tensorflow": True, "need_data": False, "device": "cpu" }) from returnn.__main__ import engine, config if args.model: config.set("model", args.model) if args.scores: config.set("learning_rate_file", args.scores) if args.dry_run: config.set("dry_run", True) engine.cleanup_old_models(ask_for_confirmation=True) except KeyboardInterrupt: return_code = 1 print("KeyboardInterrupt", file=getattr(log, "v3", sys.stderr)) if getattr(log, "verbose", [False] * 6)[5]: sys.excepthook(*sys.exc_info()) finalize() if return_code: sys.exit(return_code)
def main(): """ Main entry. """ arg_parser = argparse.ArgumentParser( description='Anaylize dataset batches.') arg_parser.add_argument( 'returnn_config', help="either filename to config-file, or dict for dataset") arg_parser.add_argument( "--dataset", help="if given the config, specifies the dataset. e.g. 'dev'") arg_parser.add_argument('--epoch', type=int, default=1) arg_parser.add_argument('--endseq', type=int, default=-1, help='end seq idx (inclusive) or -1 (default: 10)') arg_parser.add_argument("--verbosity", type=int, default=5, help="overwrites log_verbosity (default: 4)") arg_parser.add_argument( "--key", default="data", help="data-key, e.g. 'data' or 'classes'. (default: 'data')") arg_parser.add_argument("--use_pretrain", action="store_true") args = arg_parser.parse_args() init(config_str=args.returnn_config, config_dataset=args.dataset, epoch=args.epoch, use_pretrain=args.use_pretrain, verbosity=args.verbosity) try: analyze_dataset(args) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def main(argv): """ Main entry. """ arg_parser = argparse.ArgumentParser(description='Dump network as JSON.') arg_parser.add_argument('returnn_config_file') arg_parser.add_argument('--epoch', default=1, type=int) arg_parser.add_argument('--out', default="/dev/stdout") args = arg_parser.parse_args(argv[1:]) init(config_filename=args.returnn_config_file, command_line_options=[]) pretrain = pretrain_from_config(config) if pretrain: network = pretrain.get_network_for_epoch(args.epoch) else: network = network_json_from_config(config) json_data = network.to_json_content() f = open(args.out, 'w') print(json.dumps(json_data, indent=2, sort_keys=True), file=f) f.close() rnn.finalize()
def main(argv): """ Main entry. """ argparser = argparse.ArgumentParser(description=__doc__) argparser.add_argument("config_file", type=str, help="RETURNN config, or model-dir") argparser.add_argument("--epoch", type=int) argparser.add_argument( '--data', default="train", help= "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'" ) argparser.add_argument('--do_search', default=False, action='store_true') argparser.add_argument('--beam_size', default=12, type=int) argparser.add_argument('--dump_dir', help="for npy or png") argparser.add_argument("--output_file", help="hdf") argparser.add_argument("--device", help="gpu or cpu (default: automatic)") argparser.add_argument("--layers", default=["att_weights"], action="append", help="Layer of subnet to grab") argparser.add_argument("--rec_layer", default="output", help="Subnet layer to grab from; decoder") argparser.add_argument("--enc_layer", default="encoder") argparser.add_argument("--batch_size", type=int, default=5000) argparser.add_argument("--seq_list", default=[], action="append", help="predefined list of seqs") argparser.add_argument("--min_seq_len", default="0", help="can also be dict") argparser.add_argument("--num_seqs", default=-1, type=int, help="stop after this many seqs") argparser.add_argument("--output_format", default="npy", help="npy, png or hdf") argparser.add_argument("--dropout", default=None, type=float, help="if set, overwrites all dropout values") argparser.add_argument("--train_flag", action="store_true") argparser.add_argument("--reset_partition_epoch", type=int, default=1) argparser.add_argument("--reset_seq_ordering", default="sorted_reverse") argparser.add_argument("--reset_epoch_wise_filter", default=None) args = argparser.parse_args(argv[1:]) layers = args.layers assert isinstance(layers, list) config_fn = args.config_file explicit_model_dir = None if os.path.isdir(config_fn): # Assume we gave a model dir. explicit_model_dir = config_fn train_log_dir_config_pattern = "%s/train-*/*.config" % config_fn train_log_dir_configs = sorted(glob(train_log_dir_config_pattern)) assert train_log_dir_configs config_fn = train_log_dir_configs[-1] print("Using this config via model dir:", config_fn) else: assert os.path.isfile(config_fn) model_name = ".".join(config_fn.split("/")[-1].split(".")[:-1]) init_returnn(config_fn=config_fn, args=args) if explicit_model_dir: config.set( "model", "%s/%s" % (explicit_model_dir, os.path.basename(config.value('model', '')))) print("Model file prefix:", config.value('model', '')) if args.do_search: raise NotImplementedError min_seq_length = NumbersDict(eval(args.min_seq_len)) assert args.output_format in ["npy", "png", "hdf"] if args.output_format in ["npy", "png"]: assert args.dump_dir if not os.path.exists(args.dump_dir): os.makedirs(args.dump_dir) plt = ticker = None if args.output_format == "png": import matplotlib.pyplot as plt # need to import early? https://stackoverflow.com/a/45582103/133374 import matplotlib.ticker as ticker dataset_str = args.data if dataset_str in ["train", "dev", "eval"]: dataset_str = "config:%s" % dataset_str extra_dataset_kwargs = {} if args.reset_partition_epoch: print("NOTE: We are resetting partition epoch to %i." % (args.reset_partition_epoch, )) extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch if args.reset_seq_ordering: print("NOTE: We will use %r seq ordering." % (args.reset_seq_ordering, )) extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering if args.reset_epoch_wise_filter: extra_dataset_kwargs["epoch_wise_filter"] = eval( args.reset_epoch_wise_filter) dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs) if hasattr(dataset, "epoch_wise_filter") and args.reset_epoch_wise_filter is None: if dataset.epoch_wise_filter: print("NOTE: Resetting epoch_wise_filter to None.") dataset.epoch_wise_filter = None if args.reset_partition_epoch: assert dataset.partition_epoch == args.reset_partition_epoch if args.reset_seq_ordering: assert dataset.seq_ordering == args.reset_seq_ordering init_net(args, layers) network = rnn.engine.network hdf_writer = None if args.output_format == "hdf": assert args.output_file assert len(layers) == 1 sub_layer = network.get_layer("%s/%s" % (args.rec_layer, layers[0])) from returnn.datasets.hdf import SimpleHDFWriter hdf_writer = SimpleHDFWriter(filename=args.output_file, dim=sub_layer.output.dim, ndim=sub_layer.output.ndim) extra_fetches = { "output": network.layers[args.rec_layer].output.get_placeholder_as_batch_major(), "output_len": network.layers[ args.rec_layer].output.get_sequence_lengths(), # decoder length "encoder_len": network.layers[ args.enc_layer].output.get_sequence_lengths(), # encoder length "seq_idx": network.get_extern_data("seq_idx"), "seq_tag": network.get_extern_data("seq_tag"), "target_data": network.get_extern_data(network.extern_data.default_input), "target_classes": network.get_extern_data(network.extern_data.default_target), } for layer in layers: sub_layer = rnn.engine.network.get_layer("%s/%s" % (args.rec_layer, layer)) extra_fetches[ "rec_%s" % layer] = sub_layer.output.get_placeholder_as_batch_major() dataset.init_seq_order( epoch=1, seq_list=args.seq_list or None) # use always epoch 1, such that we have same seqs dataset_batch = dataset.generate_batches( recurrent_net=network.recurrent, batch_size=args.batch_size, max_seqs=rnn.engine.max_seqs, max_seq_length=sys.maxsize, min_seq_length=min_seq_length, max_total_num_seqs=args.num_seqs, used_data_keys=network.used_data_keys) stats = {layer: Stats() for layer in layers} # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output, output_len, encoder_len, **kwargs): """ :param list[int] seq_idx: len is n_batch :param list[str] seq_tag: len is n_batch :param numpy.ndarray target_data: extern data default input (e.g. "data"), shape e.g. (B,enc-T,...) :param numpy.ndarray target_classes: extern data default target (e.g. "classes"), shape e.g. (B,dec-T,...) :param numpy.ndarray output: rec layer output, shape e.g. (B,dec-T,...) :param numpy.ndarray output_len: rec layer seq len, i.e. decoder length, shape (B,) :param numpy.ndarray encoder_len: encoder seq len, shape (B,) :param kwargs: contains "rec_%s" % l for l in layers, the sub layers (e.g att weights) we are interested in """ n_batch = len(seq_idx) for i in range(n_batch): # noinspection PyShadowingNames for layer in layers: att_weights = kwargs["rec_%s" % layer][i] stats[layer].collect(att_weights.flatten()) if args.output_format == "npy": data = {} for i in range(n_batch): data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i], 'output_len': output_len[i], 'encoder_len': encoder_len[i], } # noinspection PyShadowingNames for layer in [("rec_%s" % layer) for layer in layers]: assert layer in kwargs out = kwargs[layer][i] assert out.ndim >= 2 assert out.shape[0] >= output_len[i] and out.shape[ 1] >= encoder_len[i] data[i][layer] = out[:output_len[i], :encoder_len[i]] fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % ( model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1]) np.save(fname, data) elif args.output_format == "png": for i in range(n_batch): # noinspection PyShadowingNames for layer in layers: extra_postfix = "" if args.dropout is not None: extra_postfix += "_dropout%.2f" % args.dropout elif args.train_flag: extra_postfix += "_train" fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % ( model_name, rnn.engine.epoch, seq_idx[i], layer, extra_postfix) att_weights = kwargs["rec_%s" % layer][i] att_weights = att_weights.squeeze(axis=2) # (out,enc) assert att_weights.shape[0] >= output_len[ i] and att_weights.shape[1] >= encoder_len[i] att_weights = att_weights[:output_len[i], :encoder_len[i]] print("Seq %i, %s: Dump att weights with shape %r to: %s" % (seq_idx[i], seq_tag[i], att_weights.shape, fname)) plt.matshow(att_weights) title = seq_tag[i] if dataset.can_serialize_data( network.extern_data.default_target): title += "\n" + dataset.serialize_data( network.extern_data.default_target, target_classes[i][:output_len[i]]) ax = plt.gca() tick_labels = [ dataset.serialize_data( network.extern_data.default_target, np.array([x], dtype=target_classes[i].dtype)) for x in target_classes[i][:output_len[i]] ] ax.set_yticklabels([''] + tick_labels, fontsize=8) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.title(title) plt.savefig(fname) plt.close() elif args.output_format == "hdf": assert len(layers) == 1 att_weights = kwargs["rec_%s" % layers[0]] hdf_writer.insert_batch(inputs=att_weights, seq_len={ 0: output_len, 1: encoder_len }, seq_tag=seq_tag) else: raise Exception("output format %r" % args.output_format) runner = Runner(engine=rnn.engine, dataset=dataset, batches=dataset_batch, train=False, train_flag=bool(args.dropout) or args.train_flag, extra_fetches=extra_fetches, extra_fetches_callback=fetch_callback) runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch) for layer in layers: stats[layer].dump(stream_prefix="Layer %r " % layer) if not runner.finalized: print("Some error occured, not finalized.") sys.exit(1) if hdf_writer: hdf_writer.close() rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description='Collect orth symbols.') argparser.add_argument( 'input', help="RETURNN config, Corpus Bliss XML or just txt-data") argparser.add_argument( '--frame_time', type=int, default=10, help='time (in ms) per frame. not needed for Corpus Bliss XML') argparser.add_argument('--collect_time', type=int, default=True, help="collect time info. can be slow in some cases") argparser.add_argument('--dump_orth_syms', action='store_true', help="dump all orthographies") argparser.add_argument('--filter_orth_sym', help="dump orthographies which match this filter") argparser.add_argument('--filter_orth_syms_seq', help="dump orthographies which match this filter") argparser.add_argument( '--max_seq_frame_len', type=int, default=float('inf'), help="collect only orthographies <= this max frame len") argparser.add_argument( '--max_seq_orth_len', type=int, default=float('inf'), help="collect only orthographies <= this max orth len") argparser.add_argument('--add_numbers', type=int, default=True, help="add chars 0-9 to orth symbols") argparser.add_argument('--add_lower_alphabet', type=int, default=True, help="add chars a-z to orth symbols") argparser.add_argument('--add_upper_alphabet', type=int, default=True, help="add chars A-Z to orth symbols") argparser.add_argument('--remove_symbols', default="(){}$", help="remove these chars from orth symbols") argparser.add_argument( '--output', help='where to store the symbols (default: dont store)') args = argparser.parse_args(argv[1:]) bliss_filename = None crnn_config_filename = None txt_filename = None if is_bliss(args.input): bliss_filename = args.input elif is_crnn_config(args.input): crnn_config_filename = args.input else: # treat just as txt txt_filename = args.input init(config_filename=crnn_config_filename) if bliss_filename: iter_corpus = lambda cb: iter_bliss( bliss_filename, options=args, callback=cb) elif txt_filename: iter_corpus = lambda cb: iter_txt( txt_filename, options=args, callback=cb) else: iter_corpus = lambda cb: iter_dataset( rnn.train_data, options=args, callback=cb) collect_stats(args, iter_corpus) if crnn_config_filename: rnn.finalize()