コード例 #1
0
ファイル: __main__.py プロジェクト: e0397123/returnn
def analyze_data(config):  # pylint: disable=redefined-outer-name
    """
  :param Config config:
  """
    dss = config.value('analyze_dataset', 'train')
    ds = {"train": train_data, "dev": dev_data, "eval": eval_data}[dss]
    epoch = config.int('epoch', 1)
    print("Analyze dataset", dss, "epoch", epoch, file=log.v1)
    ds.init_seq_order(epoch=epoch)
    stat_prefix = config.value('statistics_save_prefix', 'statistics')
    dtype = config.value('statistics_dtype', 'float64')
    target = config.value('target', 'classes')
    data_key = config.value('data_key', 'data')
    assert ds.is_data_sparse(target), "need for prior calculation"
    assert not ds.is_data_sparse(data_key), "needed for mean/var estimation"
    from returnn.util.basic import inplace_increment, progress_bar_with_time, NumbersDict

    priors = numpy.zeros((ds.get_data_dim(target), ), dtype=dtype)
    mean = numpy.zeros((ds.get_data_dim(data_key), ), dtype=dtype)
    mean_sq = numpy.zeros((ds.get_data_dim(data_key), ), dtype=dtype)
    total_targets_len = 0
    total_data_len = 0

    # Note: This is not stable! See :class:`Util.Stats` for a better alternative.
    seq_idx = 0
    while ds.is_less_than_num_seqs(seq_idx):
        progress_bar_with_time(ds.get_complete_frac(seq_idx))
        ds.load_seqs(seq_idx, seq_idx + 1)
        targets = ds.get_data(seq_idx, target)
        inplace_increment(priors, targets, 1)
        total_targets_len += targets.shape[0]
        data = ds.get_data(seq_idx, data_key)
        new_total_data_len = total_data_len + data.shape[0]
        f = float(total_data_len) / new_total_data_len
        mean = mean * f + numpy.sum(data, axis=0) * (1.0 - f)
        mean_sq = mean_sq * f + numpy.sum(data * data, axis=0) * (1.0 - f)
        total_data_len = new_total_data_len
        seq_idx += 1
    log_priors = numpy.log(priors)
    log_priors -= numpy.log(NumbersDict(ds.get_num_timesteps())[target])
    std_dev = numpy.sqrt(mean_sq - mean * mean)
    print("Finished. %i total target frames, %i total data frames" %
          (total_targets_len, total_data_len),
          file=log.v1)
    priors_fn = stat_prefix + ".log_priors.txt"
    mean_fn = stat_prefix + ".mean.txt"
    std_dev_fn = stat_prefix + ".std_dev.txt"
    print("Dump priors to", priors_fn, file=log.v1)
    numpy.savetxt(priors_fn, log_priors)
    print("Dump mean to", mean_fn, file=log.v1)
    numpy.savetxt(mean_fn, mean)
    print("Dump std dev to", std_dev_fn, file=log.v1)
    numpy.savetxt(std_dev_fn, std_dev)
    print("Done.", file=log.v1)
