def main(argv): """ Main entry. """ arg_parser = argparse.ArgumentParser(description='Dump something from dataset.') arg_parser.add_argument('--config', help="filename to config-file. will use dataset 'eval' from it") arg_parser.add_argument("--dataset", help="dataset, overwriting config") arg_parser.add_argument("--refs", help="same format as hyps. alternative to providing dataset/config") arg_parser.add_argument("--hyps", help="hypotheses, dumped via search in py format") arg_parser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)') arg_parser.add_argument('--endseq', type=int, default=-1, help='end seq idx (inclusive) or -1 (default: -1)') arg_parser.add_argument("--key", default="raw", help="data-key, e.g. 'data' or 'classes'. (default: 'raw')") arg_parser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)") arg_parser.add_argument("--out", help="if provided, will write WER% (as string) to this file") arg_parser.add_argument("--expect_full", action="store_true", help="full dataset should be scored") args = arg_parser.parse_args(argv[1:]) assert args.config or args.dataset or args.refs init(config_filename=args.config, log_verbosity=args.verbosity) dataset = None refs = None if args.refs: refs = load_hyps_refs(args.refs) elif args.dataset: dataset = init_dataset(args.dataset) elif config.value("wer_data", "eval") in ["train", "dev", "eval"]: dataset = init_dataset(config.opt_typed_value(config.value("search_data", "eval"))) else: dataset = init_dataset(config.opt_typed_value("wer_data")) hyps = load_hyps_refs(args.hyps) global wer_compute wer_compute = WerComputeGraph() with tf_compat.v1.Session(config=tf_compat.v1.ConfigProto(device_count={"GPU": 0})) as _session: global session session = _session session.run(tf_compat.v1.global_variables_initializer()) try: wer = calc_wer_on_dataset(dataset=dataset, refs=refs, options=args, hyps=hyps) print("Final WER: %.02f%%" % (wer * 100), file=log.v1) if args.out: with open(args.out, "w") as output_file: output_file.write("%.02f\n" % (wer * 100)) print("Wrote WER%% to %r." % args.out) except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) finally: rnn.finalize()
def _init_dataset(): global sprintDataset, customDataset if sprintDataset: return assert config extra_opts = config.typed_value("sprint_interface_dataset_opts", {}) assert isinstance(extra_opts, dict) sprintDataset = SprintDatasetBase.from_config(config, **extra_opts) if config.is_true("sprint_interface_custom_dataset"): custom_dataset_func = config.typed_value( "sprint_interface_custom_dataset") assert callable(custom_dataset_func) custom_dataset_opts = custom_dataset_func(sprint_dataset=sprintDataset) customDataset = init_dataset(custom_dataset_opts)
def generate_hdf_from_other(opts, suffix=".hdf"): """ :param dict[str] opts: :param str suffix: :return: hdf filename :rtype: str """ # See test_hdf_dump.py and tools/hdf_dump.py. from returnn.util.basic import make_hashable cache_key = make_hashable(opts) if cache_key in _hdf_cache: return _hdf_cache[cache_key] fn = get_test_tmp_file(suffix=suffix) from returnn.datasets.basic import init_dataset dataset = init_dataset(opts) hdf_dataset = HDFDatasetWriter(fn) hdf_dataset.dump_from_dataset(dataset) hdf_dataset.close() _hdf_cache[cache_key] = fn return fn
def demo(): """ Demo. """ print("SprintDataset demo.") from argparse import ArgumentParser from returnn.util.basic import progress_bar_with_time from returnn.log import log from returnn.config import Config from returnn.datasets.basic import init_dataset arg_parser = ArgumentParser() arg_parser.add_argument("--config", help="config with ExternSprintDataset", required=True) arg_parser.add_argument("--sprint_cache_dataset", help="kwargs dict for SprintCacheDataset", required=True) arg_parser.add_argument("--max_num_seqs", default=sys.maxsize, type=int) arg_parser.add_argument("--action", default="compare", help="compare or benchmark") args = arg_parser.parse_args() log.initialize(verbosity=[4]) sprint_cache_dataset_kwargs = eval(args.sprint_cache_dataset) assert isinstance(sprint_cache_dataset_kwargs, dict) sprint_cache_dataset = SprintCacheDataset(**sprint_cache_dataset_kwargs) print("SprintCacheDataset: %r" % sprint_cache_dataset) config = Config() config.load_file(args.config) dataset = init_dataset(config.typed_value("train")) print("Dataset via config: %r" % dataset) assert sprint_cache_dataset.num_inputs == dataset.num_inputs assert tuple(sprint_cache_dataset.num_outputs["classes"]) == tuple(dataset.num_outputs["classes"]) sprint_cache_dataset.init_seq_order(epoch=1) if args.action == "compare": print("Iterating through dataset...") seq_idx = 0 dataset.init_seq_order(epoch=1) while seq_idx < args.max_num_seqs: if not dataset.is_less_than_num_seqs(seq_idx): break dataset.load_seqs(seq_idx, seq_idx + 1) tag = dataset.get_tag(seq_idx) assert not tag.startswith("seq-"), "dataset does not provide tag-names for seqs" dataset_seq = sprint_cache_dataset.get_dataset_seq_for_name(tag) data = dataset.get_data(seq_idx, "data") targets = dataset.get_data(seq_idx, "classes") assert data.shape == dataset_seq.features["data"].shape assert targets.shape == dataset_seq.features["classes"].shape assert numpy.allclose(data, dataset_seq.features["data"]) assert numpy.allclose(targets, dataset_seq.features["classes"]) seq_idx += 1 progress_bar_with_time(dataset.get_complete_frac(seq_idx)) print("Finished through dataset. Num seqs: %i" % seq_idx) print("SprintCacheDataset has num seqs: %i." % sprint_cache_dataset.num_seqs) elif args.action == "benchmark": print("Iterating through dataset...") start_time = time.time() seq_tags = [] seq_idx = 0 dataset.init_seq_order(epoch=1) while seq_idx < args.max_num_seqs: if not dataset.is_less_than_num_seqs(seq_idx): break dataset.load_seqs(seq_idx, seq_idx + 1) tag = dataset.get_tag(seq_idx) assert not tag.startswith("seq-"), "dataset does not provide tag-names for seqs" seq_tags.append(tag) dataset.get_data(seq_idx, "data") dataset.get_data(seq_idx, "classes") seq_idx += 1 progress_bar_with_time(dataset.get_complete_frac(seq_idx)) print("Finished through dataset. Num seqs: %i, time: %f" % (seq_idx, time.time() - start_time)) print("SprintCacheDataset has num seqs: %i." % sprint_cache_dataset.num_seqs) if hasattr(dataset, "exit_handler"): dataset.exit_handler() else: print("No way to stop any background tasks.") del dataset start_time = time.time() print("Iterating through SprintCacheDataset...") for i, tag in enumerate(seq_tags): sprint_cache_dataset.get_dataset_seq_for_name(tag) progress_bar_with_time(float(i) / len(seq_tags)) print("Finished through SprintCacheDataset. time: %f" % (time.time() - start_time,)) else: raise Exception("invalid action: %r" % args.action)
def init(config_str, config_dataset, use_pretrain, epoch, verbosity): """ :param str config_str: either filename to config-file, or dict for dataset :param str|None config_dataset: :param bool use_pretrain: might overwrite config options, or even the dataset :param int epoch: :param int verbosity: """ rnn.init_better_exchook() rnn.init_thread_join_hack() dataset_opts = None config_filename = None if config_str.strip().startswith("{"): print("Using dataset %s." % config_str) dataset_opts = eval(config_str.strip()) elif config_str.endswith(".hdf"): dataset_opts = {"class": "HDFDataset", "files": [config_str]} print("Using dataset %r." % dataset_opts) assert os.path.exists(config_str) else: config_filename = config_str print("Using config file %r." % config_filename) assert os.path.exists(config_filename) rnn.init_config(config_filename=config_filename, default_config={"cache_size": "0"}) global config config = rnn.config config.set("log", None) config.set("log_verbosity", verbosity) rnn.init_log() print("Returnn %s starting up." % __file__, file=log.v2) rnn.returnn_greeting() rnn.init_faulthandler() rnn.init_config_json_network() util.BackendEngine.select_engine(config=config) if not dataset_opts: if config_dataset: dataset_opts = "config:%s" % config_dataset else: dataset_opts = "config:train" if use_pretrain: from returnn.pretrain import pretrain_from_config pretrain = pretrain_from_config(config) if pretrain: print("Using pretrain %s, epoch %i" % (pretrain, epoch), file=log.v2) net_dict = pretrain.get_network_json_for_epoch(epoch=epoch) if "#config" in net_dict: config_overwrites = net_dict["#config"] print("Pretrain overwrites these config options:", file=log.v2) assert isinstance(config_overwrites, dict) for key, value in sorted(config_overwrites.items()): assert isinstance(key, str) orig_value = config.typed_dict.get(key, None) if isinstance(orig_value, dict) and isinstance( value, dict): diff_str = "\n" + util.dict_diff_str(orig_value, value) elif isinstance(value, dict): diff_str = "\n%r ->\n%s" % (orig_value, pformat(value)) else: diff_str = " %r -> %r" % (orig_value, value) print("Config key %r for epoch %i:%s" % (key, epoch, diff_str), file=log.v2) config.set(key, value) else: print("No config overwrites for this epoch.", file=log.v2) else: print("No pretraining used.", file=log.v2) elif config.typed_dict.get("pretrain", None): print("Not using pretrain.", file=log.v2) dataset_default_opts = {} Dataset.kwargs_update_from_config(config, dataset_default_opts) print("Using dataset:", dataset_opts, file=log.v2) global dataset dataset = init_dataset(dataset_opts, default_kwargs=dataset_default_opts) assert isinstance(dataset, Dataset) dataset.init_seq_order(epoch=epoch)