def dump(options): """ :param options: argparse.Namespace """ print("Epoch: %i" % options.epoch, file=log.v3) dataset.init_seq_order(options.epoch) stats = Stats() output = engine.network.get_layer( options.layer).output.copy_as_batch_major() def _extra_fetches_cb(inputs, **kwargs): n_batch = inputs.shape[0] # noinspection PyShadowingNames seq_len = { i: kwargs["seq_len_%i" % i] for i in output.size_placeholder.keys() } assert all([len(v) == n_batch for v in seq_len.values()]) assert set(seq_len.keys()) == {0} # not implemented otherwise for n in range(n_batch): stats.collect(inputs[n, :seq_len[0][n]]) extra_fetches = { 'inputs': output.placeholder, } for i, seq_len in output.size_placeholder.items(): extra_fetches["seq_len_%i" % i] = seq_len batches = dataset.generate_batches( recurrent_net=True, # Want seq lengths batch_size=config.typed_value('batch_size', 1), max_seqs=config.int('max_seqs', -1), used_data_keys=engine.network.get_used_data_keys()) forwarder = Runner(engine=engine, dataset=dataset, batches=batches, train=False, eval=False, extra_fetches=extra_fetches, extra_fetches_callback=_extra_fetches_cb) forwarder.run(report_prefix="forward") if not forwarder.finalized: print("Error happened. Exit now.") sys.exit(1) stats.dump(output_file_prefix=options.dump_stats, stream_prefix="Layer %r " % options.layer)
def get_raw_strings(dataset, options): """ :param Dataset dataset: :param options: argparse.Namespace :return: list of (seq tag, string) :rtype: list[(str,str)] """ refs = [] start_time = time.time() seq_len_stats = Stats() seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") interactive = util.is_tty() and not log.verbose[5] print("Iterating over %r." % dataset, file=log.v2) while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq: dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % ( seq_idx, num_seqs_s, ) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) seq_tag = dataset.get_tag(seq_idx) assert isinstance(seq_tag, str) ref = dataset.get_data(seq_idx, options.key) if isinstance(ref, numpy.ndarray): assert ref.shape == () or (ref.ndim == 1 and ref.dtype == numpy.uint8) if ref.shape == (): ref = ref.flatten()[0] # get the entry itself (str or bytes) else: ref = ref.tobytes() if isinstance(ref, bytes): ref = ref.decode("utf8") assert isinstance(ref, str) seq_len_stats.collect([len(ref)]) refs.append((seq_tag, ref)) if interactive: util.progress_bar_with_time(complete_frac, prefix=progress_prefix) elif log.verbose[5]: print(progress_prefix, "seq tag %r, ref len %i chars" % (seq_tag, len(ref))) seq_idx += 1 print("Done. Num seqs %i. Total time %s." % (seq_idx, hms(time.time() - start_time)), file=log.v1) print("More seqs which we did not dumped: %s." % (dataset.is_less_than_num_seqs(seq_idx), ), file=log.v1) seq_len_stats.dump(stream_prefix="Seq-length %r " % (options.key, ), stream=log.v2) return refs
def main(argv): """ Main entry. """ argparser = argparse.ArgumentParser(description=__doc__) argparser.add_argument("config_file", type=str, help="RETURNN config, or model-dir") argparser.add_argument("--epoch", type=int) argparser.add_argument( '--data', default="train", help= "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'" ) argparser.add_argument('--do_search', default=False, action='store_true') argparser.add_argument('--beam_size', default=12, type=int) argparser.add_argument('--dump_dir', help="for npy or png") argparser.add_argument("--output_file", help="hdf") argparser.add_argument("--device", help="gpu or cpu (default: automatic)") argparser.add_argument("--layers", default=["att_weights"], action="append", help="Layer of subnet to grab") argparser.add_argument("--rec_layer", default="output", help="Subnet layer to grab from; decoder") argparser.add_argument("--enc_layer", default="encoder") argparser.add_argument("--batch_size", type=int, default=5000) argparser.add_argument("--seq_list", default=[], action="append", help="predefined list of seqs") argparser.add_argument("--min_seq_len", default="0", help="can also be dict") argparser.add_argument("--num_seqs", default=-1, type=int, help="stop after this many seqs") argparser.add_argument("--output_format", default="npy", help="npy, png or hdf") argparser.add_argument("--dropout", default=None, type=float, help="if set, overwrites all dropout values") argparser.add_argument("--train_flag", action="store_true") argparser.add_argument("--reset_partition_epoch", type=int, default=1) argparser.add_argument("--reset_seq_ordering", default="sorted_reverse") argparser.add_argument("--reset_epoch_wise_filter", default=None) args = argparser.parse_args(argv[1:]) layers = args.layers assert isinstance(layers, list) config_fn = args.config_file explicit_model_dir = None if os.path.isdir(config_fn): # Assume we gave a model dir. explicit_model_dir = config_fn train_log_dir_config_pattern = "%s/train-*/*.config" % config_fn train_log_dir_configs = sorted(glob(train_log_dir_config_pattern)) assert train_log_dir_configs config_fn = train_log_dir_configs[-1] print("Using this config via model dir:", config_fn) else: assert os.path.isfile(config_fn) model_name = ".".join(config_fn.split("/")[-1].split(".")[:-1]) init_returnn(config_fn=config_fn, args=args) if explicit_model_dir: config.set( "model", "%s/%s" % (explicit_model_dir, os.path.basename(config.value('model', '')))) print("Model file prefix:", config.value('model', '')) if args.do_search: raise NotImplementedError min_seq_length = NumbersDict(eval(args.min_seq_len)) assert args.output_format in ["npy", "png", "hdf"] if args.output_format in ["npy", "png"]: assert args.dump_dir if not os.path.exists(args.dump_dir): os.makedirs(args.dump_dir) plt = ticker = None if args.output_format == "png": import matplotlib.pyplot as plt # need to import early? https://stackoverflow.com/a/45582103/133374 import matplotlib.ticker as ticker dataset_str = args.data if dataset_str in ["train", "dev", "eval"]: dataset_str = "config:%s" % dataset_str extra_dataset_kwargs = {} if args.reset_partition_epoch: print("NOTE: We are resetting partition epoch to %i." % (args.reset_partition_epoch, )) extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch if args.reset_seq_ordering: print("NOTE: We will use %r seq ordering." % (args.reset_seq_ordering, )) extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering if args.reset_epoch_wise_filter: extra_dataset_kwargs["epoch_wise_filter"] = eval( args.reset_epoch_wise_filter) dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs) if hasattr(dataset, "epoch_wise_filter") and args.reset_epoch_wise_filter is None: if dataset.epoch_wise_filter: print("NOTE: Resetting epoch_wise_filter to None.") dataset.epoch_wise_filter = None if args.reset_partition_epoch: assert dataset.partition_epoch == args.reset_partition_epoch if args.reset_seq_ordering: assert dataset.seq_ordering == args.reset_seq_ordering init_net(args, layers) network = rnn.engine.network hdf_writer = None if args.output_format == "hdf": assert args.output_file assert len(layers) == 1 sub_layer = network.get_layer("%s/%s" % (args.rec_layer, layers[0])) from returnn.datasets.hdf import SimpleHDFWriter hdf_writer = SimpleHDFWriter(filename=args.output_file, dim=sub_layer.output.dim, ndim=sub_layer.output.ndim) extra_fetches = { "output": network.layers[args.rec_layer].output.get_placeholder_as_batch_major(), "output_len": network.layers[ args.rec_layer].output.get_sequence_lengths(), # decoder length "encoder_len": network.layers[ args.enc_layer].output.get_sequence_lengths(), # encoder length "seq_idx": network.get_extern_data("seq_idx"), "seq_tag": network.get_extern_data("seq_tag"), "target_data": network.get_extern_data(network.extern_data.default_input), "target_classes": network.get_extern_data(network.extern_data.default_target), } for layer in layers: sub_layer = rnn.engine.network.get_layer("%s/%s" % (args.rec_layer, layer)) extra_fetches[ "rec_%s" % layer] = sub_layer.output.get_placeholder_as_batch_major() dataset.init_seq_order( epoch=1, seq_list=args.seq_list or None) # use always epoch 1, such that we have same seqs dataset_batch = dataset.generate_batches( recurrent_net=network.recurrent, batch_size=args.batch_size, max_seqs=rnn.engine.max_seqs, max_seq_length=sys.maxsize, min_seq_length=min_seq_length, max_total_num_seqs=args.num_seqs, used_data_keys=network.used_data_keys) stats = {layer: Stats() for layer in layers} # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output, output_len, encoder_len, **kwargs): """ :param list[int] seq_idx: len is n_batch :param list[str] seq_tag: len is n_batch :param numpy.ndarray target_data: extern data default input (e.g. "data"), shape e.g. (B,enc-T,...) :param numpy.ndarray target_classes: extern data default target (e.g. "classes"), shape e.g. (B,dec-T,...) :param numpy.ndarray output: rec layer output, shape e.g. (B,dec-T,...) :param numpy.ndarray output_len: rec layer seq len, i.e. decoder length, shape (B,) :param numpy.ndarray encoder_len: encoder seq len, shape (B,) :param kwargs: contains "rec_%s" % l for l in layers, the sub layers (e.g att weights) we are interested in """ n_batch = len(seq_idx) for i in range(n_batch): # noinspection PyShadowingNames for layer in layers: att_weights = kwargs["rec_%s" % layer][i] stats[layer].collect(att_weights.flatten()) if args.output_format == "npy": data = {} for i in range(n_batch): data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i], 'output_len': output_len[i], 'encoder_len': encoder_len[i], } # noinspection PyShadowingNames for layer in [("rec_%s" % layer) for layer in layers]: assert layer in kwargs out = kwargs[layer][i] assert out.ndim >= 2 assert out.shape[0] >= output_len[i] and out.shape[ 1] >= encoder_len[i] data[i][layer] = out[:output_len[i], :encoder_len[i]] fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % ( model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1]) np.save(fname, data) elif args.output_format == "png": for i in range(n_batch): # noinspection PyShadowingNames for layer in layers: extra_postfix = "" if args.dropout is not None: extra_postfix += "_dropout%.2f" % args.dropout elif args.train_flag: extra_postfix += "_train" fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % ( model_name, rnn.engine.epoch, seq_idx[i], layer, extra_postfix) att_weights = kwargs["rec_%s" % layer][i] att_weights = att_weights.squeeze(axis=2) # (out,enc) assert att_weights.shape[0] >= output_len[ i] and att_weights.shape[1] >= encoder_len[i] att_weights = att_weights[:output_len[i], :encoder_len[i]] print("Seq %i, %s: Dump att weights with shape %r to: %s" % (seq_idx[i], seq_tag[i], att_weights.shape, fname)) plt.matshow(att_weights) title = seq_tag[i] if dataset.can_serialize_data( network.extern_data.default_target): title += "\n" + dataset.serialize_data( network.extern_data.default_target, target_classes[i][:output_len[i]]) ax = plt.gca() tick_labels = [ dataset.serialize_data( network.extern_data.default_target, np.array([x], dtype=target_classes[i].dtype)) for x in target_classes[i][:output_len[i]] ] ax.set_yticklabels([''] + tick_labels, fontsize=8) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.title(title) plt.savefig(fname) plt.close() elif args.output_format == "hdf": assert len(layers) == 1 att_weights = kwargs["rec_%s" % layers[0]] hdf_writer.insert_batch(inputs=att_weights, seq_len={ 0: output_len, 1: encoder_len }, seq_tag=seq_tag) else: raise Exception("output format %r" % args.output_format) runner = Runner(engine=rnn.engine, dataset=dataset, batches=dataset_batch, train=False, train_flag=bool(args.dropout) or args.train_flag, extra_fetches=extra_fetches, extra_fetches_callback=fetch_callback) runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch) for layer in layers: stats[layer].dump(stream_prefix="Layer %r " % layer) if not runner.finalized: print("Some error occured, not finalized.") sys.exit(1) if hdf_writer: hdf_writer.close() rnn.finalize()
def calc_wer_on_dataset(dataset, refs, options, hyps): """ :param Dataset|None dataset: :param dict[str,str]|None refs: seq tag -> ref string (words delimited by space) :param options: argparse.Namespace :param dict[str,str] hyps: seq tag -> hyp string (words delimited by space) :return: WER :rtype: float """ assert dataset or refs start_time = time.time() seq_len_stats = {"refs": Stats(), "hyps": Stats()} seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") wer = 1.0 remaining_hyp_seq_tags = set(hyps.keys()) interactive = util.is_tty() and not log.verbose[5] collected = {"hyps": [], "refs": []} max_num_collected = 1 if dataset: dataset.init_seq_order(epoch=1) else: refs = sorted(refs.items(), key=lambda item: len(item[1])) while True: if seq_idx > options.endseq: break if dataset: if not dataset.is_less_than_num_seqs(seq_idx): break dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) seq_tag = dataset.get_tag(seq_idx) assert isinstance(seq_tag, str) ref = dataset.get_data(seq_idx, options.key) if isinstance(ref, numpy.ndarray): assert ref.shape == () ref = ref.flatten()[0] # get the entry itself (str or bytes) if isinstance(ref, bytes): ref = ref.decode("utf8") assert isinstance(ref, str) try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" else: if seq_idx >= len(refs): break complete_frac = (seq_idx + 1) / float(len(refs)) seq_tag, ref = refs[seq_idx] assert isinstance(seq_tag, str) assert isinstance(ref, str) num_seqs_s = str(len(refs)) start_elapsed = time.time() - start_time progress_prefix = "%i/%s (WER %.02f%%)" % (seq_idx, num_seqs_s, wer * 100) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) remaining_hyp_seq_tags.remove(seq_tag) hyp = hyps[seq_tag] seq_len_stats["hyps"].collect([len(hyp)]) seq_len_stats["refs"].collect([len(ref)]) collected["hyps"].append(hyp) collected["refs"].append(ref) if len(collected["hyps"]) >= max_num_collected: wer = wer_compute.step(session, **collected) del collected["hyps"][:] del collected["refs"][:] if interactive: util.progress_bar_with_time(complete_frac, prefix=progress_prefix) elif log.verbose[5]: print( progress_prefix, "seq tag %r, ref/hyp len %i/%i chars" % (seq_tag, len(ref), len(hyp))) seq_idx += 1 if len(collected["hyps"]) > 0: wer = wer_compute.step(session, **collected) print("Done. Num seqs %i. Total time %s." % (seq_idx, hms(time.time() - start_time)), file=log.v1) print("Remaining num hyp seqs %i." % (len(remaining_hyp_seq_tags), ), file=log.v1) if dataset: print("More seqs which we did not dumped: %s." % dataset.is_less_than_num_seqs(seq_idx), file=log.v1) for key in ["hyps", "refs"]: seq_len_stats[key].dump(stream_prefix="Seq-length %r %r " % (key, options.key), stream=log.v2) if options.expect_full: assert not remaining_hyp_seq_tags, "There are still remaining hypotheses." return wer
def dump_dataset(dataset, options): """ :type dataset: Dataset.Dataset :param options: argparse.Namespace """ print("Epoch: %i" % options.epoch, file=log.v3) dataset.init_seq_order(epoch=options.epoch) print("Dataset keys:", dataset.get_data_keys(), file=log.v3) print("Dataset target keys:", dataset.get_target_list(), file=log.v3) assert options.key in dataset.get_data_keys() if options.get_num_seqs: print("Get num seqs.") print("estimated_num_seqs: %r" % dataset.estimated_num_seqs) try: print("num_seqs: %r" % dataset.num_seqs) except Exception as exc: print("num_seqs exception %r, which is valid, so we count." % exc) seq_idx = 0 if dataset.get_target_list(): default_target = dataset.get_target_list()[0] else: default_target = None while dataset.is_less_than_num_seqs(seq_idx): dataset.load_seqs(seq_idx, seq_idx + 1) if seq_idx % 10000 == 0: if default_target: targets = dataset.get_targets(default_target, seq_idx) postfix = " (targets = %r...)" % (targets[:10], ) else: postfix = "" print("%i ...%s" % (seq_idx, postfix)) seq_idx += 1 print("accumulated num seqs: %i" % seq_idx) print("Done.") return dump_file = None if options.type == "numpy": print("Dump files: %r*%r" % (options.dump_prefix, options.dump_postfix), file=log.v3) elif options.type == "stdout": print("Dump to stdout", file=log.v3) if options.stdout_limit is not None: util.set_pretty_print_default_limit(options.stdout_limit) numpy.set_printoptions( threshold=sys.maxsize if options.stdout_limit == float("inf") else int(options.stdout_limit)) if options.stdout_as_bytes: util.set_pretty_print_as_bytes(options.stdout_as_bytes) elif options.type == "print_tag": print("Dump seq tag to stdout", file=log.v3) elif options.type == "dump_tag": dump_file = open("%sseq-tags.txt" % options.dump_prefix, "w") print("Dump seq tag to file: %s" % (dump_file.name, ), file=log.v3) elif options.type == "dump_seq_len": dump_file = open("%sseq-lens.txt" % options.dump_prefix, "w") print("Dump seq lens to file: %s" % (dump_file.name, ), file=log.v3) dump_file.write("{\n") elif options.type == "print_shape": print("Dump shape to stdout", file=log.v3) elif options.type == "plot": print("Plot.", file=log.v3) elif options.type == "interactive": print("Interactive debug shell.", file=log.v3) elif options.type == "null": if options.dump_stats: print("No dump (except stats).") else: print("No dump.") else: raise Exception("unknown dump option type %r" % options.type) start_time = time.time() stats = Stats() if (options.stats or options.dump_stats) else None seq_len_stats = {key: Stats() for key in dataset.get_data_keys()} seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq: dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % (seq_idx, num_seqs_s) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) data = None if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) if options.type == "print_tag": print( "seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx)) elif options.type == "dump_tag": print( "seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx)) dump_file.write("%s\n" % dataset.get_tag(seq_idx)) elif options.type == "dump_seq_len": seq_len = dataset.get_seq_length(seq_idx)[options.key] print( "seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx), "%r len:" % options.key, seq_len) dump_file.write("%r: %r,\n" % (dataset.get_tag(seq_idx), seq_len)) else: data = dataset.get_data(seq_idx, options.key) if options.type == "numpy": numpy.savetxt( "%s%i.data%s" % (options.dump_prefix, seq_idx, options.dump_postfix), data) elif options.type == "stdout": print("seq %s tag:" % progress, dataset.get_tag(seq_idx)) print("seq %s data:" % progress, pretty_print(data)) elif options.type == "print_shape": print("seq %s data shape:" % progress, data.shape) elif options.type == "plot": plot(data) for target in dataset.get_target_list(): targets = dataset.get_targets(target, seq_idx) if options.type == "numpy": numpy.savetxt("%s%i.targets.%s%s" % (options.dump_prefix, seq_idx, target, options.dump_postfix), targets, fmt='%i') elif options.type == "stdout": extra = "" if target in dataset.labels and len( dataset.labels[target]) > 1: assert dataset.can_serialize_data(target) extra += " (%r)" % dataset.serialize_data(key=target, data=targets) print("seq %i target %r: %s%s" % (seq_idx, target, pretty_print(targets), extra)) elif options.type == "print_shape": print("seq %i target %r shape:" % (seq_idx, target), targets.shape) if options.type == "interactive": from returnn.util.debug import debug_shell debug_shell(locals()) seq_len = dataset.get_seq_length(seq_idx) for key in dataset.get_data_keys(): seq_len_stats[key].collect([seq_len[key]]) if stats: stats.collect(data) if options.type == "null": util.progress_bar_with_time(complete_frac, prefix=progress_prefix) seq_idx += 1 print("Done. Total time %s. More seqs which we did not dumped: %s" % (hms_fraction(time.time() - start_time), dataset.is_less_than_num_seqs(seq_idx)), file=log.v2) for key in dataset.get_data_keys(): seq_len_stats[key].dump(stream_prefix="Seq-length %r " % key, stream=log.v2) if stats: stats.dump(output_file_prefix=options.dump_stats, stream_prefix="Data %r " % options.key, stream=log.v1) if options.type == "dump_seq_len": dump_file.write("}\n") if dump_file: print("Dumped to file:", dump_file.name, file=log.v2) dump_file.close()
def analyze_dataset(options): """ :param options: argparse.Namespace """ print("Epoch: %i" % options.epoch, file=log.v3) print("Dataset keys:", dataset.get_data_keys(), file=log.v3) print("Dataset target keys:", dataset.get_target_list(), file=log.v3) assert options.key in dataset.get_data_keys() terminal_width, _ = util.terminal_size() show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5]) and terminal_width >= 0) start_time = time.time() num_seqs_stats = Stats() if options.endseq < 0: options.endseq = float("inf") recurrent = True used_data_keys = dataset.get_data_keys() batch_size = config.typed_value('batch_size', 1) max_seqs = config.int('max_seqs', -1) seq_drop = config.float('seq_drop', 0.0) max_seq_length = config.typed_value( 'max_seq_length', None) or config.float('max_seq_length', 0) max_pad_size = config.typed_value("max_pad_size", None) batches = dataset.generate_batches(recurrent_net=recurrent, batch_size=batch_size, max_seqs=max_seqs, max_seq_length=max_seq_length, max_pad_size=max_pad_size, seq_drop=seq_drop, used_data_keys=used_data_keys) step = 0 total_num_seqs = 0 total_num_frames = NumbersDict() total_num_used_frames = NumbersDict() try: while batches.has_more(): # See FeedDictDataProvider. batch, = batches.peek_next_n(1) assert isinstance(batch, Batch) if batch.start_seq > options.endseq: break dataset.load_seqs(batch.start_seq, batch.end_seq) complete_frac = batches.completed_frac() start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % (batch.start_seq, num_seqs_s) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) batch_max_time = NumbersDict.max( [seq.frame_length for seq in batch.seqs]) * len(batch.seqs) batch_num_used_frames = sum( [seq.frame_length for seq in batch.seqs], NumbersDict()) total_num_seqs += len(batch.seqs) num_seqs_stats.collect(numpy.array([len(batch.seqs)])) total_num_frames += batch_max_time total_num_used_frames += batch_num_used_frames print("%s, batch %i, num seqs %i, frames %s, used %s (%s)" % (progress, step, len( batch.seqs), batch_max_time, batch_num_used_frames, batch_num_used_frames / batch_max_time), file=log.v5) if show_interactive_process_bar: util.progress_bar_with_time(complete_frac, prefix=progress_prefix) step += 1 batches.advance(1) finally: print("Done. Total time %s. More seqs which we did not dumped: %s" % (hms(time.time() - start_time), batches.has_more()), file=log.v2) print("Dataset epoch %i, order %r." % (dataset.epoch, dataset.seq_ordering)) print("Num batches (steps): %i" % step, file=log.v1) print("Num seqs: %i" % total_num_seqs, file=log.v1) num_seqs_stats.dump(stream=log.v1, stream_prefix="Batch num seqs ") for key in used_data_keys: print("Data key %r:" % key, file=log.v1) print(" Num frames: %s" % total_num_frames[key], file=log.v1) print(" Num used frames: %s" % total_num_used_frames[key], file=log.v1) print(" Fraction used frames: %s" % (total_num_used_frames / total_num_frames)[key], file=log.v1) dataset.finish_epoch()