コード例 #2
0
def get_raw_strings(dataset, options):
    """
  :param Dataset dataset:
  :param options: argparse.Namespace
  :return: list of (seq tag, string)
  :rtype: list[(str,str)]
  """
    refs = []
    start_time = time.time()
    seq_len_stats = Stats()
    seq_idx = options.startseq
    if options.endseq < 0:
        options.endseq = float("inf")
    interactive = util.is_tty() and not log.verbose[5]
    print("Iterating over %r." % dataset, file=log.v2)
    while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq:
        dataset.load_seqs(seq_idx, seq_idx + 1)
        complete_frac = dataset.get_complete_frac(seq_idx)
        start_elapsed = time.time() - start_time
        try:
            num_seqs_s = str(dataset.num_seqs)
        except NotImplementedError:
            try:
                num_seqs_s = "~%i" % dataset.estimated_num_seqs
            except TypeError:  # a number is required, not NoneType
                num_seqs_s = "?"
        progress_prefix = "%i/%s" % (
            seq_idx,
            num_seqs_s,
        )
        progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
        if complete_frac > 0:
            total_time_estimated = start_elapsed / complete_frac
            remaining_estimated = total_time_estimated - start_elapsed
            progress += " (%s)" % hms(remaining_estimated)
        seq_tag = dataset.get_tag(seq_idx)
        assert isinstance(seq_tag, str)
        ref = dataset.get_data(seq_idx, options.key)
        if isinstance(ref, numpy.ndarray):
            assert ref.shape == () or (ref.ndim == 1
                                       and ref.dtype == numpy.uint8)
            if ref.shape == ():
                ref = ref.flatten()[0]  # get the entry itself (str or bytes)
            else:
                ref = ref.tobytes()
        if isinstance(ref, bytes):
            ref = ref.decode("utf8")
        assert isinstance(ref, str)
        seq_len_stats.collect([len(ref)])
        refs.append((seq_tag, ref))
        if interactive:
            util.progress_bar_with_time(complete_frac, prefix=progress_prefix)
        elif log.verbose[5]:
            print(progress_prefix,
                  "seq tag %r, ref len %i chars" % (seq_tag, len(ref)))
        seq_idx += 1
    print("Done. Num seqs %i. Total time %s." %
          (seq_idx, hms(time.time() - start_time)),
          file=log.v1)
    print("More seqs which we did not dumped: %s." %
          (dataset.is_less_than_num_seqs(seq_idx), ),
          file=log.v1)
    seq_len_stats.dump(stream_prefix="Seq-length %r " % (options.key, ),
                       stream=log.v2)
    return refs
コード例 #3
0
def calc_wer_on_dataset(dataset, refs, options, hyps):
    """
  :param Dataset|None dataset:
  :param dict[str,str]|None refs: seq tag -> ref string (words delimited by space)
  :param options: argparse.Namespace
  :param dict[str,str] hyps: seq tag -> hyp string (words delimited by space)
  :return: WER
  :rtype: float
  """
    assert dataset or refs
    start_time = time.time()
    seq_len_stats = {"refs": Stats(), "hyps": Stats()}
    seq_idx = options.startseq
    if options.endseq < 0:
        options.endseq = float("inf")
    wer = 1.0
    remaining_hyp_seq_tags = set(hyps.keys())
    interactive = util.is_tty() and not log.verbose[5]
    collected = {"hyps": [], "refs": []}
    max_num_collected = 1
    if dataset:
        dataset.init_seq_order(epoch=1)
    else:
        refs = sorted(refs.items(), key=lambda item: len(item[1]))
    while True:
        if seq_idx > options.endseq:
            break
        if dataset:
            if not dataset.is_less_than_num_seqs(seq_idx):
                break
            dataset.load_seqs(seq_idx, seq_idx + 1)
            complete_frac = dataset.get_complete_frac(seq_idx)
            seq_tag = dataset.get_tag(seq_idx)
            assert isinstance(seq_tag, str)
            ref = dataset.get_data(seq_idx, options.key)
            if isinstance(ref, numpy.ndarray):
                assert ref.shape == ()
                ref = ref.flatten()[0]  # get the entry itself (str or bytes)
            if isinstance(ref, bytes):
                ref = ref.decode("utf8")
            assert isinstance(ref, str)
            try:
                num_seqs_s = str(dataset.num_seqs)
            except NotImplementedError:
                try:
                    num_seqs_s = "~%i" % dataset.estimated_num_seqs
                except TypeError:  # a number is required, not NoneType
                    num_seqs_s = "?"
        else:
            if seq_idx >= len(refs):
                break
            complete_frac = (seq_idx + 1) / float(len(refs))
            seq_tag, ref = refs[seq_idx]
            assert isinstance(seq_tag, str)
            assert isinstance(ref, str)
            num_seqs_s = str(len(refs))

        start_elapsed = time.time() - start_time
        progress_prefix = "%i/%s (WER %.02f%%)" % (seq_idx, num_seqs_s,
                                                   wer * 100)
        progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
        if complete_frac > 0:
            total_time_estimated = start_elapsed / complete_frac
            remaining_estimated = total_time_estimated - start_elapsed
            progress += " (%s)" % hms(remaining_estimated)

        remaining_hyp_seq_tags.remove(seq_tag)
        hyp = hyps[seq_tag]
        seq_len_stats["hyps"].collect([len(hyp)])
        seq_len_stats["refs"].collect([len(ref)])
        collected["hyps"].append(hyp)
        collected["refs"].append(ref)

        if len(collected["hyps"]) >= max_num_collected:
            wer = wer_compute.step(session, **collected)
            del collected["hyps"][:]
            del collected["refs"][:]

        if interactive:
            util.progress_bar_with_time(complete_frac, prefix=progress_prefix)
        elif log.verbose[5]:
            print(
                progress_prefix, "seq tag %r, ref/hyp len %i/%i chars" %
                (seq_tag, len(ref), len(hyp)))
        seq_idx += 1
    if len(collected["hyps"]) > 0:
        wer = wer_compute.step(session, **collected)
    print("Done. Num seqs %i. Total time %s." %
          (seq_idx, hms(time.time() - start_time)),
          file=log.v1)
    print("Remaining num hyp seqs %i." % (len(remaining_hyp_seq_tags), ),
          file=log.v1)
    if dataset:
        print("More seqs which we did not dumped: %s." %
              dataset.is_less_than_num_seqs(seq_idx),
              file=log.v1)
    for key in ["hyps", "refs"]:
        seq_len_stats[key].dump(stream_prefix="Seq-length %r %r " %
                                (key, options.key),
                                stream=log.v2)
    if options.expect_full:
        assert not remaining_hyp_seq_tags, "There are still remaining hypotheses."
    return wer
