def main(): args = arg_parser.parse_args() return_code = 0 try: if args.cwd: os.chdir(args.cwd) init(extra_greeting="Delete old models.", configFilename=args.config or None, config_updates={ "use_tensorflow": True, "need_data": False, "device": "cpu" }) from rnn import engine, config if args.model: config.set("model", args.model) if args.scores: config.set("learning_rate_file", args.scores) if args.dry_run: config.set("dry_run", True) engine.cleanup_old_models(ask_for_confirmation=True) except KeyboardInterrupt: return_code = 1 print("KeyboardInterrupt", file=getattr(log, "v3", sys.stderr)) if getattr(log, "verbose", [False] * 6)[5]: sys.excepthook(*sys.exc_info()) finalize() if return_code: sys.exit(return_code)
def exit(): print("SprintInterface[pid %i] exit()" % (os.getpid(), )) assert isInitialized global isExited if isExited: print("SprintInterface[pid %i] exit called multiple times" % (os.getpid(), )) return isExited = True if isTrainThreadStarted: engine.stop_train_after_epoch_request = True sprintDataset.finishSprintEpoch( ) # In case this was not called yet. (No PythonSegmentOrdering.) sprintDataset.finalizeSprint( ) # In case this was not called yet. (No PythonSegmentOrdering.) trainThread.join() rnn.finalize() if startTime: print("SprintInterface[pid %i]: elapsed total time: %f" % (os.getpid(), time.time() - startTime), file=log.v3) else: print("SprintInterface[pid %i]: finished (unknown start time)" % os.getpid(), file=log.v3)
def main(argv): argparser = argparse.ArgumentParser( description='Dump something from dataset.') argparser.add_argument( 'crnn_config', help="either filename to config-file, or dict for dataset") argparser.add_argument('--epoch', type=int, default=1) argparser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') argparser.add_argument('--endseq', type=int, default=10, help='end seq idx (inclusive) or -1 (default: 10)') argparser.add_argument('--get_num_seqs', action="store_true") argparser.add_argument('--type', default='stdout', help="'numpy' or 'stdout'") argparser.add_argument('--dump_prefix', default='/tmp/crnn.dump-dataset.') argparser.add_argument('--dump_postfix', default='.txt.gz') args = argparser.parse_args(argv[1:]) init(config_str=args.crnn_config) try: dump_dataset(rnn.train_data, args) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description='Dump raw strings from dataset. Same format as in search.') argparser.add_argument('--config', help="filename to config-file. will use dataset 'eval' from it") argparser.add_argument("--dataset", help="dataset, overwriting config") argparser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') argparser.add_argument('--endseq', type=int, default=-1, help='end seq idx (inclusive) or -1 (default: -1)') argparser.add_argument("--key", default="raw", help="data-key, e.g. 'data' or 'classes'. (default: 'raw')") argparser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)") argparser.add_argument("--out", required=True, help="out-file. py-format as in task=search") args = argparser.parse_args(argv[1:]) assert args.config or args.dataset init(config_filename=args.config, log_verbosity=args.verbosity) if args.dataset: print(args.dataset) dataset = init_dataset(args.dataset) elif config.value("dump_data", "eval") in ["train", "dev", "eval"]: dataset = init_dataset(config.opt_typed_value(config.value("search_data", "eval"))) else: dataset = init_dataset(config.opt_typed_value("wer_data")) dataset.init_seq_order(epoch=1) try: with generic_open(args.out, "w") as output_file: refs = get_raw_strings(dataset=dataset, options=args) output_file.write("{\n") for seq_tag, ref in refs: output_file.write("%r: %r,\n" % (seq_tag, ref)) output_file.write("}\n") print("Done. Wrote to %r." % args.out) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def main(): argparser = argparse.ArgumentParser(description='Dump something from dataset.') argparser.add_argument('crnn_config', help="either filename to config-file, or dict for dataset") argparser.add_argument('--epoch', type=int, default=1) argparser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') argparser.add_argument('--endseq', type=int, default=10, help='end seq idx (inclusive) or -1 (default: 10)') argparser.add_argument('--get_num_seqs', action="store_true") argparser.add_argument('--type', default='stdout', help="'numpy', 'stdout', 'plot', 'null' (default 'stdout')") argparser.add_argument("--stdout_limit", type=float, default=None, help="e.g. inf to disable") argparser.add_argument("--stdout_as_bytes", action="store_true") argparser.add_argument("--verbosity", type=int, default=4, help="overwrites log_verbosity (default: 4)") argparser.add_argument('--dump_prefix', default='/tmp/crnn.dump-dataset.') argparser.add_argument('--dump_postfix', default='.txt.gz') argparser.add_argument("--key", default="data", help="data-key, e.g. 'data' or 'classes'. (default: 'data')") argparser.add_argument('--stats', action="store_true", help="calculate mean/stddev stats") argparser.add_argument('--dump_stats', help="file-prefix to dump stats to") args = argparser.parse_args() init(config_str=args.crnn_config, verbosity=args.verbosity) try: dump_dataset(rnn.train_data, args) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def main(): argparser = argparse.ArgumentParser(description='Dump something from dataset.') argparser.add_argument('crnn_config', help="either filename to config-file, or dict for dataset") argparser.add_argument("--dataset", help="if given the config, specifies the dataset. e.g. 'dev'") argparser.add_argument('--epoch', type=int, default=1) argparser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') argparser.add_argument('--endseq', type=int, default=10, help='end seq idx (inclusive) or -1 (default: 10)') argparser.add_argument('--get_num_seqs', action="store_true") argparser.add_argument('--type', default='stdout', help="'numpy', 'stdout', 'plot', 'null' (default 'stdout')") argparser.add_argument("--stdout_limit", type=float, default=None, help="e.g. inf to disable") argparser.add_argument("--stdout_as_bytes", action="store_true") argparser.add_argument("--verbosity", type=int, default=4, help="overwrites log_verbosity (default: 4)") argparser.add_argument('--dump_prefix', default='/tmp/crnn.dump-dataset.') argparser.add_argument('--dump_postfix', default='.txt.gz') argparser.add_argument("--key", default="data", help="data-key, e.g. 'data' or 'classes'. (default: 'data')") argparser.add_argument('--stats', action="store_true", help="calculate mean/stddev stats") argparser.add_argument('--dump_stats', help="file-prefix to dump stats to") args = argparser.parse_args() init(config_str=args.crnn_config, config_dataset=args.dataset, verbosity=args.verbosity) try: dump_dataset(rnn.train_data, args) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def main(): args = arg_parser.parse_args() return_code = 0 try: if args.cwd: os.chdir(args.cwd) init( extra_greeting="Delete old models.", config_filename=args.config or None, config_updates={ "use_tensorflow": True, "need_data": False, "device": "cpu"}) from rnn import engine, config if args.model: config.set("model", args.model) if args.scores: config.set("learning_rate_file", args.scores) if args.dry_run: config.set("dry_run", True) engine.cleanup_old_models(ask_for_confirmation=True) except KeyboardInterrupt: return_code = 1 print("KeyboardInterrupt", file=getattr(log, "v3", sys.stderr)) if getattr(log, "verbose", [False] * 6)[5]: sys.excepthook(*sys.exc_info()) finalize() if return_code: sys.exit(return_code)
def main(argv): argparser = argparse.ArgumentParser(description='Dump raw strings from dataset. Same format as in search.') argparser.add_argument('--config', help="filename to config-file. will use dataset 'eval' from it") argparser.add_argument("--dataset", help="dataset, overwriting config") argparser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') argparser.add_argument('--endseq', type=int, default=-1, help='end seq idx (inclusive) or -1 (default: -1)') argparser.add_argument("--key", default="raw", help="data-key, e.g. 'data' or 'classes'. (default: 'raw')") argparser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)") argparser.add_argument("--out", required=True, help="out-file. py-format as in task=search") args = argparser.parse_args(argv[1:]) assert args.config or args.dataset init(config_filename=args.config, log_verbosity=args.verbosity) if args.dataset: dataset = init_dataset(args.dataset) elif config.value("dump_data", "eval") in ["train", "dev", "eval"]: dataset = init_dataset(config.opt_typed_value(config.value("search_data", "eval"))) else: dataset = init_dataset(config.opt_typed_value("wer_data")) dataset.init_seq_order(epoch=1) try: with open(args.out, "w") as output_file: refs = get_raw_strings(dataset=dataset, options=args) output_file.write("{\n") for seq_tag, ref in refs: output_file.write("%r: %r,\n" % (seq_tag, ref)) output_file.write("}\n") print("Done. Wrote to %r." % args.out) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def exit(): print "Python train exit()" assert isInitialized if isTrainThreadStarted: engine.stop_train_after_epoch_request = True sprintDataset.finishSprintEpoch() # In case this was not called yet. (No PythonSegmentOrdering.) sprintDataset.finalizeSprint() # In case this was not called yet. (No PythonSegmentOrdering.) trainThread.join() rnn.finalize() print >> log.v3, ("elapsed total time: %f" % (time.time() - startTime))
def main(argv): argparser = argparse.ArgumentParser(description='Forward something and dump it.') argparser.add_argument('crnn_config_file') argparser.add_argument('--epoch', type=int, default=1) argparser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') argparser.add_argument('--endseq', type=int, default=10, help='end seq idx (inclusive) or -1 (default: 10)') args = argparser.parse_args(argv[1:]) init(configFilename=args.crnn_config_file, commandLineOptions=[]) dump(rnn.train_data, args) rnn.finalize()
def exit(): print "Python train exit()" assert isInitialized if isTrainThreadStarted: engine.stop_train_after_epoch_request = True sprintDataset.finishSprintEpoch( ) # In case this was not called yet. (No PythonSegmentOrdering.) sprintDataset.finalizeSprint( ) # In case this was not called yet. (No PythonSegmentOrdering.) trainThread.join() rnn.finalize() print >> log.v3, ("elapsed total time: %f" % (time.time() - startTime))
def main(argv): argparser = argparse.ArgumentParser(description='Collect orth symbols.') argparser.add_argument( 'input', help="CRNN config, Corpus Bliss XML or just txt-data") argparser.add_argument("--dump_orth", action="store_true") argparser.add_argument("--lexicon") args = argparser.parse_args(argv[1:]) bliss_filename = None crnn_config_filename = None txt_filename = None if is_bliss(args.input): bliss_filename = args.input print("Read Bliss corpus:", bliss_filename) elif is_crnn_config(args.input): crnn_config_filename = args.input print("Read corpus from Returnn config:", crnn_config_filename) else: # treat just as txt txt_filename = args.input print("Read corpus from txt-file:", txt_filename) init(configFilename=crnn_config_filename) if bliss_filename: iter_corpus = lambda cb: iter_bliss( bliss_filename, options=args, callback=cb) elif txt_filename: iter_corpus = lambda cb: iter_txt( txt_filename, options=args, callback=cb) else: iter_corpus = lambda cb: iter_dataset( rnn.train_data, options=args, callback=cb) corpus_stats = CollectCorpusStats(args, iter_corpus) if args.lexicon: print("Lexicon:", args.lexicon) lexicon = Lexicon(args.lexicon) print("Words not in lexicon:") c = 0 for w in sorted(corpus_stats.words): if w not in lexicon.lemmas: print(w) c += 1 print("Count: %i (%f%%)" % (c, 100. * float(c) / len(corpus_stats.words))) else: print("No lexicon provided (--lexicon).") if crnn_config_filename: rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description='Dump something from dataset.') argparser.add_argument('--config', help="filename to config-file. will use dataset 'eval' from it") argparser.add_argument("--dataset", help="dataset, overwriting config") argparser.add_argument("--refs", help="same format as hyps. alternative to providing dataset/config") argparser.add_argument("--hyps", help="hypotheses, dumped via search in py format") argparser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') argparser.add_argument('--endseq', type=int, default=-1, help='end seq idx (inclusive) or -1 (default: -1)') argparser.add_argument("--key", default="raw", help="data-key, e.g. 'data' or 'classes'. (default: 'raw')") argparser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)") argparser.add_argument("--out", help="if provided, will write WER% (as string) to this file") argparser.add_argument("--expect_full", action="store_true", help="full dataset should be scored") args = argparser.parse_args(argv[1:]) assert args.config or args.dataset or args.refs init(config_filename=args.config, log_verbosity=args.verbosity) dataset = None refs = None if args.refs: refs = load_hyps_refs(args.refs) elif args.dataset: dataset = init_dataset(args.dataset) elif config.value("wer_data", "eval") in ["train", "dev", "eval"]: dataset = init_dataset(config.opt_typed_value(config.value("search_data", "eval"))) else: dataset = init_dataset(config.opt_typed_value("wer_data")) hyps = load_hyps_refs(args.hyps) global wer_compute wer_compute = WerComputeGraph() with TFCompat.v1.Session(config=TFCompat.v1.ConfigProto(device_count={"GPU": 0})) as _session: global session session = _session session.run(TFCompat.v1.global_variables_initializer()) try: wer = calc_wer_on_dataset(dataset=dataset, refs=refs, options=args, hyps=hyps) print("Final WER: %.02f%%" % (wer * 100), file=log.v1) if args.out: with open(args.out, "w") as output_file: output_file.write("%.02f\n" % (wer * 100)) print("Wrote WER%% to %r." % args.out) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description='Collect orth symbols.') argparser.add_argument('input', help="CRNN config, Corpus Bliss XML or just txt-data") argparser.add_argument("--dump_orth", action="store_true") argparser.add_argument("--lexicon") args = argparser.parse_args(argv[1:]) bliss_filename = None crnn_config_filename = None txt_filename = None if is_bliss(args.input): bliss_filename = args.input print("Read Bliss corpus:", bliss_filename) elif is_crnn_config(args.input): crnn_config_filename = args.input print("Read corpus from Returnn config:", crnn_config_filename) else: # treat just as txt txt_filename = args.input print("Read corpus from txt-file:", txt_filename) init(configFilename=crnn_config_filename) if bliss_filename: iter_corpus = lambda cb: iter_bliss(bliss_filename, options=args, callback=cb) elif txt_filename: iter_corpus = lambda cb: iter_txt(txt_filename, options=args, callback=cb) else: iter_corpus = lambda cb: iter_dataset(rnn.train_data, options=args, callback=cb) corpus_stats = CollectCorpusStats(args, iter_corpus) if args.lexicon: print("Lexicon:", args.lexicon) lexicon = Lexicon(args.lexicon) print("Words not in lexicon:") c = 0 for w in sorted(corpus_stats.words): if w not in lexicon.lemmas: print(w) c += 1 print("Count: %i (%f%%)" % (c, 100. * float(c) / len(corpus_stats.words))) else: print("No lexicon provided (--lexicon).") if crnn_config_filename: rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description='Dump network as JSON.') argparser.add_argument('crnn_config_file') argparser.add_argument('--epoch', default=1, type=int) argparser.add_argument('--out', default="/dev/stdout") args = argparser.parse_args(argv[1:]) init(configFilename=args.crnn_config_file, commandLineOptions=[]) pretrain = pretrainFromConfig(config) if pretrain: network = pretrain.get_network_for_epoch(args.epoch) else: network = LayerNetwork.from_config_topology(config) json_data = network.to_json_content() f = open(args.out, 'w') print(json.dumps(json_data, indent=2, sort_keys=True), file=f) f.close() rnn.finalize()
def main(): argparser = argparse.ArgumentParser(description='Anaylize dataset batches.') argparser.add_argument('crnn_config', help="either filename to config-file, or dict for dataset") argparser.add_argument("--dataset", help="if given the config, specifies the dataset. e.g. 'dev'") argparser.add_argument('--epoch', type=int, default=1) argparser.add_argument('--endseq', type=int, default=-1, help='end seq idx (inclusive) or -1 (default: 10)') argparser.add_argument("--verbosity", type=int, default=5, help="overwrites log_verbosity (default: 4)") argparser.add_argument("--key", default="data", help="data-key, e.g. 'data' or 'classes'. (default: 'data')") argparser.add_argument("--use_pretrain", action="store_true") args = argparser.parse_args() init( config_str=args.crnn_config, config_dataset=args.dataset, epoch=args.epoch, use_pretrain=args.use_pretrain, verbosity=args.verbosity) try: analyze_dataset(args) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description='Dump network as JSON.') argparser.add_argument('crnn_config_file') argparser.add_argument('--epoch', default=1, type=int) argparser.add_argument('--out', default="/dev/stdout") args = argparser.parse_args(argv[1:]) init(configFilename=args.crnn_config_file, commandLineOptions=[]) pretrain = pretrainFromConfig(config) if pretrain: network = pretrain.get_network_for_epoch(args.epoch) else: network = LayerNetwork.from_config_topology(config) json_data = network.to_json_content() f = open(args.out, 'w') print >> f, json.dumps(json_data, indent=2, sort_keys=True) f.close() rnn.finalize()
def main(argv): parser = argparse.ArgumentParser( description="Dump dataset or subset of dataset in external HDF dataset" ) parser.add_argument( 'config_file_or_dataset', type=str, help="Config file for CRNN, or directly the dataset init string") parser.add_argument( 'hdf_filename', type=str, help="File name of the HDF dataset, which will be created") parser.add_argument('--start_seq', type=int, default=0, help="Start sequence index of the dataset to dump") parser.add_argument('--end_seq', type=int, default=float("inf"), help="End sequence index of the dataset to dump") parser.add_argument('--epoch', type=int, default=1, help="Optional start epoch for initialization") args = parser.parse_args(argv[1:]) crnn_config = None dataset_config_str = None if _is_crnn_config(args.config_file_or_dataset): crnn_config = args.config_file_or_dataset else: dataset_config_str = args.config_file_or_dataset dataset = init(config_filename=crnn_config, cmd_line_opts=[], dataset_config_str=dataset_config_str) print hdf_dataset = hdf_dataset_init(args.hdf_filename) hdf_dump_from_dataset(dataset, hdf_dataset, args) hdf_close(hdf_dataset) rnn.finalize()
def exit(): """ Called by Sprint at exit. """ print("SprintInterface[pid %i] exit()" % (os.getpid(),)) assert isInitialized global isExited if isExited: print("SprintInterface[pid %i] exit called multiple times" % (os.getpid(),)) return isExited = True if isTrainThreadStarted: engine.stop_train_after_epoch_request = True sprintDataset.finish_sprint_epoch() # In case this was not called yet. (No PythonSegmentOrdering.) sprintDataset.finalize_sprint() # In case this was not called yet. (No PythonSegmentOrdering.) trainThread.join() rnn.finalize() if startTime: print("SprintInterface[pid %i]: elapsed total time: %f" % (os.getpid(), time.time() - startTime), file=log.v3) else: print("SprintInterface[pid %i]: finished (unknown start time)" % os.getpid(), file=log.v3)
def main(argv): argparser = argparse.ArgumentParser( description='Dump something from dataset.') argparser.add_argument('crnn_config_file') argparser.add_argument('--epoch', type=int, default=1) argparser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') argparser.add_argument('--endseq', type=int, default=10, help='end seq idx (inclusive) or -1 (default: 10)') argparser.add_argument('--type', default='stdout', help="'numpy' or 'stdout'") argparser.add_argument('--dump_prefix', default='/tmp/crnn.dump-dataset.') argparser.add_argument('--dump_postfix', default='.txt.gz') args = argparser.parse_args(argv[1:]) init(configFilename=args.crnn_config_file, commandLineOptions=[]) dump_dataset(rnn.train_data, args) rnn.finalize()
def main(argv): parser = argparse.ArgumentParser(description="Dump dataset or subset of dataset in external HDF dataset") parser.add_argument('config_file_or_dataset', type=str, help="Config file for CRNN, or directly the dataset init string") parser.add_argument('hdf_filename', type=str, help="File name of the HDF dataset, which will be created") parser.add_argument('--start_seq', type=int, default=0, help="Start sequence index of the dataset to dump") parser.add_argument('--end_seq', type=int, default=float("inf"), help="End sequence index of the dataset to dump") parser.add_argument('--epoch', type=int, default=1, help="Optional start epoch for initialization") args = parser.parse_args(argv[1:]) crnn_config = None dataset_config_str = None if _is_crnn_config(args.config_file_or_dataset): crnn_config = args.config_file_or_dataset else: dataset_config_str = args.config_file_or_dataset dataset = init(config_filename=crnn_config, cmd_line_opts=[], dataset_config_str=dataset_config_str) hdf_dataset = hdf_dataset_init(args.hdf_filename) hdf_dump_from_dataset(dataset, hdf_dataset, args) hdf_close(hdf_dataset) rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description='Collect orth symbols.') argparser.add_argument('input', help="CRNN config, Corpus Bliss XML or just txt-data") argparser.add_argument('--frame_time', type=int, default=10, help='time (in ms) per frame. not needed for Corpus Bliss XML') argparser.add_argument('--collect_time', type=int, default=True, help="collect time info. can be slow in some cases") argparser.add_argument('--dump_orth_syms', action='store_true', help="dump all orthographies") argparser.add_argument('--filter_orth_sym', help="dump orthographies which match this filter") argparser.add_argument('--filter_orth_syms_seq', help="dump orthographies which match this filter") argparser.add_argument('--max_seq_frame_len', type=int, default=float('inf'), help="collect only orthographies <= this max frame len") argparser.add_argument('--max_seq_orth_len', type=int, default=float('inf'), help="collect only orthographies <= this max orth len") argparser.add_argument('--add_numbers', type=int, default=True, help="add chars 0-9 to orth symbols") argparser.add_argument('--add_lower_alphabet', type=int, default=True, help="add chars a-z to orth symbols") argparser.add_argument('--add_upper_alphabet', type=int, default=True, help="add chars A-Z to orth symbols") argparser.add_argument('--remove_symbols', default="(){}$", help="remove these chars from orth symbols") argparser.add_argument('--output', help='where to store the symbols (default: dont store)') args = argparser.parse_args(argv[1:]) bliss_filename = None crnn_config_filename = None txt_filename = None if is_bliss(args.input): bliss_filename = args.input elif is_crnn_config(args.input): crnn_config_filename = args.input else: # treat just as txt txt_filename = args.input init(configFilename=crnn_config_filename) if bliss_filename: iter_corpus = lambda cb: iter_bliss(bliss_filename, options=args, callback=cb) elif txt_filename: iter_corpus = lambda cb: iter_txt(txt_filename, options=args, callback=cb) else: iter_corpus = lambda cb: iter_dataset(rnn.train_data, options=args, callback=cb) collect_stats(args, iter_corpus) if crnn_config_filename: rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description=__doc__) argparser.add_argument("config_file", type=str, help="RETURNN config, or model-dir") argparser.add_argument("--epoch", type=int) argparser.add_argument('--data', default="train", help="e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'") argparser.add_argument('--do_search', default=False, action='store_true') argparser.add_argument('--beam_size', default=12, type=int) argparser.add_argument('--dump_dir', help="for npy or png") argparser.add_argument("--output_file", help="hdf") argparser.add_argument("--device", help="gpu or cpu (default: automatic)") argparser.add_argument("--layers", default=["att_weights"], action="append", help="Layer of subnet to grab") argparser.add_argument("--rec_layer", default="output", help="Subnet layer to grab from; decoder") argparser.add_argument("--enc_layer", default="encoder") argparser.add_argument("--batch_size", type=int, default=5000) argparser.add_argument("--seq_list", default=[], action="append", help="predefined list of seqs") argparser.add_argument("--min_seq_len", default="0", help="can also be dict") argparser.add_argument("--num_seqs", default=-1, type=int, help="stop after this many seqs") argparser.add_argument("--output_format", default="npy", help="npy, png or hdf") argparser.add_argument("--dropout", default=None, type=float, help="if set, overwrites all dropout values") argparser.add_argument("--train_flag", action="store_true") argparser.add_argument("--reset_partition_epoch", type=int, default=1) argparser.add_argument("--reset_seq_ordering", default="sorted_reverse") argparser.add_argument("--reset_epoch_wise_filter", default=None) args = argparser.parse_args(argv[1:]) layers = args.layers assert isinstance(layers, list) config_fn = args.config_file explicit_model_dir = None if os.path.isdir(config_fn): # Assume we gave a model dir. explicit_model_dir = config_fn train_log_dir_config_pattern = "%s/train-*/*.config" % config_fn train_log_dir_configs = sorted(glob(train_log_dir_config_pattern)) assert train_log_dir_configs config_fn = train_log_dir_configs[-1] print("Using this config via model dir:", config_fn) else: assert os.path.isfile(config_fn) model_name = ".".join(config_fn.split("/")[-1].split(".")[:-1]) init_returnn(config_fn=config_fn, args=args) if explicit_model_dir: config.set("model", "%s/%s" % (explicit_model_dir, os.path.basename(config.value('model', '')))) print("Model file prefix:", config.value('model', '')) if args.do_search: raise NotImplementedError min_seq_length = NumbersDict(eval(args.min_seq_len)) assert args.output_format in ["npy", "png", "hdf"] if args.output_format in ["npy", "png"]: assert args.dump_dir if not os.path.exists(args.dump_dir): os.makedirs(args.dump_dir) plt = ticker = None if args.output_format == "png": import matplotlib.pyplot as plt # need to import early? https://stackoverflow.com/a/45582103/133374 import matplotlib.ticker as ticker dataset_str = args.data if dataset_str in ["train", "dev", "eval"]: dataset_str = "config:%s" % dataset_str extra_dataset_kwargs = {} if args.reset_partition_epoch: print("NOTE: We are resetting partition epoch to %i." % (args.reset_partition_epoch,)) extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch if args.reset_seq_ordering: print("NOTE: We will use %r seq ordering." % (args.reset_seq_ordering,)) extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering if args.reset_epoch_wise_filter: extra_dataset_kwargs["epoch_wise_filter"] = eval(args.reset_epoch_wise_filter) dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs) if hasattr(dataset, "epoch_wise_filter") and args.reset_epoch_wise_filter is None: if dataset.epoch_wise_filter: print("NOTE: Resetting epoch_wise_filter to None.") dataset.epoch_wise_filter = None if args.reset_partition_epoch: assert dataset.partition_epoch == args.reset_partition_epoch if args.reset_seq_ordering: assert dataset.seq_ordering == args.reset_seq_ordering init_net(args, layers) network = rnn.engine.network hdf_writer = None if args.output_format == "hdf": assert args.output_file assert len(layers) == 1 sub_layer = network.get_layer("%s/%s" % (args.rec_layer, layers[0])) from HDFDataset import SimpleHDFWriter hdf_writer = SimpleHDFWriter(filename=args.output_file, dim=sub_layer.output.dim, ndim=sub_layer.output.ndim) extra_fetches = { "output": network.layers[args.rec_layer].output.get_placeholder_as_batch_major(), "output_len": network.layers[args.rec_layer].output.get_sequence_lengths(), # decoder length "encoder_len": network.layers[args.enc_layer].output.get_sequence_lengths(), # encoder length "seq_idx": network.get_extern_data("seq_idx"), "seq_tag": network.get_extern_data("seq_tag"), "target_data": network.get_extern_data(network.extern_data.default_input), "target_classes": network.get_extern_data(network.extern_data.default_target), } for l in layers: sub_layer = rnn.engine.network.get_layer("%s/%s" % (args.rec_layer, l)) extra_fetches["rec_%s" % l] = sub_layer.output.get_placeholder_as_batch_major() dataset.init_seq_order(epoch=1, seq_list=args.seq_list or None) # use always epoch 1, such that we have same seqs dataset_batch = dataset.generate_batches( recurrent_net=network.recurrent, batch_size=args.batch_size, max_seqs=rnn.engine.max_seqs, max_seq_length=sys.maxsize, min_seq_length=min_seq_length, max_total_num_seqs=args.num_seqs, used_data_keys=network.used_data_keys) stats = {l: Stats() for l in layers} # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output, output_len, encoder_len, **kwargs): """ :param list[int] seq_idx: len is n_batch :param list[str] seq_tag: len is n_batch :param numpy.ndarray target_data: extern data default input (e.g. "data"), shape e.g. (B,enc-T,...) :param numpy.ndarray target_classes: extern data default target (e.g. "classes"), shape e.g. (B,dec-T,...) :param numpy.ndarray output: rec layer output, shape e.g. (B,dec-T,...) :param numpy.ndarray output_len: rec layer seq len, i.e. decoder length, shape (B,) :param numpy.ndarray encoder_len: encoder seq len, shape (B,) :param kwargs: contains "rec_%s" % l for l in layers, the sub layers (e.g att weights) we are interested in """ n_batch = len(seq_idx) for i in range(n_batch): for l in layers: att_weights = kwargs["rec_%s" % l][i] stats[l].collect(att_weights.flatten()) if args.output_format == "npy": data = {} for i in range(n_batch): data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i], 'output_len': output_len[i], 'encoder_len': encoder_len[i], } for l in [("rec_%s" % l) for l in layers]: assert l in kwargs out = kwargs[l][i] assert out.ndim >= 2 assert out.shape[0] >= output_len[i] and out.shape[1] >= encoder_len[i] data[i][l] = out[:output_len[i], :encoder_len[i]] fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % (model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1]) np.save(fname, data) elif args.output_format == "png": for i in range(n_batch): for l in layers: extra_postfix = "" if args.dropout is not None: extra_postfix += "_dropout%.2f" % args.dropout elif args.train_flag: extra_postfix += "_train" fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % ( model_name, rnn.engine.epoch, seq_idx[i], l, extra_postfix) att_weights = kwargs["rec_%s" % l][i] att_weights = att_weights.squeeze(axis=2) # (out,enc) assert att_weights.shape[0] >= output_len[i] and att_weights.shape[1] >= encoder_len[i] att_weights = att_weights[:output_len[i], :encoder_len[i]] print("Seq %i, %s: Dump att weights with shape %r to: %s" % ( seq_idx[i], seq_tag[i], att_weights.shape, fname)) plt.matshow(att_weights) title = seq_tag[i] if dataset.can_serialize_data(network.extern_data.default_target): title += "\n" + dataset.serialize_data( network.extern_data.default_target, target_classes[i][:output_len[i]]) ax = plt.gca() tick_labels = [ dataset.serialize_data(network.extern_data.default_target, np.array([x], dtype=target_classes[i].dtype)) for x in target_classes[i][:output_len[i]]] ax.set_yticklabels([''] + tick_labels, fontsize=8) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.title(title) plt.savefig(fname) plt.close() elif args.output_format == "hdf": assert len(layers) == 1 att_weights = kwargs["rec_%s" % layers[0]] hdf_writer.insert_batch(inputs=att_weights, seq_len={0: output_len, 1: encoder_len}, seq_tag=seq_tag) else: raise Exception("output format %r" % args.output_format) runner = Runner(engine=rnn.engine, dataset=dataset, batches=dataset_batch, train=False, train_flag=bool(args.dropout) or args.train_flag, extra_fetches=extra_fetches, extra_fetches_callback=fetch_callback) runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch) for l in layers: stats[l].dump(stream_prefix="Layer %r " % l) if not runner.finalized: print("Some error occured, not finalized.") sys.exit(1) if hdf_writer: hdf_writer.close() rnn.finalize()
def main(): argparser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter) argparser.add_argument("--model", required=True, help="or config, or setup") argparser.add_argument("--epoch", required=True, type=int) argparser.add_argument("--prior", help="none, fixed, softmax (default: none)") argparser.add_argument("--prior_scale", type=float, default=1.0) argparser.add_argument("--am_scale", type=float, default=1.0) argparser.add_argument("--tdp_scale", type=float, default=1.0) args = argparser.parse_args() cfg_fn = args.model if "/" not in cfg_fn: cfg_fn = "config-train/%s.config" % cfg_fn assert os.path.exists(cfg_fn) setup_name = os.path.splitext(os.path.basename(cfg_fn))[0] setup_dir = "data-train/%s" % setup_name assert os.path.exists(setup_dir) Globals.setup_name = setup_name Globals.setup_dir = setup_dir Globals.epoch = args.epoch config_update["epoch"] = args.epoch config_update["load_epoch"] = args.epoch config_update["model"] = "%s/net-model/network" % setup_dir import rnn rnn.init(configFilename=cfg_fn, config_updates=config_update, extra_greeting="calc full sum score.") Globals.engine = rnn.engine Globals.config = rnn.config Globals.dataset = rnn.dev_data assert Globals.engine and Globals.config and Globals.dataset # This will init the network, load the params, etc. Globals.engine.init_train_from_config(config=Globals.config, dev_data=Globals.dataset) # Do not modify the network here. Not needed. softmax_prior = get_softmax_prior() prior = args.prior or "none" if prior == "none": prior_filename = None elif prior == "softmax": prior_filename = softmax_prior elif prior == "fixed": prior_filename = "dependencies/prior-fixed-f32.xml" else: raise Exception("invalid prior %r" % prior) print("using prior:", prior) if prior_filename: assert os.path.exists(prior_filename) check_valid_prior(prior_filename) print("Do the stuff...") print("Reinit dataset.") Globals.dataset.init_seq_order(epoch=args.epoch) network_update["out_fullsum_scores"]["eval_locals"][ "am_scale"] = args.am_scale network_update["out_fullsum_scores"]["eval_locals"][ "prior_scale"] = args.prior_scale network_update["out_fullsum_bw"]["tdp_scale"] = args.tdp_scale if prior_filename: network_update["out_fullsum_prior"][ "init"] = "load_txt_file(%r)" % prior_filename else: network_update["out_fullsum_prior"]["init"] = 0 from copy import deepcopy Globals.config.typed_dict["network"] = deepcopy( Globals.config.typed_dict["network"]) Globals.config.typed_dict["network"].update(network_update) # Reinit the network, and copy over params. from Pretrain import pretrainFromConfig pretrain = pretrainFromConfig( Globals.config) # reinit Pretrain topologies if used if pretrain: new_network_desc = pretrain.get_network_json_for_epoch(Globals.epoch) else: new_network_desc = Globals.config.typed_dict["network"] assert "output_fullsum" in new_network_desc print("Init new network.") Globals.engine.maybe_init_new_network(new_network_desc) print("Calc scores.") calc_fullsum_scores(meta=dict(prior=prior, prior_scale=args.prior_scale, am_scale=args.am_scale, tdp_scale=args.tdp_scale)) rnn.finalize() print("Bye.")
def main(argv): assert len(argv) >= 2, "usage: %s <config>" % argv[0] init(configFilename=argv[1], commandLineOptions=argv[2:]) iterateEpochs() rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description=__doc__) argparser.add_argument("config_file", type=str) argparser.add_argument("--epoch", required=False, type=int) argparser.add_argument( '--data', default="train", help= "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'" ) argparser.add_argument('--do_search', default=False, action='store_true') argparser.add_argument('--beam_size', default=12, type=int) argparser.add_argument('--dump_dir', required=True) argparser.add_argument("--device", default="gpu") argparser.add_argument("--layers", default=["att_weights"], action="append", help="Layer of subnet to grab") argparser.add_argument("--rec_layer", default="output", help="Subnet layer to grab from; decoder") argparser.add_argument("--enc_layer", default="encoder") argparser.add_argument("--batch_size", type=int, default=5000) argparser.add_argument("--seq_list", default=[], action="append", help="predefined list of seqs") argparser.add_argument("--min_seq_len", default="0", help="can also be dict") argparser.add_argument("--output_format", default="npy", help="npy or png") argparser.add_argument("--dropout", default=None, type=float, help="if set, overwrites all dropout values") argparser.add_argument("--train_flag", action="store_true") args = argparser.parse_args(argv[1:]) layers = args.layers assert isinstance(layers, list) model_name = ".".join(args.config_file.split("/")[-1].split(".")[:-1]) init_returnn(config_fn=args.config_file, cmd_line_opts=["--device", args.device], args=args) if args.do_search: raise NotImplementedError min_seq_length = NumbersDict(eval(args.min_seq_len)) if not os.path.exists(args.dump_dir): os.makedirs(args.dump_dir) assert args.output_format in ["npy", "png"] if args.output_format == "png": import matplotlib.pyplot as plt # need to import early? https://stackoverflow.com/a/45582103/133374 import matplotlib.ticker as ticker dataset_str = args.data if dataset_str in ["train", "dev", "eval"]: dataset_str = "config:%s" % dataset_str dataset = init_dataset(dataset_str) init_net(args, layers) network = rnn.engine.network extra_fetches = {} for rec_ret_layer in ["rec_%s" % l for l in layers]: extra_fetches[rec_ret_layer] = rnn.engine.network.layers[ rec_ret_layer].output.get_placeholder_as_batch_major() extra_fetches.update({ "output": network.layers[args.rec_layer].output.get_placeholder_as_batch_major(), "output_len": network.layers[ args.rec_layer].output.get_sequence_lengths(), # decoder length "encoder_len": network.layers[ args.enc_layer].output.get_sequence_lengths(), # encoder length "seq_idx": network.get_extern_data("seq_idx"), "seq_tag": network.get_extern_data("seq_tag"), "target_data": network.get_extern_data(network.extern_data.default_input), "target_classes": network.get_extern_data(network.extern_data.default_target), }) dataset.init_seq_order(epoch=rnn.engine.epoch, seq_list=args.seq_list or None) dataset_batch = dataset.generate_batches( recurrent_net=network.recurrent, batch_size=args.batch_size, max_seqs=rnn.engine.max_seqs, max_seq_length=sys.maxsize, min_seq_length=min_seq_length, used_data_keys=network.used_data_keys) stats = {l: Stats() for l in layers} # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output, output_len, encoder_len, **kwargs): for i in range(len(seq_idx)): for l in layers: att_weights = kwargs["rec_%s" % l][i] stats[l].collect(att_weights.flatten()) if args.output_format == "npy": data = {} for i in range(len(seq_idx)): data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i], 'output_len': output_len[i], 'encoder_len': encoder_len[i], } for l in [("rec_%s" % l) for l in layers]: assert l in kwargs out = kwargs[l][i] assert out.ndim >= 2 assert out.shape[0] >= output_len[i] and out.shape[ 1] >= encoder_len[i] data[i][l] = out[:output_len[i], :encoder_len[i]] fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % ( model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1]) np.save(fname, data) elif args.output_format == "png": for i in range(len(seq_idx)): for l in layers: extra_postfix = "" if args.dropout is not None: extra_postfix += "_dropout%.2f" % args.dropout elif args.train_flag: extra_postfix += "_train" fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % ( model_name, rnn.engine.epoch, seq_idx[i], l, extra_postfix) att_weights = kwargs["rec_%s" % l][i] att_weights = att_weights.squeeze(axis=2) # (out,enc) assert att_weights.shape[0] >= output_len[ i] and att_weights.shape[1] >= encoder_len[i] att_weights = att_weights[:output_len[i], :encoder_len[i]] print("Seq %i, %s: Dump att weights with shape %r to: %s" % (seq_idx[i], seq_tag[i], att_weights.shape, fname)) plt.matshow(att_weights) title = seq_tag[i] if dataset.can_serialize_data( network.extern_data.default_target): title += "\n" + dataset.serialize_data( network.extern_data.default_target, target_classes[i][:output_len[i]]) ax = plt.gca() tick_labels = [ dataset.serialize_data( network.extern_data.default_target, np.array([x], dtype=target_classes[i].dtype)) for x in target_classes[i][:output_len[i]] ] ax.set_yticklabels([''] + tick_labels, fontsize=8) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.title(title) plt.savefig(fname) plt.close() else: raise NotImplementedError("output format %r" % args.output_format) runner = Runner(engine=rnn.engine, dataset=dataset, batches=dataset_batch, train=False, train_flag=args.dropout is not None or args.train_flag, extra_fetches=extra_fetches, extra_fetches_callback=fetch_callback) runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch) for l in layers: stats[l].dump(stream_prefix="Layer %r " % l) if not runner.finalized: print("Some error occured, not finalized.") sys.exit(1) rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description='Dump network as JSON.') argparser.add_argument("crnn_config_file", type=str) argparser.add_argument("--epoch", required=False, type=int) argparser.add_argument('--data', default="test") argparser.add_argument('--do_search', default=False, action='store_true') argparser.add_argument('--beam_size', default=12, type=int) argparser.add_argument('--dump_dir', required=True) argparser.add_argument("--device", default="gpu") argparser.add_argument("--layers", default=["att_weights"], action="append", help="Layer of subnet to grab") argparser.add_argument("--rec_layer", default="output", help="Subnet layer to grab from") argparser.add_argument("--batch_size", type=int, default=5000) args = argparser.parse_args(argv[1:]) if not os.path.exists(args.dump_dir): os.makedirs(args.dump_dir) model = ".".join(args.crnn_config_file.split("/")[-1].split(".")[:-1]) init(configFilename=args.crnn_config_file, commandLineOptions=["--device", args.device], args=args) if isinstance(args.layers, str): layers = [args.layers] else: layers = args.layers inject_retrieval_code(args, layers) network = rnn.engine.network assert rnn.eval_data is not None, "provide evaluation data" extra_fetches = {} for rec_ret_layer in ["rec_%s" % l for l in layers]: extra_fetches[rec_ret_layer] = rnn.engine.network.layers[ rec_ret_layer].output.placeholder extra_fetches.update({ "output": network.get_default_output_layer().output. get_placeholder_as_batch_major(), "output_len": network.get_default_output_layer().output.get_sequence_lengths( ), # decoder length "encoder_len": network.layers["encoder"].output.get_sequence_lengths( ), # encoder length "seq_idx": network.get_extern_data("seq_idx", mark_data_key_as_used=True), "seq_tag": network.get_extern_data("seq_tag", mark_data_key_as_used=True), "target_data": network.get_extern_data("data", mark_data_key_as_used=True), "target_classes": network.get_extern_data("classes", mark_data_key_as_used=True), }) dataset_batch = rnn.eval_data.generate_batches( recurrent_net=network.recurrent, batch_size=args.batch_size, max_seqs=rnn.engine.max_seqs, max_seq_length=sys.maxsize, used_data_keys=network.used_data_keys) # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output, output_len, encoder_len, **kwargs): data = {} for i in range(len(seq_idx)): data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i], 'output_len': output_len[i], 'encoder_len': encoder_len[i], } for l in [("rec_%s" % l) for l in layers]: assert l in kwargs data[i][l] = kwargs[l] fname = os.path.join( args.dump_dir, '%s_ep%03d_data_%i_%i.npy' % (model, rnn.engine.epoch, seq_idx[0], seq_idx[-1])) np.save(fname, data) runner = Runner(engine=rnn.engine, dataset=rnn.eval_data, batches=dataset_batch, train=False, extra_fetches=extra_fetches, extra_fetches_callback=fetch_callback) runner.run(report_prefix="att-weights ") assert runner.finalized rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description=__doc__) argparser.add_argument("config_file", type=str, help="RETURNN config, or model-dir") argparser.add_argument("--epoch", type=int) argparser.add_argument( '--data', default="train", help= "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'" ) argparser.add_argument('--do_search', default=False, action='store_true') argparser.add_argument('--beam_size', default=12, type=int) argparser.add_argument('--dump_dir', help="for npy or png") argparser.add_argument("--output_file", help="hdf") argparser.add_argument("--device", help="gpu or cpu (default: automatic)") argparser.add_argument("--layers", default=[], action="append", help="Layer of subnet to grab") argparser.add_argument("--rec_layer", default="output", help="Subnet layer to grab from; decoder") argparser.add_argument("--enc_layer", default="encoder") argparser.add_argument("--batch_size", type=int, default=5000) argparser.add_argument("--seq_list", default=[], action="append", help="predefined list of seqs") argparser.add_argument("--min_seq_len", default="0", help="can also be dict") argparser.add_argument("--num_seqs", default=-1, type=int, help="stop after this many seqs") argparser.add_argument("--output_format", default="npy", help="npy, png or hdf") argparser.add_argument("--dropout", default=None, type=float, help="if set, overwrites all dropout values") argparser.add_argument("--train_flag", action="store_true") argparser.add_argument("--reset_partition_epoch", type=int, default=None) argparser.add_argument("--reset_seq_ordering", default="default") argparser.add_argument("--reset_epoch_wise_filter", default=None) argparser.add_argument('--hmm_fac_fo', default=False, action='store_true') argparser.add_argument('--encoder_sa', default=False, action='store_true') argparser.add_argument('--tf_log_dir', help="for npy or png", default=None) argparser.add_argument("--instead_save_encoder_decoder", action="store_true") args = argparser.parse_args(argv[1:]) layers = args.layers assert isinstance(layers, list) config_fn = args.config_file explicit_model_dir = None if os.path.isdir(config_fn): # Assume we gave a model dir. explicit_model_dir = config_fn train_log_dir_config_pattern = "%s/train-*/*.config" % config_fn train_log_dir_configs = sorted(glob(train_log_dir_config_pattern)) assert train_log_dir_configs config_fn = train_log_dir_configs[-1] print("Using this config via model dir:", config_fn) else: assert os.path.isfile(config_fn) model_name = ".".join(config_fn.split("/")[-1].split(".")[:-1]) init_returnn(config_fn=config_fn, args=args) if explicit_model_dir: config.set( "model", "%s/%s" % (explicit_model_dir, os.path.basename(config.value('model', '')))) print("Model file prefix:", config.value('model', '')) min_seq_length = NumbersDict(eval(args.min_seq_len)) assert args.output_format in ["npy", "png", "hdf"] if args.output_format in ["npy", "png"]: assert args.dump_dir if not os.path.exists(args.dump_dir): os.makedirs(args.dump_dir) plt = ticker = None if args.output_format == "png": import matplotlib.pyplot as plt # need to import early? https://stackoverflow.com/a/45582103/133374 import matplotlib.ticker as ticker dataset_str = args.data if dataset_str in ["train", "dev", "eval"]: dataset_str = "config:%s" % dataset_str extra_dataset_kwargs = {} if args.reset_partition_epoch: print("NOTE: We are resetting partition epoch to %i." % (args.reset_partition_epoch, )) extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch if args.reset_seq_ordering: print("NOTE: We will use %r seq ordering." % (args.reset_seq_ordering, )) extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering if args.reset_epoch_wise_filter: extra_dataset_kwargs["epoch_wise_filter"] = eval( args.reset_epoch_wise_filter) dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs) if hasattr(dataset, "epoch_wise_filter") and args.reset_epoch_wise_filter is None: if dataset.epoch_wise_filter: print("NOTE: Resetting epoch_wise_filter to None.") dataset.epoch_wise_filter = None if args.reset_partition_epoch: assert dataset.partition_epoch == args.reset_partition_epoch if args.reset_seq_ordering: assert dataset.seq_ordering == args.reset_seq_ordering init_net(args, layers) network = rnn.engine.network hdf_writer = None if args.output_format == "hdf": assert args.output_file assert len(layers) == 1 sub_layer = network.get_layer("%s/%s" % (args.rec_layer, layers[0])) from HDFDataset import SimpleHDFWriter hdf_writer = SimpleHDFWriter(filename=args.output_file, dim=sub_layer.output.dim, ndim=sub_layer.output.ndim) extra_fetches = { "output": network.layers[args.rec_layer].output.get_placeholder_as_batch_major(), "output_len": network.layers[ args.rec_layer].output.get_sequence_lengths(), # decoder length "encoder_len": network.layers[ args.enc_layer].output.get_sequence_lengths(), # encoder length "seq_idx": network.get_extern_data("seq_idx"), "seq_tag": network.get_extern_data("seq_tag"), "target_data": network.get_extern_data(network.extern_data.default_input), "target_classes": network.get_extern_data(network.extern_data.default_target), } if args.instead_save_encoder_decoder: extra_fetches["encoder"] = network.layers["encoder"] extra_fetches["decoder"] = network.layers["output"].get_sub_layer( "decoder") else: for l in layers: sub_layer = rnn.engine.network.get_layer("%s/%s" % (args.rec_layer, l)) extra_fetches[ "rec_%s" % l] = sub_layer.output.get_placeholder_as_batch_major() if args.do_search: o_layer = rnn.engine.network.get_layer("output") extra_fetches["beam_scores_" + l] = o_layer.get_search_choices().beam_scores dataset.init_seq_order( epoch=1, seq_list=args.seq_list or None) # use always epoch 1, such that we have same seqs dataset_batch = dataset.generate_batches( recurrent_net=network.recurrent, batch_size=args.batch_size, max_seqs=rnn.engine.max_seqs, max_seq_length=sys.maxsize, min_seq_length=min_seq_length, max_total_num_seqs=args.num_seqs, used_data_keys=network.used_data_keys) stats = {l: Stats() for l in layers} # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output, output_len, encoder_len, **kwargs): """ :param list[int] seq_idx: len is n_batch :param list[str] seq_tag: len is n_batch :param numpy.ndarray target_data: extern data default input (e.g. "data"), shape e.g. (B,enc-T,...) :param numpy.ndarray target_classes: extern data default target (e.g. "classes"), shape e.g. (B,dec-T,...) :param numpy.ndarray output: rec layer output, shape e.g. (B,dec-T,...) :param numpy.ndarray output_len: rec layer seq len, i.e. decoder length, shape (B,) :param numpy.ndarray encoder_len: encoder seq len, shape (B,) :param kwargs: contains "rec_%s" % l for l in layers, the sub layers (e.g att weights) we are interested in """ n_batch = len(seq_idx) for i in range(n_batch): if not args.instead_save_encoder_decoder: for l in layers: att_weights = kwargs["rec_%s" % l][i] stats[l].collect(att_weights.flatten()) if args.output_format == "npy": data = {} if args.do_search: assert not args.instead_save_encoder_decoder, "Not implemented" # find axis with correct beam size axis_beam_size = n_batch * args.beam_size corr_axis = None num_axes = len(kwargs["rec_%s" % layers[0]].shape) for a in range(len(kwargs["rec_%s" % layers[0]].shape)): if kwargs["rec_%s" % layers[0]].shape[a] == axis_beam_size: corr_axis = a break assert corr_axis is not None, "Att Weights Decoding: correct axis not found! maybe check beam size." # set dimensions correctly for l, l_raw in zip([("rec_%s" % l) for l in layers], layers): swap = list(range(num_axes)) del swap[corr_axis] swap.insert(0, corr_axis) kwargs[l] = np.transpose(kwargs[l], swap) for i in range(n_batch): # The first beam contains the score with the highest beam i_beam = args.beam_size * i data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i_beam], 'output_len': output_len[i_beam], 'encoder_len': encoder_len[i], } #if args.hmm_fac_fo is False: for l, l_raw in zip([("rec_%s" % l) for l in layers], layers): assert l in kwargs out = kwargs[l][i_beam] # Do search for multihead # out is [I, H, 1, J] is new version # out is [I, J, H, 1] for old version if len(out.shape) == 3 and min(out.shape) > 1: out = np.transpose( out, axes=(0, 3, 1, 2)) # [I, J, H, 1] new version out = np.squeeze(out, axis=-1) else: out = np.squeeze(out, axis=1) assert out.shape[0] >= output_len[ i_beam] and out.shape[1] >= encoder_len[i] data[i][l] = out[:output_len[i_beam], :encoder_len[i]] fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % ( model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1]) np.save(fname, data) else: for i in range(n_batch): data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i], 'output_len': output_len[i], 'encoder_len': encoder_len[i], } #if args.hmm_fac_fo is False: if args.instead_save_encoder_decoder: out = kwargs["encoder"][i] data[i]["encoder"] = out[:encoder_len[i]] out_2 = kwargs["decoder"][i] data[i]["decoder"] = out_2[:output_len[i]] else: for l in [("rec_%s" % l) for l in layers]: assert l in kwargs out = kwargs[l][i] # [] # multi-head attention if len(out.shape) == 3 and min(out.shape) > 1: # Multihead attention out = np.transpose( out, axes=(1, 2, 0)) # (I, J, H) new version #out = np.transpose(out, axes=(2, 0, 1)) # [I, J, H] old version else: # RNN out = np.squeeze(out, axis=1) assert out.ndim >= 2 assert out.shape[0] >= output_len[i] and out.shape[ 1] >= encoder_len[i] data[i][l] = out[:output_len[i], :encoder_len[i]] fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % ( model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1]) np.save(fname, data) elif args.output_format == "png": for i in range(n_batch): for l in layers: extra_postfix = "" if args.dropout is not None: extra_postfix += "_dropout%.2f" % args.dropout elif args.train_flag: extra_postfix += "_train" fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % ( model_name, rnn.engine.epoch, seq_idx[i], l, extra_postfix) att_weights = kwargs["rec_%s" % l][i] att_weights = att_weights.squeeze(axis=2) # (out,enc) assert att_weights.shape[0] >= output_len[ i] and att_weights.shape[1] >= encoder_len[i] att_weights = att_weights[:output_len[i], :encoder_len[i]] print("Seq %i, %s: Dump att weights with shape %r to: %s" % (seq_idx[i], seq_tag[i], att_weights.shape, fname)) plt.matshow(att_weights) title = seq_tag[i] if dataset.can_serialize_data( network.extern_data.default_target): title += "\n" + dataset.serialize_data( network.extern_data.default_target, target_classes[i][:output_len[i]]) ax = plt.gca() tick_labels = [ dataset.serialize_data( network.extern_data.default_target, np.array([x], dtype=target_classes[i].dtype)) for x in target_classes[i][:output_len[i]] ] ax.set_yticklabels([''] + tick_labels, fontsize=8) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.title(title) plt.savefig(fname) plt.close() elif args.output_format == "hdf": assert len(layers) == 1 att_weights = kwargs["rec_%s" % layers[0]] hdf_writer.insert_batch(inputs=att_weights, seq_len={ 0: output_len, 1: encoder_len }, seq_tag=seq_tag) else: raise Exception("output format %r" % args.output_format) runner = Runner(engine=rnn.engine, dataset=dataset, batches=dataset_batch, train=False, train_flag=bool(args.dropout) or args.train_flag, extra_fetches=extra_fetches, extra_fetches_callback=fetch_callback, eval=False) runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch) if not args.instead_save_encoder_decoder: for l in layers: stats[l].dump(stream_prefix="Layer %r " % l) if not runner.finalized: print("Some error occured, not finalized.") sys.exit(1) if hdf_writer: hdf_writer.close() rnn.finalize()