def analyze_data(config): # pylint: disable=redefined-outer-name """ :param Config config: """ dss = config.value('analyze_dataset', 'train') ds = {"train": train_data, "dev": dev_data, "eval": eval_data}[dss] epoch = config.int('epoch', 1) print("Analyze dataset", dss, "epoch", epoch, file=log.v1) ds.init_seq_order(epoch=epoch) stat_prefix = config.value('statistics_save_prefix', 'statistics') dtype = config.value('statistics_dtype', 'float64') target = config.value('target', 'classes') data_key = config.value('data_key', 'data') assert ds.is_data_sparse(target), "need for prior calculation" assert not ds.is_data_sparse(data_key), "needed for mean/var estimation" from returnn.util.basic import inplace_increment, progress_bar_with_time, NumbersDict priors = numpy.zeros((ds.get_data_dim(target), ), dtype=dtype) mean = numpy.zeros((ds.get_data_dim(data_key), ), dtype=dtype) mean_sq = numpy.zeros((ds.get_data_dim(data_key), ), dtype=dtype) total_targets_len = 0 total_data_len = 0 # Note: This is not stable! See :class:`Util.Stats` for a better alternative. seq_idx = 0 while ds.is_less_than_num_seqs(seq_idx): progress_bar_with_time(ds.get_complete_frac(seq_idx)) ds.load_seqs(seq_idx, seq_idx + 1) targets = ds.get_data(seq_idx, target) inplace_increment(priors, targets, 1) total_targets_len += targets.shape[0] data = ds.get_data(seq_idx, data_key) new_total_data_len = total_data_len + data.shape[0] f = float(total_data_len) / new_total_data_len mean = mean * f + numpy.sum(data, axis=0) * (1.0 - f) mean_sq = mean_sq * f + numpy.sum(data * data, axis=0) * (1.0 - f) total_data_len = new_total_data_len seq_idx += 1 log_priors = numpy.log(priors) log_priors -= numpy.log(NumbersDict(ds.get_num_timesteps())[target]) std_dev = numpy.sqrt(mean_sq - mean * mean) print("Finished. %i total target frames, %i total data frames" % (total_targets_len, total_data_len), file=log.v1) priors_fn = stat_prefix + ".log_priors.txt" mean_fn = stat_prefix + ".mean.txt" std_dev_fn = stat_prefix + ".std_dev.txt" print("Dump priors to", priors_fn, file=log.v1) numpy.savetxt(priors_fn, log_priors) print("Dump mean to", mean_fn, file=log.v1) numpy.savetxt(mean_fn, mean) print("Dump std dev to", std_dev_fn, file=log.v1) numpy.savetxt(std_dev_fn, std_dev) print("Done.", file=log.v1)
def get_raw_strings(dataset, options): """ :param Dataset dataset: :param options: argparse.Namespace :return: list of (seq tag, string) :rtype: list[(str,str)] """ refs = [] start_time = time.time() seq_len_stats = Stats() seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") interactive = util.is_tty() and not log.verbose[5] print("Iterating over %r." % dataset, file=log.v2) while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq: dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % ( seq_idx, num_seqs_s, ) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) seq_tag = dataset.get_tag(seq_idx) assert isinstance(seq_tag, str) ref = dataset.get_data(seq_idx, options.key) if isinstance(ref, numpy.ndarray): assert ref.shape == () or (ref.ndim == 1 and ref.dtype == numpy.uint8) if ref.shape == (): ref = ref.flatten()[0] # get the entry itself (str or bytes) else: ref = ref.tobytes() if isinstance(ref, bytes): ref = ref.decode("utf8") assert isinstance(ref, str) seq_len_stats.collect([len(ref)]) refs.append((seq_tag, ref)) if interactive: util.progress_bar_with_time(complete_frac, prefix=progress_prefix) elif log.verbose[5]: print(progress_prefix, "seq tag %r, ref len %i chars" % (seq_tag, len(ref))) seq_idx += 1 print("Done. Num seqs %i. Total time %s." % (seq_idx, hms(time.time() - start_time)), file=log.v1) print("More seqs which we did not dumped: %s." % (dataset.is_less_than_num_seqs(seq_idx), ), file=log.v1) seq_len_stats.dump(stream_prefix="Seq-length %r " % (options.key, ), stream=log.v2) return refs
def calc_wer_on_dataset(dataset, refs, options, hyps): """ :param Dataset|None dataset: :param dict[str,str]|None refs: seq tag -> ref string (words delimited by space) :param options: argparse.Namespace :param dict[str,str] hyps: seq tag -> hyp string (words delimited by space) :return: WER :rtype: float """ assert dataset or refs start_time = time.time() seq_len_stats = {"refs": Stats(), "hyps": Stats()} seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") wer = 1.0 remaining_hyp_seq_tags = set(hyps.keys()) interactive = util.is_tty() and not log.verbose[5] collected = {"hyps": [], "refs": []} max_num_collected = 1 if dataset: dataset.init_seq_order(epoch=1) else: refs = sorted(refs.items(), key=lambda item: len(item[1])) while True: if seq_idx > options.endseq: break if dataset: if not dataset.is_less_than_num_seqs(seq_idx): break dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) seq_tag = dataset.get_tag(seq_idx) assert isinstance(seq_tag, str) ref = dataset.get_data(seq_idx, options.key) if isinstance(ref, numpy.ndarray): assert ref.shape == () ref = ref.flatten()[0] # get the entry itself (str or bytes) if isinstance(ref, bytes): ref = ref.decode("utf8") assert isinstance(ref, str) try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" else: if seq_idx >= len(refs): break complete_frac = (seq_idx + 1) / float(len(refs)) seq_tag, ref = refs[seq_idx] assert isinstance(seq_tag, str) assert isinstance(ref, str) num_seqs_s = str(len(refs)) start_elapsed = time.time() - start_time progress_prefix = "%i/%s (WER %.02f%%)" % (seq_idx, num_seqs_s, wer * 100) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) remaining_hyp_seq_tags.remove(seq_tag) hyp = hyps[seq_tag] seq_len_stats["hyps"].collect([len(hyp)]) seq_len_stats["refs"].collect([len(ref)]) collected["hyps"].append(hyp) collected["refs"].append(ref) if len(collected["hyps"]) >= max_num_collected: wer = wer_compute.step(session, **collected) del collected["hyps"][:] del collected["refs"][:] if interactive: util.progress_bar_with_time(complete_frac, prefix=progress_prefix) elif log.verbose[5]: print( progress_prefix, "seq tag %r, ref/hyp len %i/%i chars" % (seq_tag, len(ref), len(hyp))) seq_idx += 1 if len(collected["hyps"]) > 0: wer = wer_compute.step(session, **collected) print("Done. Num seqs %i. Total time %s." % (seq_idx, hms(time.time() - start_time)), file=log.v1) print("Remaining num hyp seqs %i." % (len(remaining_hyp_seq_tags), ), file=log.v1) if dataset: print("More seqs which we did not dumped: %s." % dataset.is_less_than_num_seqs(seq_idx), file=log.v1) for key in ["hyps", "refs"]: seq_len_stats[key].dump(stream_prefix="Seq-length %r %r " % (key, options.key), stream=log.v2) if options.expect_full: assert not remaining_hyp_seq_tags, "There are still remaining hypotheses." return wer
def dump_dataset(dataset, options): """ :type dataset: Dataset.Dataset :param options: argparse.Namespace """ print("Epoch: %i" % options.epoch, file=log.v3) dataset.init_seq_order(epoch=options.epoch) print("Dataset keys:", dataset.get_data_keys(), file=log.v3) print("Dataset target keys:", dataset.get_target_list(), file=log.v3) assert options.key in dataset.get_data_keys() if options.get_num_seqs: print("Get num seqs.") print("estimated_num_seqs: %r" % dataset.estimated_num_seqs) try: print("num_seqs: %r" % dataset.num_seqs) except Exception as exc: print("num_seqs exception %r, which is valid, so we count." % exc) seq_idx = 0 if dataset.get_target_list(): default_target = dataset.get_target_list()[0] else: default_target = None while dataset.is_less_than_num_seqs(seq_idx): dataset.load_seqs(seq_idx, seq_idx + 1) if seq_idx % 10000 == 0: if default_target: targets = dataset.get_targets(default_target, seq_idx) postfix = " (targets = %r...)" % (targets[:10], ) else: postfix = "" print("%i ...%s" % (seq_idx, postfix)) seq_idx += 1 print("accumulated num seqs: %i" % seq_idx) print("Done.") return dump_file = None if options.type == "numpy": print("Dump files: %r*%r" % (options.dump_prefix, options.dump_postfix), file=log.v3) elif options.type == "stdout": print("Dump to stdout", file=log.v3) if options.stdout_limit is not None: util.set_pretty_print_default_limit(options.stdout_limit) numpy.set_printoptions( threshold=sys.maxsize if options.stdout_limit == float("inf") else int(options.stdout_limit)) if options.stdout_as_bytes: util.set_pretty_print_as_bytes(options.stdout_as_bytes) elif options.type == "print_tag": print("Dump seq tag to stdout", file=log.v3) elif options.type == "dump_tag": dump_file = open("%sseq-tags.txt" % options.dump_prefix, "w") print("Dump seq tag to file: %s" % (dump_file.name, ), file=log.v3) elif options.type == "dump_seq_len": dump_file = open("%sseq-lens.txt" % options.dump_prefix, "w") print("Dump seq lens to file: %s" % (dump_file.name, ), file=log.v3) dump_file.write("{\n") elif options.type == "print_shape": print("Dump shape to stdout", file=log.v3) elif options.type == "plot": print("Plot.", file=log.v3) elif options.type == "interactive": print("Interactive debug shell.", file=log.v3) elif options.type == "null": if options.dump_stats: print("No dump (except stats).") else: print("No dump.") else: raise Exception("unknown dump option type %r" % options.type) start_time = time.time() stats = Stats() if (options.stats or options.dump_stats) else None seq_len_stats = {key: Stats() for key in dataset.get_data_keys()} seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq: dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % (seq_idx, num_seqs_s) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) data = None if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) if options.type == "print_tag": print( "seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx)) elif options.type == "dump_tag": print( "seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx)) dump_file.write("%s\n" % dataset.get_tag(seq_idx)) elif options.type == "dump_seq_len": seq_len = dataset.get_seq_length(seq_idx)[options.key] print( "seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx), "%r len:" % options.key, seq_len) dump_file.write("%r: %r,\n" % (dataset.get_tag(seq_idx), seq_len)) else: data = dataset.get_data(seq_idx, options.key) if options.type == "numpy": numpy.savetxt( "%s%i.data%s" % (options.dump_prefix, seq_idx, options.dump_postfix), data) elif options.type == "stdout": print("seq %s tag:" % progress, dataset.get_tag(seq_idx)) print("seq %s data:" % progress, pretty_print(data)) elif options.type == "print_shape": print("seq %s data shape:" % progress, data.shape) elif options.type == "plot": plot(data) for target in dataset.get_target_list(): targets = dataset.get_targets(target, seq_idx) if options.type == "numpy": numpy.savetxt("%s%i.targets.%s%s" % (options.dump_prefix, seq_idx, target, options.dump_postfix), targets, fmt='%i') elif options.type == "stdout": extra = "" if target in dataset.labels and len( dataset.labels[target]) > 1: assert dataset.can_serialize_data(target) extra += " (%r)" % dataset.serialize_data(key=target, data=targets) print("seq %i target %r: %s%s" % (seq_idx, target, pretty_print(targets), extra)) elif options.type == "print_shape": print("seq %i target %r shape:" % (seq_idx, target), targets.shape) if options.type == "interactive": from returnn.util.debug import debug_shell debug_shell(locals()) seq_len = dataset.get_seq_length(seq_idx) for key in dataset.get_data_keys(): seq_len_stats[key].collect([seq_len[key]]) if stats: stats.collect(data) if options.type == "null": util.progress_bar_with_time(complete_frac, prefix=progress_prefix) seq_idx += 1 print("Done. Total time %s. More seqs which we did not dumped: %s" % (hms_fraction(time.time() - start_time), dataset.is_less_than_num_seqs(seq_idx)), file=log.v2) for key in dataset.get_data_keys(): seq_len_stats[key].dump(stream_prefix="Seq-length %r " % key, stream=log.v2) if stats: stats.dump(output_file_prefix=options.dump_stats, stream_prefix="Data %r " % options.key, stream=log.v1) if options.type == "dump_seq_len": dump_file.write("}\n") if dump_file: print("Dumped to file:", dump_file.name, file=log.v2) dump_file.close()
def demo(): """ Demo. """ print("SprintDataset demo.") from argparse import ArgumentParser from returnn.util.basic import progress_bar_with_time from returnn.log import log from returnn.config import Config from returnn.datasets.basic import init_dataset arg_parser = ArgumentParser() arg_parser.add_argument("--config", help="config with ExternSprintDataset", required=True) arg_parser.add_argument("--sprint_cache_dataset", help="kwargs dict for SprintCacheDataset", required=True) arg_parser.add_argument("--max_num_seqs", default=sys.maxsize, type=int) arg_parser.add_argument("--action", default="compare", help="compare or benchmark") args = arg_parser.parse_args() log.initialize(verbosity=[4]) sprint_cache_dataset_kwargs = eval(args.sprint_cache_dataset) assert isinstance(sprint_cache_dataset_kwargs, dict) sprint_cache_dataset = SprintCacheDataset(**sprint_cache_dataset_kwargs) print("SprintCacheDataset: %r" % sprint_cache_dataset) config = Config() config.load_file(args.config) dataset = init_dataset(config.typed_value("train")) print("Dataset via config: %r" % dataset) assert sprint_cache_dataset.num_inputs == dataset.num_inputs assert tuple(sprint_cache_dataset.num_outputs["classes"]) == tuple(dataset.num_outputs["classes"]) sprint_cache_dataset.init_seq_order(epoch=1) if args.action == "compare": print("Iterating through dataset...") seq_idx = 0 dataset.init_seq_order(epoch=1) while seq_idx < args.max_num_seqs: if not dataset.is_less_than_num_seqs(seq_idx): break dataset.load_seqs(seq_idx, seq_idx + 1) tag = dataset.get_tag(seq_idx) assert not tag.startswith("seq-"), "dataset does not provide tag-names for seqs" dataset_seq = sprint_cache_dataset.get_dataset_seq_for_name(tag) data = dataset.get_data(seq_idx, "data") targets = dataset.get_data(seq_idx, "classes") assert data.shape == dataset_seq.features["data"].shape assert targets.shape == dataset_seq.features["classes"].shape assert numpy.allclose(data, dataset_seq.features["data"]) assert numpy.allclose(targets, dataset_seq.features["classes"]) seq_idx += 1 progress_bar_with_time(dataset.get_complete_frac(seq_idx)) print("Finished through dataset. Num seqs: %i" % seq_idx) print("SprintCacheDataset has num seqs: %i." % sprint_cache_dataset.num_seqs) elif args.action == "benchmark": print("Iterating through dataset...") start_time = time.time() seq_tags = [] seq_idx = 0 dataset.init_seq_order(epoch=1) while seq_idx < args.max_num_seqs: if not dataset.is_less_than_num_seqs(seq_idx): break dataset.load_seqs(seq_idx, seq_idx + 1) tag = dataset.get_tag(seq_idx) assert not tag.startswith("seq-"), "dataset does not provide tag-names for seqs" seq_tags.append(tag) dataset.get_data(seq_idx, "data") dataset.get_data(seq_idx, "classes") seq_idx += 1 progress_bar_with_time(dataset.get_complete_frac(seq_idx)) print("Finished through dataset. Num seqs: %i, time: %f" % (seq_idx, time.time() - start_time)) print("SprintCacheDataset has num seqs: %i." % sprint_cache_dataset.num_seqs) if hasattr(dataset, "exit_handler"): dataset.exit_handler() else: print("No way to stop any background tasks.") del dataset start_time = time.time() print("Iterating through SprintCacheDataset...") for i, tag in enumerate(seq_tags): sprint_cache_dataset.get_dataset_seq_for_name(tag) progress_bar_with_time(float(i) / len(seq_tags)) print("Finished through SprintCacheDataset. time: %f" % (time.time() - start_time,)) else: raise Exception("invalid action: %r" % args.action)
def analyze_dataset(options): """ :param options: argparse.Namespace """ print("Epoch: %i" % options.epoch, file=log.v3) print("Dataset keys:", dataset.get_data_keys(), file=log.v3) print("Dataset target keys:", dataset.get_target_list(), file=log.v3) assert options.key in dataset.get_data_keys() terminal_width, _ = util.terminal_size() show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5]) and terminal_width >= 0) start_time = time.time() num_seqs_stats = Stats() if options.endseq < 0: options.endseq = float("inf") recurrent = True used_data_keys = dataset.get_data_keys() batch_size = config.typed_value('batch_size', 1) max_seqs = config.int('max_seqs', -1) seq_drop = config.float('seq_drop', 0.0) max_seq_length = config.typed_value( 'max_seq_length', None) or config.float('max_seq_length', 0) max_pad_size = config.typed_value("max_pad_size", None) batches = dataset.generate_batches(recurrent_net=recurrent, batch_size=batch_size, max_seqs=max_seqs, max_seq_length=max_seq_length, max_pad_size=max_pad_size, seq_drop=seq_drop, used_data_keys=used_data_keys) step = 0 total_num_seqs = 0 total_num_frames = NumbersDict() total_num_used_frames = NumbersDict() try: while batches.has_more(): # See FeedDictDataProvider. batch, = batches.peek_next_n(1) assert isinstance(batch, Batch) if batch.start_seq > options.endseq: break dataset.load_seqs(batch.start_seq, batch.end_seq) complete_frac = batches.completed_frac() start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % (batch.start_seq, num_seqs_s) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) batch_max_time = NumbersDict.max( [seq.frame_length for seq in batch.seqs]) * len(batch.seqs) batch_num_used_frames = sum( [seq.frame_length for seq in batch.seqs], NumbersDict()) total_num_seqs += len(batch.seqs) num_seqs_stats.collect(numpy.array([len(batch.seqs)])) total_num_frames += batch_max_time total_num_used_frames += batch_num_used_frames print("%s, batch %i, num seqs %i, frames %s, used %s (%s)" % (progress, step, len( batch.seqs), batch_max_time, batch_num_used_frames, batch_num_used_frames / batch_max_time), file=log.v5) if show_interactive_process_bar: util.progress_bar_with_time(complete_frac, prefix=progress_prefix) step += 1 batches.advance(1) finally: print("Done. Total time %s. More seqs which we did not dumped: %s" % (hms(time.time() - start_time), batches.has_more()), file=log.v2) print("Dataset epoch %i, order %r." % (dataset.epoch, dataset.seq_ordering)) print("Num batches (steps): %i" % step, file=log.v1) print("Num seqs: %i" % total_num_seqs, file=log.v1) num_seqs_stats.dump(stream=log.v1, stream_prefix="Batch num seqs ") for key in used_data_keys: print("Data key %r:" % key, file=log.v1) print(" Num frames: %s" % total_num_frames[key], file=log.v1) print(" Num used frames: %s" % total_num_used_frames[key], file=log.v1) print(" Fraction used frames: %s" % (total_num_used_frames / total_num_frames)[key], file=log.v1) dataset.finish_epoch()