コード例 #4
0
ファイル: dump-dataset.py プロジェクト: twistedmove/returnn
def dump_dataset(dataset, options):
    """
  :type dataset: Dataset.Dataset
  :param options: argparse.Namespace
  """
    print("Epoch: %i" % options.epoch, file=log.v3)
    dataset.init_seq_order(epoch=options.epoch)
    print("Dataset keys:", dataset.get_data_keys(), file=log.v3)
    print("Dataset target keys:", dataset.get_target_list(), file=log.v3)
    assert options.key in dataset.get_data_keys()

    if options.get_num_seqs:
        print("Get num seqs.")
        print("estimated_num_seqs: %r" % dataset.estimated_num_seqs)
        try:
            print("num_seqs: %r" % dataset.num_seqs)
        except Exception as exc:
            print("num_seqs exception %r, which is valid, so we count." % exc)
            seq_idx = 0
            if dataset.get_target_list():
                default_target = dataset.get_target_list()[0]
            else:
                default_target = None
            while dataset.is_less_than_num_seqs(seq_idx):
                dataset.load_seqs(seq_idx, seq_idx + 1)
                if seq_idx % 10000 == 0:
                    if default_target:
                        targets = dataset.get_targets(default_target, seq_idx)
                        postfix = " (targets = %r...)" % (targets[:10], )
                    else:
                        postfix = ""
                    print("%i ...%s" % (seq_idx, postfix))
                seq_idx += 1
            print("accumulated num seqs: %i" % seq_idx)
        print("Done.")
        return

    dump_file = None
    if options.type == "numpy":
        print("Dump files: %r*%r" %
              (options.dump_prefix, options.dump_postfix),
              file=log.v3)
    elif options.type == "stdout":
        print("Dump to stdout", file=log.v3)
        if options.stdout_limit is not None:
            util.set_pretty_print_default_limit(options.stdout_limit)
            numpy.set_printoptions(
                threshold=sys.maxsize if options.stdout_limit ==
                float("inf") else int(options.stdout_limit))
        if options.stdout_as_bytes:
            util.set_pretty_print_as_bytes(options.stdout_as_bytes)
    elif options.type == "print_tag":
        print("Dump seq tag to stdout", file=log.v3)
    elif options.type == "dump_tag":
        dump_file = open("%sseq-tags.txt" % options.dump_prefix, "w")
        print("Dump seq tag to file: %s" % (dump_file.name, ), file=log.v3)
    elif options.type == "dump_seq_len":
        dump_file = open("%sseq-lens.txt" % options.dump_prefix, "w")
        print("Dump seq lens to file: %s" % (dump_file.name, ), file=log.v3)
        dump_file.write("{\n")
    elif options.type == "print_shape":
        print("Dump shape to stdout", file=log.v3)
    elif options.type == "plot":
        print("Plot.", file=log.v3)
    elif options.type == "interactive":
        print("Interactive debug shell.", file=log.v3)
    elif options.type == "null":
        if options.dump_stats:
            print("No dump (except stats).")
        else:
            print("No dump.")
    else:
        raise Exception("unknown dump option type %r" % options.type)

    start_time = time.time()
    stats = Stats() if (options.stats or options.dump_stats) else None
    seq_len_stats = {key: Stats() for key in dataset.get_data_keys()}
    seq_idx = options.startseq
    if options.endseq < 0:
        options.endseq = float("inf")
    while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq:
        dataset.load_seqs(seq_idx, seq_idx + 1)
        complete_frac = dataset.get_complete_frac(seq_idx)
        start_elapsed = time.time() - start_time
        try:
            num_seqs_s = str(dataset.num_seqs)
        except NotImplementedError:
            try:
                num_seqs_s = "~%i" % dataset.estimated_num_seqs
            except TypeError:  # a number is required, not NoneType
                num_seqs_s = "?"
        progress_prefix = "%i/%s" % (seq_idx, num_seqs_s)
        progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
        data = None
        if complete_frac > 0:
            total_time_estimated = start_elapsed / complete_frac
            remaining_estimated = total_time_estimated - start_elapsed
            progress += " (%s)" % hms(remaining_estimated)
        if options.type == "print_tag":
            print(
                "seq %s tag:" %
                (progress if log.verbose[2] else progress_prefix),
                dataset.get_tag(seq_idx))
        elif options.type == "dump_tag":
            print(
                "seq %s tag:" %
                (progress if log.verbose[2] else progress_prefix),
                dataset.get_tag(seq_idx))
            dump_file.write("%s\n" % dataset.get_tag(seq_idx))
        elif options.type == "dump_seq_len":
            seq_len = dataset.get_seq_length(seq_idx)[options.key]
            print(
                "seq %s tag:" %
                (progress if log.verbose[2] else progress_prefix),
                dataset.get_tag(seq_idx), "%r len:" % options.key, seq_len)
            dump_file.write("%r: %r,\n" % (dataset.get_tag(seq_idx), seq_len))
        else:
            data = dataset.get_data(seq_idx, options.key)
            if options.type == "numpy":
                numpy.savetxt(
                    "%s%i.data%s" %
                    (options.dump_prefix, seq_idx, options.dump_postfix), data)
            elif options.type == "stdout":
                print("seq %s tag:" % progress, dataset.get_tag(seq_idx))
                print("seq %s data:" % progress, pretty_print(data))
            elif options.type == "print_shape":
                print("seq %s data shape:" % progress, data.shape)
            elif options.type == "plot":
                plot(data)
            for target in dataset.get_target_list():
                targets = dataset.get_targets(target, seq_idx)
                if options.type == "numpy":
                    numpy.savetxt("%s%i.targets.%s%s" %
                                  (options.dump_prefix, seq_idx, target,
                                   options.dump_postfix),
                                  targets,
                                  fmt='%i')
                elif options.type == "stdout":
                    extra = ""
                    if target in dataset.labels and len(
                            dataset.labels[target]) > 1:
                        assert dataset.can_serialize_data(target)
                        extra += " (%r)" % dataset.serialize_data(key=target,
                                                                  data=targets)
                    print("seq %i target %r: %s%s" %
                          (seq_idx, target, pretty_print(targets), extra))
                elif options.type == "print_shape":
                    print("seq %i target %r shape:" % (seq_idx, target),
                          targets.shape)
            if options.type == "interactive":
                from returnn.util.debug import debug_shell
                debug_shell(locals())
        seq_len = dataset.get_seq_length(seq_idx)
        for key in dataset.get_data_keys():
            seq_len_stats[key].collect([seq_len[key]])
        if stats:
            stats.collect(data)
        if options.type == "null":
            util.progress_bar_with_time(complete_frac, prefix=progress_prefix)

        seq_idx += 1

    print("Done. Total time %s. More seqs which we did not dumped: %s" %
          (hms_fraction(time.time() - start_time),
           dataset.is_less_than_num_seqs(seq_idx)),
          file=log.v2)
    for key in dataset.get_data_keys():
        seq_len_stats[key].dump(stream_prefix="Seq-length %r " % key,
                                stream=log.v2)
    if stats:
        stats.dump(output_file_prefix=options.dump_stats,
                   stream_prefix="Data %r " % options.key,
                   stream=log.v1)
    if options.type == "dump_seq_len":
        dump_file.write("}\n")
    if dump_file:
        print("Dumped to file:", dump_file.name, file=log.v2)
        dump_file.close()
