示例#1
0
def main(argv):
  """
  Main entry.
  """
  arg_parser = argparse.ArgumentParser(description='Dump something from dataset.')
  arg_parser.add_argument('--config', help="filename to config-file. will use dataset 'eval' from it")
  arg_parser.add_argument("--dataset", help="dataset, overwriting config")
  arg_parser.add_argument("--refs", help="same format as hyps. alternative to providing dataset/config")
  arg_parser.add_argument("--hyps", help="hypotheses, dumped via search in py format")
  arg_parser.add_argument('--startseq', type=int, default=0, help='start seq idx (inclusive) (default: 0)')
  arg_parser.add_argument('--endseq', type=int, default=-1, help='end seq idx (inclusive) or -1 (default: -1)')
  arg_parser.add_argument("--key", default="raw", help="data-key, e.g. 'data' or 'classes'. (default: 'raw')")
  arg_parser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)")
  arg_parser.add_argument("--out", help="if provided, will write WER% (as string) to this file")
  arg_parser.add_argument("--expect_full", action="store_true", help="full dataset should be scored")
  args = arg_parser.parse_args(argv[1:])
  assert args.config or args.dataset or args.refs

  init(config_filename=args.config, log_verbosity=args.verbosity)
  dataset = None
  refs = None
  if args.refs:
    refs = load_hyps_refs(args.refs)
  elif args.dataset:
    dataset = init_dataset(args.dataset)
  elif config.value("wer_data", "eval") in ["train", "dev", "eval"]:
    dataset = init_dataset(config.opt_typed_value(config.value("search_data", "eval")))
  else:
    dataset = init_dataset(config.opt_typed_value("wer_data"))
  hyps = load_hyps_refs(args.hyps)

  global wer_compute
  wer_compute = WerComputeGraph()
  with tf_compat.v1.Session(config=tf_compat.v1.ConfigProto(device_count={"GPU": 0})) as _session:
    global session
    session = _session
    session.run(tf_compat.v1.global_variables_initializer())
    try:
      wer = calc_wer_on_dataset(dataset=dataset, refs=refs, options=args, hyps=hyps)
      print("Final WER: %.02f%%" % (wer * 100), file=log.v1)
      if args.out:
        with open(args.out, "w") as output_file:
          output_file.write("%.02f\n" % (wer * 100))
        print("Wrote WER%% to %r." % args.out)
    except KeyboardInterrupt:
      print("KeyboardInterrupt")
      sys.exit(1)
    finally:
      rnn.finalize()
示例#2
0
def _init_dataset():
    global sprintDataset, customDataset
    if sprintDataset:
        return
    assert config
    extra_opts = config.typed_value("sprint_interface_dataset_opts", {})
    assert isinstance(extra_opts, dict)
    sprintDataset = SprintDatasetBase.from_config(config, **extra_opts)
    if config.is_true("sprint_interface_custom_dataset"):
        custom_dataset_func = config.typed_value(
            "sprint_interface_custom_dataset")
        assert callable(custom_dataset_func)
        custom_dataset_opts = custom_dataset_func(sprint_dataset=sprintDataset)
        customDataset = init_dataset(custom_dataset_opts)
示例#3
0
def generate_hdf_from_other(opts, suffix=".hdf"):
    """
  :param dict[str] opts:
  :param str suffix:
  :return: hdf filename
  :rtype: str
  """
    # See test_hdf_dump.py and tools/hdf_dump.py.
    from returnn.util.basic import make_hashable
    cache_key = make_hashable(opts)
    if cache_key in _hdf_cache:
        return _hdf_cache[cache_key]
    fn = get_test_tmp_file(suffix=suffix)
    from returnn.datasets.basic import init_dataset
    dataset = init_dataset(opts)
    hdf_dataset = HDFDatasetWriter(fn)
    hdf_dataset.dump_from_dataset(dataset)
    hdf_dataset.close()
    _hdf_cache[cache_key] = fn
    return fn