コード例 #5
0
ファイル: sprint.py プロジェクト: ishine/returnn
def demo():
  """
  Demo.
  """
  print("SprintDataset demo.")
  from argparse import ArgumentParser
  from returnn.util.basic import progress_bar_with_time
  from returnn.log import log
  from returnn.config import Config
  from returnn.datasets.basic import init_dataset
  arg_parser = ArgumentParser()
  arg_parser.add_argument("--config", help="config with ExternSprintDataset", required=True)
  arg_parser.add_argument("--sprint_cache_dataset", help="kwargs dict for SprintCacheDataset", required=True)
  arg_parser.add_argument("--max_num_seqs", default=sys.maxsize, type=int)
  arg_parser.add_argument("--action", default="compare", help="compare or benchmark")
  args = arg_parser.parse_args()
  log.initialize(verbosity=[4])
  sprint_cache_dataset_kwargs = eval(args.sprint_cache_dataset)
  assert isinstance(sprint_cache_dataset_kwargs, dict)
  sprint_cache_dataset = SprintCacheDataset(**sprint_cache_dataset_kwargs)
  print("SprintCacheDataset: %r" % sprint_cache_dataset)
  config = Config()
  config.load_file(args.config)
  dataset = init_dataset(config.typed_value("train"))
  print("Dataset via config: %r" % dataset)
  assert sprint_cache_dataset.num_inputs == dataset.num_inputs
  assert tuple(sprint_cache_dataset.num_outputs["classes"]) == tuple(dataset.num_outputs["classes"])
  sprint_cache_dataset.init_seq_order(epoch=1)

  if args.action == "compare":
    print("Iterating through dataset...")
    seq_idx = 0
    dataset.init_seq_order(epoch=1)
    while seq_idx < args.max_num_seqs:
      if not dataset.is_less_than_num_seqs(seq_idx):
        break
      dataset.load_seqs(seq_idx, seq_idx + 1)
      tag = dataset.get_tag(seq_idx)
      assert not tag.startswith("seq-"), "dataset does not provide tag-names for seqs"
      dataset_seq = sprint_cache_dataset.get_dataset_seq_for_name(tag)
      data = dataset.get_data(seq_idx, "data")
      targets = dataset.get_data(seq_idx, "classes")
      assert data.shape == dataset_seq.features["data"].shape
      assert targets.shape == dataset_seq.features["classes"].shape
      assert numpy.allclose(data, dataset_seq.features["data"])
      assert numpy.allclose(targets, dataset_seq.features["classes"])
      seq_idx += 1
      progress_bar_with_time(dataset.get_complete_frac(seq_idx))

    print("Finished through dataset. Num seqs: %i" % seq_idx)
    print("SprintCacheDataset has num seqs: %i." % sprint_cache_dataset.num_seqs)

  elif args.action == "benchmark":
    print("Iterating through dataset...")
    start_time = time.time()
    seq_tags = []
    seq_idx = 0
    dataset.init_seq_order(epoch=1)
    while seq_idx < args.max_num_seqs:
      if not dataset.is_less_than_num_seqs(seq_idx):
        break
      dataset.load_seqs(seq_idx, seq_idx + 1)
      tag = dataset.get_tag(seq_idx)
      assert not tag.startswith("seq-"), "dataset does not provide tag-names for seqs"
      seq_tags.append(tag)
      dataset.get_data(seq_idx, "data")
      dataset.get_data(seq_idx, "classes")
      seq_idx += 1
      progress_bar_with_time(dataset.get_complete_frac(seq_idx))
    print("Finished through dataset. Num seqs: %i, time: %f" % (seq_idx, time.time() - start_time))
    print("SprintCacheDataset has num seqs: %i." % sprint_cache_dataset.num_seqs)
    if hasattr(dataset, "exit_handler"):
      dataset.exit_handler()
    else:
      print("No way to stop any background tasks.")
    del dataset

    start_time = time.time()
    print("Iterating through SprintCacheDataset...")
    for i, tag in enumerate(seq_tags):
      sprint_cache_dataset.get_dataset_seq_for_name(tag)
      progress_bar_with_time(float(i) / len(seq_tags))
    print("Finished through SprintCacheDataset. time: %f" % (time.time() - start_time,))

  else:
    raise Exception("invalid action: %r" % args.action)
コード例 #6
0
def analyze_dataset(options):
    """
  :param options: argparse.Namespace
  """
    print("Epoch: %i" % options.epoch, file=log.v3)
    print("Dataset keys:", dataset.get_data_keys(), file=log.v3)
    print("Dataset target keys:", dataset.get_target_list(), file=log.v3)
    assert options.key in dataset.get_data_keys()

    terminal_width, _ = util.terminal_size()
    show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5])
                                    and terminal_width >= 0)

    start_time = time.time()
    num_seqs_stats = Stats()
    if options.endseq < 0:
        options.endseq = float("inf")

    recurrent = True
    used_data_keys = dataset.get_data_keys()
    batch_size = config.typed_value('batch_size', 1)
    max_seqs = config.int('max_seqs', -1)
    seq_drop = config.float('seq_drop', 0.0)
    max_seq_length = config.typed_value(
        'max_seq_length', None) or config.float('max_seq_length', 0)
    max_pad_size = config.typed_value("max_pad_size", None)

    batches = dataset.generate_batches(recurrent_net=recurrent,
                                       batch_size=batch_size,
                                       max_seqs=max_seqs,
                                       max_seq_length=max_seq_length,
                                       max_pad_size=max_pad_size,
                                       seq_drop=seq_drop,
                                       used_data_keys=used_data_keys)

    step = 0
    total_num_seqs = 0
    total_num_frames = NumbersDict()
    total_num_used_frames = NumbersDict()

    try:
        while batches.has_more():
            # See FeedDictDataProvider.
            batch, = batches.peek_next_n(1)
            assert isinstance(batch, Batch)
            if batch.start_seq > options.endseq:
                break
            dataset.load_seqs(batch.start_seq, batch.end_seq)
            complete_frac = batches.completed_frac()
            start_elapsed = time.time() - start_time
            try:
                num_seqs_s = str(dataset.num_seqs)
            except NotImplementedError:
                try:
                    num_seqs_s = "~%i" % dataset.estimated_num_seqs
                except TypeError:  # a number is required, not NoneType
                    num_seqs_s = "?"
            progress_prefix = "%i/%s" % (batch.start_seq, num_seqs_s)
            progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
            if complete_frac > 0:
                total_time_estimated = start_elapsed / complete_frac
                remaining_estimated = total_time_estimated - start_elapsed
                progress += " (%s)" % hms(remaining_estimated)

            batch_max_time = NumbersDict.max(
                [seq.frame_length for seq in batch.seqs]) * len(batch.seqs)
            batch_num_used_frames = sum(
                [seq.frame_length for seq in batch.seqs], NumbersDict())
            total_num_seqs += len(batch.seqs)
            num_seqs_stats.collect(numpy.array([len(batch.seqs)]))
            total_num_frames += batch_max_time
            total_num_used_frames += batch_num_used_frames

            print("%s, batch %i, num seqs %i, frames %s, used %s (%s)" %
                  (progress, step, len(
                      batch.seqs), batch_max_time, batch_num_used_frames,
                   batch_num_used_frames / batch_max_time),
                  file=log.v5)
            if show_interactive_process_bar:
                util.progress_bar_with_time(complete_frac,
                                            prefix=progress_prefix)

            step += 1
            batches.advance(1)

    finally:
        print("Done. Total time %s. More seqs which we did not dumped: %s" %
              (hms(time.time() - start_time), batches.has_more()),
              file=log.v2)
        print("Dataset epoch %i, order %r." %
              (dataset.epoch, dataset.seq_ordering))
        print("Num batches (steps): %i" % step, file=log.v1)
        print("Num seqs: %i" % total_num_seqs, file=log.v1)
        num_seqs_stats.dump(stream=log.v1, stream_prefix="Batch num seqs ")
        for key in used_data_keys:
            print("Data key %r:" % key, file=log.v1)
            print("  Num frames: %s" % total_num_frames[key], file=log.v1)
            print("  Num used frames: %s" % total_num_used_frames[key],
                  file=log.v1)
            print("  Fraction used frames: %s" %
                  (total_num_used_frames / total_num_frames)[key],
                  file=log.v1)
        dataset.finish_epoch()