示例#4
0
文件: sprint.py 项目: ishine/returnn
def demo():
  """
  Demo.
  """
  print("SprintDataset demo.")
  from argparse import ArgumentParser
  from returnn.util.basic import progress_bar_with_time
  from returnn.log import log
  from returnn.config import Config
  from returnn.datasets.basic import init_dataset
  arg_parser = ArgumentParser()
  arg_parser.add_argument("--config", help="config with ExternSprintDataset", required=True)
  arg_parser.add_argument("--sprint_cache_dataset", help="kwargs dict for SprintCacheDataset", required=True)
  arg_parser.add_argument("--max_num_seqs", default=sys.maxsize, type=int)
  arg_parser.add_argument("--action", default="compare", help="compare or benchmark")
  args = arg_parser.parse_args()
  log.initialize(verbosity=[4])
  sprint_cache_dataset_kwargs = eval(args.sprint_cache_dataset)
  assert isinstance(sprint_cache_dataset_kwargs, dict)
  sprint_cache_dataset = SprintCacheDataset(**sprint_cache_dataset_kwargs)
  print("SprintCacheDataset: %r" % sprint_cache_dataset)
  config = Config()
  config.load_file(args.config)
  dataset = init_dataset(config.typed_value("train"))
  print("Dataset via config: %r" % dataset)
  assert sprint_cache_dataset.num_inputs == dataset.num_inputs
  assert tuple(sprint_cache_dataset.num_outputs["classes"]) == tuple(dataset.num_outputs["classes"])
  sprint_cache_dataset.init_seq_order(epoch=1)

  if args.action == "compare":
    print("Iterating through dataset...")
    seq_idx = 0
    dataset.init_seq_order(epoch=1)
    while seq_idx < args.max_num_seqs:
      if not dataset.is_less_than_num_seqs(seq_idx):
        break
      dataset.load_seqs(seq_idx, seq_idx + 1)
      tag = dataset.get_tag(seq_idx)
      assert not tag.startswith("seq-"), "dataset does not provide tag-names for seqs"
      dataset_seq = sprint_cache_dataset.get_dataset_seq_for_name(tag)
      data = dataset.get_data(seq_idx, "data")
      targets = dataset.get_data(seq_idx, "classes")
      assert data.shape == dataset_seq.features["data"].shape
      assert targets.shape == dataset_seq.features["classes"].shape
      assert numpy.allclose(data, dataset_seq.features["data"])
      assert numpy.allclose(targets, dataset_seq.features["classes"])
      seq_idx += 1
      progress_bar_with_time(dataset.get_complete_frac(seq_idx))

    print("Finished through dataset. Num seqs: %i" % seq_idx)
    print("SprintCacheDataset has num seqs: %i." % sprint_cache_dataset.num_seqs)

  elif args.action == "benchmark":
    print("Iterating through dataset...")
    start_time = time.time()
    seq_tags = []
    seq_idx = 0
    dataset.init_seq_order(epoch=1)
    while seq_idx < args.max_num_seqs:
      if not dataset.is_less_than_num_seqs(seq_idx):
        break
      dataset.load_seqs(seq_idx, seq_idx + 1)
      tag = dataset.get_tag(seq_idx)
      assert not tag.startswith("seq-"), "dataset does not provide tag-names for seqs"
      seq_tags.append(tag)
      dataset.get_data(seq_idx, "data")
      dataset.get_data(seq_idx, "classes")
      seq_idx += 1
      progress_bar_with_time(dataset.get_complete_frac(seq_idx))
    print("Finished through dataset. Num seqs: %i, time: %f" % (seq_idx, time.time() - start_time))
    print("SprintCacheDataset has num seqs: %i." % sprint_cache_dataset.num_seqs)
    if hasattr(dataset, "exit_handler"):
      dataset.exit_handler()
    else:
      print("No way to stop any background tasks.")
    del dataset

    start_time = time.time()
    print("Iterating through SprintCacheDataset...")
    for i, tag in enumerate(seq_tags):
      sprint_cache_dataset.get_dataset_seq_for_name(tag)
      progress_bar_with_time(float(i) / len(seq_tags))
    print("Finished through SprintCacheDataset. time: %f" % (time.time() - start_time,))

  else:
    raise Exception("invalid action: %r" % args.action)
def init(config_str, config_dataset, use_pretrain, epoch, verbosity):
    """
  :param str config_str: either filename to config-file, or dict for dataset
  :param str|None config_dataset:
  :param bool use_pretrain: might overwrite config options, or even the dataset
  :param int epoch:
  :param int verbosity:
  """
    rnn.init_better_exchook()
    rnn.init_thread_join_hack()
    dataset_opts = None
    config_filename = None
    if config_str.strip().startswith("{"):
        print("Using dataset %s." % config_str)
        dataset_opts = eval(config_str.strip())
    elif config_str.endswith(".hdf"):
        dataset_opts = {"class": "HDFDataset", "files": [config_str]}
        print("Using dataset %r." % dataset_opts)
        assert os.path.exists(config_str)
    else:
        config_filename = config_str
        print("Using config file %r." % config_filename)
        assert os.path.exists(config_filename)
    rnn.init_config(config_filename=config_filename,
                    default_config={"cache_size": "0"})
    global config
    config = rnn.config
    config.set("log", None)
    config.set("log_verbosity", verbosity)
    rnn.init_log()
    print("Returnn %s starting up." % __file__, file=log.v2)
    rnn.returnn_greeting()
    rnn.init_faulthandler()
    rnn.init_config_json_network()
    util.BackendEngine.select_engine(config=config)
    if not dataset_opts:
        if config_dataset:
            dataset_opts = "config:%s" % config_dataset
        else:
            dataset_opts = "config:train"
    if use_pretrain:
        from returnn.pretrain import pretrain_from_config
        pretrain = pretrain_from_config(config)
        if pretrain:
            print("Using pretrain %s, epoch %i" % (pretrain, epoch),
                  file=log.v2)
            net_dict = pretrain.get_network_json_for_epoch(epoch=epoch)
            if "#config" in net_dict:
                config_overwrites = net_dict["#config"]
                print("Pretrain overwrites these config options:", file=log.v2)
                assert isinstance(config_overwrites, dict)
                for key, value in sorted(config_overwrites.items()):
                    assert isinstance(key, str)
                    orig_value = config.typed_dict.get(key, None)
                    if isinstance(orig_value, dict) and isinstance(
                            value, dict):
                        diff_str = "\n" + util.dict_diff_str(orig_value, value)
                    elif isinstance(value, dict):
                        diff_str = "\n%r ->\n%s" % (orig_value, pformat(value))
                    else:
                        diff_str = " %r -> %r" % (orig_value, value)
                    print("Config key %r for epoch %i:%s" %
                          (key, epoch, diff_str),
                          file=log.v2)
                    config.set(key, value)
            else:
                print("No config overwrites for this epoch.", file=log.v2)
        else:
            print("No pretraining used.", file=log.v2)
    elif config.typed_dict.get("pretrain", None):
        print("Not using pretrain.", file=log.v2)
    dataset_default_opts = {}
    Dataset.kwargs_update_from_config(config, dataset_default_opts)
    print("Using dataset:", dataset_opts, file=log.v2)
    global dataset
    dataset = init_dataset(dataset_opts, default_kwargs=dataset_default_opts)
    assert isinstance(dataset, Dataset)
    dataset.init_seq_order(epoch=epoch)