def print_process(self): if not self.parent.interactive and not log.v[5]: return start_elapsed = time.time() - self.parent.start_time complete = self.parent.batches.completed_frac() assert complete > 0 total_time_estimated = start_elapsed / complete remaining_estimated = total_time_estimated - start_elapsed if log.verbose[5]: mem_usage = self.device_mem_usage_str(self.alloc_devices) info = [ self.parent.report_prefix, "batch %i" % self.run_start_batch_idx ] if self.eval_info: # Such as score. info += [ "%s %s" % item for item in sorted(self.eval_info.items()) ] info += [ "elapsed %s" % hms(start_elapsed), "exp. remaining %s" % hms(remaining_estimated), "complete %.02f%%" % (complete * 100) ] if mem_usage: info += ["memory %s" % mem_usage] print(", ".join(filter(None, info)), file=log.v5) if self.parent.interactive: progress_bar(complete, hms(remaining_estimated))
def cb(frame_len, orth): if frame_len >= options.max_seq_frame_len: return orth_syms = parse_orthography(orth) if len(orth_syms) >= options.max_seq_orth_len: return Stats.count += 1 Stats.total_frame_len += frame_len if options.dump_orth_syms: print("Orth:", "".join(orth_syms), file=log.v3) if options.filter_orth_sym: if options.filter_orth_sym in orth_syms: print("Found orth:", "".join(orth_syms), file=log.v3) if options.filter_orth_syms_seq: filter_seq = parse_orthography_into_symbols( options.filter_orth_syms_seq) if found_sub_seq(filter_seq, orth_syms): print("Found orth:", "".join(orth_syms), file=log.v3) Stats.orth_syms_set.update(orth_syms) Stats.total_orth_len += len(orth_syms) # Show some progress if it takes long. if time.time() - Stats.process_last_time > 2: Stats.process_last_time = time.time() if options.collect_time: print("Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)), file=log.v3) else: print("Collect process, total orth len so far:", human_size(Stats.total_orth_len), file=log.v3)
def finish(self): """ End fold. """ elapsed_time = time.time() - self.start_time print("%s: Elapsed time: %s" % (self.name, hms(elapsed_time))) if travis_env: print("travis_fold:end:%s" % folds[-1]) if github_env: if len(folds) == 1: print("::endgroup::") sys.stdout.flush()
def get_raw_strings(dataset, options): """ :param Dataset dataset: :param options: argparse.Namespace :return: list of (seq tag, string) :rtype: list[(str,str)] """ refs = [] start_time = time.time() seq_len_stats = Stats() seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") interactive = util.is_tty() and not log.verbose[5] print("Iterating over %r." % dataset, file=log.v2) while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq: dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % ( seq_idx, num_seqs_s, ) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) seq_tag = dataset.get_tag(seq_idx) assert isinstance(seq_tag, str) ref = dataset.get_data(seq_idx, options.key) if isinstance(ref, numpy.ndarray): assert ref.shape == () or (ref.ndim == 1 and ref.dtype == numpy.uint8) if ref.shape == (): ref = ref.flatten()[0] # get the entry itself (str or bytes) else: ref = ref.tobytes() if isinstance(ref, bytes): ref = ref.decode("utf8") assert isinstance(ref, str) seq_len_stats.collect([len(ref)]) refs.append((seq_tag, ref)) if interactive: util.progress_bar_with_time(complete_frac, prefix=progress_prefix) elif log.verbose[5]: print(progress_prefix, "seq tag %r, ref len %i chars" % (seq_tag, len(ref))) seq_idx += 1 print("Done. Num seqs %i. Total time %s." % (seq_idx, hms(time.time() - start_time)), file=log.v1) print("More seqs which we did not dumped: %s." % (dataset.is_less_than_num_seqs(seq_idx), ), file=log.v1) seq_len_stats.dump(stream_prefix="Seq-length %r " % (options.key, ), stream=log.v2) return refs
def calc_wer_on_dataset(dataset, refs, options, hyps): """ :param Dataset|None dataset: :param dict[str,str]|None refs: seq tag -> ref string (words delimited by space) :param options: argparse.Namespace :param dict[str,str] hyps: seq tag -> hyp string (words delimited by space) :return: WER :rtype: float """ assert dataset or refs start_time = time.time() seq_len_stats = {"refs": Stats(), "hyps": Stats()} seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") wer = 1.0 remaining_hyp_seq_tags = set(hyps.keys()) interactive = util.is_tty() and not log.verbose[5] collected = {"hyps": [], "refs": []} max_num_collected = 1 if dataset: dataset.init_seq_order(epoch=1) else: refs = sorted(refs.items(), key=lambda item: len(item[1])) while True: if seq_idx > options.endseq: break if dataset: if not dataset.is_less_than_num_seqs(seq_idx): break dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) seq_tag = dataset.get_tag(seq_idx) assert isinstance(seq_tag, str) ref = dataset.get_data(seq_idx, options.key) if isinstance(ref, numpy.ndarray): assert ref.shape == () ref = ref.flatten()[0] # get the entry itself (str or bytes) if isinstance(ref, bytes): ref = ref.decode("utf8") assert isinstance(ref, str) try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" else: if seq_idx >= len(refs): break complete_frac = (seq_idx + 1) / float(len(refs)) seq_tag, ref = refs[seq_idx] assert isinstance(seq_tag, str) assert isinstance(ref, str) num_seqs_s = str(len(refs)) start_elapsed = time.time() - start_time progress_prefix = "%i/%s (WER %.02f%%)" % (seq_idx, num_seqs_s, wer * 100) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) remaining_hyp_seq_tags.remove(seq_tag) hyp = hyps[seq_tag] seq_len_stats["hyps"].collect([len(hyp)]) seq_len_stats["refs"].collect([len(ref)]) collected["hyps"].append(hyp) collected["refs"].append(ref) if len(collected["hyps"]) >= max_num_collected: wer = wer_compute.step(session, **collected) del collected["hyps"][:] del collected["refs"][:] if interactive: util.progress_bar_with_time(complete_frac, prefix=progress_prefix) elif log.verbose[5]: print( progress_prefix, "seq tag %r, ref/hyp len %i/%i chars" % (seq_tag, len(ref), len(hyp))) seq_idx += 1 if len(collected["hyps"]) > 0: wer = wer_compute.step(session, **collected) print("Done. Num seqs %i. Total time %s." % (seq_idx, hms(time.time() - start_time)), file=log.v1) print("Remaining num hyp seqs %i." % (len(remaining_hyp_seq_tags), ), file=log.v1) if dataset: print("More seqs which we did not dumped: %s." % dataset.is_less_than_num_seqs(seq_idx), file=log.v1) for key in ["hyps", "refs"]: seq_len_stats[key].dump(stream_prefix="Seq-length %r %r " % (key, options.key), stream=log.v2) if options.expect_full: assert not remaining_hyp_seq_tags, "There are still remaining hypotheses." return wer
def dump_dataset(dataset, options): """ :type dataset: Dataset.Dataset :param options: argparse.Namespace """ print("Epoch: %i" % options.epoch, file=log.v3) dataset.init_seq_order(epoch=options.epoch) print("Dataset keys:", dataset.get_data_keys(), file=log.v3) print("Dataset target keys:", dataset.get_target_list(), file=log.v3) assert options.key in dataset.get_data_keys() if options.get_num_seqs: print("Get num seqs.") print("estimated_num_seqs: %r" % dataset.estimated_num_seqs) try: print("num_seqs: %r" % dataset.num_seqs) except Exception as exc: print("num_seqs exception %r, which is valid, so we count." % exc) seq_idx = 0 if dataset.get_target_list(): default_target = dataset.get_target_list()[0] else: default_target = None while dataset.is_less_than_num_seqs(seq_idx): dataset.load_seqs(seq_idx, seq_idx + 1) if seq_idx % 10000 == 0: if default_target: targets = dataset.get_targets(default_target, seq_idx) postfix = " (targets = %r...)" % (targets[:10], ) else: postfix = "" print("%i ...%s" % (seq_idx, postfix)) seq_idx += 1 print("accumulated num seqs: %i" % seq_idx) print("Done.") return dump_file = None if options.type == "numpy": print("Dump files: %r*%r" % (options.dump_prefix, options.dump_postfix), file=log.v3) elif options.type == "stdout": print("Dump to stdout", file=log.v3) if options.stdout_limit is not None: util.set_pretty_print_default_limit(options.stdout_limit) numpy.set_printoptions( threshold=sys.maxsize if options.stdout_limit == float("inf") else int(options.stdout_limit)) if options.stdout_as_bytes: util.set_pretty_print_as_bytes(options.stdout_as_bytes) elif options.type == "print_tag": print("Dump seq tag to stdout", file=log.v3) elif options.type == "dump_tag": dump_file = open("%sseq-tags.txt" % options.dump_prefix, "w") print("Dump seq tag to file: %s" % (dump_file.name, ), file=log.v3) elif options.type == "dump_seq_len": dump_file = open("%sseq-lens.txt" % options.dump_prefix, "w") print("Dump seq lens to file: %s" % (dump_file.name, ), file=log.v3) dump_file.write("{\n") elif options.type == "print_shape": print("Dump shape to stdout", file=log.v3) elif options.type == "plot": print("Plot.", file=log.v3) elif options.type == "interactive": print("Interactive debug shell.", file=log.v3) elif options.type == "null": if options.dump_stats: print("No dump (except stats).") else: print("No dump.") else: raise Exception("unknown dump option type %r" % options.type) start_time = time.time() stats = Stats() if (options.stats or options.dump_stats) else None seq_len_stats = {key: Stats() for key in dataset.get_data_keys()} seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq: dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % (seq_idx, num_seqs_s) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) data = None if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) if options.type == "print_tag": print( "seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx)) elif options.type == "dump_tag": print( "seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx)) dump_file.write("%s\n" % dataset.get_tag(seq_idx)) elif options.type == "dump_seq_len": seq_len = dataset.get_seq_length(seq_idx)[options.key] print( "seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx), "%r len:" % options.key, seq_len) dump_file.write("%r: %r,\n" % (dataset.get_tag(seq_idx), seq_len)) else: data = dataset.get_data(seq_idx, options.key) if options.type == "numpy": numpy.savetxt( "%s%i.data%s" % (options.dump_prefix, seq_idx, options.dump_postfix), data) elif options.type == "stdout": print("seq %s tag:" % progress, dataset.get_tag(seq_idx)) print("seq %s data:" % progress, pretty_print(data)) elif options.type == "print_shape": print("seq %s data shape:" % progress, data.shape) elif options.type == "plot": plot(data) for target in dataset.get_target_list(): targets = dataset.get_targets(target, seq_idx) if options.type == "numpy": numpy.savetxt("%s%i.targets.%s%s" % (options.dump_prefix, seq_idx, target, options.dump_postfix), targets, fmt='%i') elif options.type == "stdout": extra = "" if target in dataset.labels and len( dataset.labels[target]) > 1: assert dataset.can_serialize_data(target) extra += " (%r)" % dataset.serialize_data(key=target, data=targets) print("seq %i target %r: %s%s" % (seq_idx, target, pretty_print(targets), extra)) elif options.type == "print_shape": print("seq %i target %r shape:" % (seq_idx, target), targets.shape) if options.type == "interactive": from returnn.util.debug import debug_shell debug_shell(locals()) seq_len = dataset.get_seq_length(seq_idx) for key in dataset.get_data_keys(): seq_len_stats[key].collect([seq_len[key]]) if stats: stats.collect(data) if options.type == "null": util.progress_bar_with_time(complete_frac, prefix=progress_prefix) seq_idx += 1 print("Done. Total time %s. More seqs which we did not dumped: %s" % (hms_fraction(time.time() - start_time), dataset.is_less_than_num_seqs(seq_idx)), file=log.v2) for key in dataset.get_data_keys(): seq_len_stats[key].dump(stream_prefix="Seq-length %r " % key, stream=log.v2) if stats: stats.dump(output_file_prefix=options.dump_stats, stream_prefix="Data %r " % options.key, stream=log.v1) if options.type == "dump_seq_len": dump_file.write("}\n") if dump_file: print("Dumped to file:", dump_file.name, file=log.v2) dump_file.close()
def work(self): """ Start the optimization. """ print("Starting hyper param search. Using %i threads." % self.num_threads, file=log.v1) from returnn.tf.util.basic import get_available_gpu_devices from returnn.log import wrap_log_streams, StreamDummy from threading import Thread, Condition from returnn.util.basic import progress_bar, hms, is_tty class Outstanding: """ Queue of outstanding work. """ cond = Condition() threads = [] # type: typing.List[WorkerThread] population = [] exit = False exception = None class WorkerThread(Thread): """ Worker threader. """ def __init__(self, gpu_ids): """ :param set[int] gpu_ids: """ super(WorkerThread, self).__init__(name="Hyper param tune train thread") self.gpu_ids = gpu_ids self.trainer = None # type: typing.Optional[_IndividualTrainer] self.finished = False self.start() def cancel(self, join=False): """ :param bool join: """ with Outstanding.cond: if self.trainer: self.trainer.cancel_flag = True if self.trainer.runner: self.trainer.runner.cancel_flag = True if join: self.join() def get_complete_frac(self): """ :rtype: float """ with Outstanding.cond: if self.trainer and self.trainer.runner: return self.trainer.runner.data_provider.get_complete_frac( ) return 0.0 # noinspection PyMethodParameters def run(self_thread): """ Run thread. """ try: while True: with Outstanding.cond: if Outstanding.exit or Outstanding.exception: return if not Outstanding.population: self_thread.finished = True Outstanding.cond.notify_all() return # noinspection PyShadowingNames individual = Outstanding.population.pop(0) self_thread.trainer = _IndividualTrainer( optim=self, individual=individual, gpu_ids=self_thread.gpu_ids) self_thread.name = "Hyper param tune train thread on %r" % individual.name self_thread.trainer.run() except Exception as exc: with Outstanding.cond: if not Outstanding.exception: Outstanding.exception = exc or True Outstanding.cond.notify_all() for thread in Outstanding.threads: if thread is not self_thread: thread.cancel() if not isinstance(exc, CancelTrainingException): with Outstanding.cond: # So that we don't mix up multiple on sys.stderr. # This would normally dump it on sys.stderr so it's fine. sys.excepthook(*sys.exc_info()) best_individuals = [] population = [] num_gpus = len(get_available_gpu_devices()) print("Num available GPUs:", num_gpus) num_gpus = num_gpus or 1 # Would be ignored anyway. interactive = is_tty() try: print( "Population of %i individuals (hyper param setting instances), running for %i evaluation iterations." % (self.num_individuals, self.num_iterations), file=log.v2) for cur_iteration_idx in range(1, self.num_iterations + 1): print("Starting iteration %i." % cur_iteration_idx, file=log.v2) if cur_iteration_idx == 1: population.append( Individual( { p: p.get_default_value() for p in self.hyper_params }, name="default")) population.append( Individual( { p: p.get_initial_value() for p in self.hyper_params }, name="canonical")) population.extend( self.get_population(iteration_idx=cur_iteration_idx, num_individuals=self.num_individuals - len(population))) if cur_iteration_idx > 1: self.cross_over(population=population, iteration_idx=cur_iteration_idx) if cur_iteration_idx == 1 and self.dry_run_first_individual: # Train first directly for testing and to see log output. # Later we will strip away all log output. print("Very first try with log output:", file=log.v2) _IndividualTrainer(optim=self, individual=population[0], gpu_ids={0}).run() print("Starting training with thread pool of %i threads." % self.num_threads) iteration_start_time = time.time() with wrap_log_streams(StreamDummy(), also_sys_stdout=True, tf_log_verbosity="WARN"): Outstanding.exit = False Outstanding.population = list(population) Outstanding.threads = [ WorkerThread(gpu_ids={i % num_gpus}) for i in range(self.num_threads) ] try: while True: with Outstanding.cond: if all([ thread.finished for thread in Outstanding.threads ]) or Outstanding.exception: break complete_frac = max( len(population) - len(Outstanding.population) - len(Outstanding.threads), 0) complete_frac += sum([ thread.get_complete_frac() for thread in Outstanding.threads ]) complete_frac /= float(len(population)) remaining_str = "" if complete_frac > 0: start_elapsed = time.time( ) - iteration_start_time total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed remaining_str = hms(remaining_estimated) if interactive: progress_bar(complete_frac, prefix=remaining_str, file=sys.__stdout__) else: print("Progress: %.02f%%" % (complete_frac * 100), "remaining:", remaining_str or "unknown", file=sys.__stdout__) sys.__stdout__.flush() Outstanding.cond.wait(1 if interactive else 10) for thread in Outstanding.threads: thread.join() finally: Outstanding.exit = True for thread in Outstanding.threads: thread.cancel(join=True) Outstanding.threads = [] print("Training iteration elapsed time:", hms(time.time() - iteration_start_time)) if Outstanding.exception: raise Outstanding.exception assert not Outstanding.population print("Training iteration finished.") population.sort(key=lambda p: p.cost) del population[-self.num_kill_individuals:] best_individuals.extend(population) best_individuals.sort(key=lambda p: p.cost) del best_individuals[self.num_best:] population = best_individuals[:self.num_kill_individuals // 4] + population print( "Current best setting, individual %s" % best_individuals[0].name, "cost:", best_individuals[0].cost) for p in self.hyper_params: print(" %s -> %s" % (p.description(), best_individuals[0].hyper_param_mapping[p])) except KeyboardInterrupt: print("KeyboardInterrupt, canceled search.") print("Best %i settings:" % len(best_individuals)) for individual in best_individuals: print("Individual %s" % individual.name, "cost:", individual.cost) for p in self.hyper_params: print(" %s -> %s" % (p.description(), individual.hyper_param_mapping[p]))
def train_epoch(self): print("start", self.get_epoch_str(), "with learning rate", self.learning_rate, "...", file=log.v4) if self.epoch == 1 and self.save_epoch1_initial_model: epoch0_model_filename = self.epoch_model_filename(self.model_filename, 0, self.is_pretrain_epoch()) print("save initial epoch1 model", epoch0_model_filename, file=log.v4) self.save_model(epoch0_model_filename, 0) if self.is_pretrain_epoch(): self.print_network_info() training_devices = self.devices if 'train' not in self.dataset_batches or not self.train_data.batch_set_generator_cache_whole_epoch(): self.dataset_batches['train'] = self.train_data.generate_batches(recurrent_net=self.network.recurrent, batch_size=self.batch_size, pruning=self.batch_pruning, max_seqs=self.max_seqs, max_seq_length=int(self.max_seq_length), seq_drop=self.seq_drop, shuffle_batches=self.shuffle_batches, used_data_keys=self.network.get_used_data_keys()) else: self.dataset_batches['train'].reset() train_batches = self.dataset_batches['train'] start_batch = self.start_batch if self.epoch == self.start_epoch else 0 trainer = TrainTaskThread(self.network, training_devices, data=self.train_data, batches=train_batches, learning_rate=self.learning_rate, updater=self.updater, eval_batch_size=self.update_batch_size, start_batch=start_batch, share_batches=self.share_batches, reduction_rate=self.reduction_rate, exclude=self.exclude, seq_train_parallel=self.seq_train_parallel, report_prefix=("pre" if self.is_pretrain_epoch() else "") + "train epoch %s" % self.epoch, epoch=self.epoch) trainer.join() if not trainer.finalized: if trainer.device_crash_batch is not None: # Otherwise we got an unexpected exception - a bug in our code. if self.model_filename: self.save_model(self.get_epoch_model_filename() + ".crash_%i" % trainer.device_crash_batch, self.epoch - 1) sys.exit(1) assert not any(numpy.isinf(list(trainer.score.values()))) or any(numpy.isnan(list(trainer.score.values()))), ( "Model is broken, got inf or nan final score: %s" % trainer.score) if self.model_filename and (self.epoch % self.save_model_epoch_interval == 0): self.save_model(self.get_epoch_model_filename(), self.epoch) self.learning_rate_control.set_epoch_error(self.epoch, {"train_score": trainer.score}) self.learning_rate_control.save() if self.ctc_prior_file is not None: trainer.save_ctc_priors(self.ctc_prior_file, self.get_epoch_str()) print(self.get_epoch_str(), "score:", self.format_score(trainer.score), "elapsed:", hms(trainer.elapsed), end=' ', file=log.v1) self.eval_model()
def analyze_dataset(options): """ :param options: argparse.Namespace """ print("Epoch: %i" % options.epoch, file=log.v3) print("Dataset keys:", dataset.get_data_keys(), file=log.v3) print("Dataset target keys:", dataset.get_target_list(), file=log.v3) assert options.key in dataset.get_data_keys() terminal_width, _ = util.terminal_size() show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5]) and terminal_width >= 0) start_time = time.time() num_seqs_stats = Stats() if options.endseq < 0: options.endseq = float("inf") recurrent = True used_data_keys = dataset.get_data_keys() batch_size = config.typed_value('batch_size', 1) max_seqs = config.int('max_seqs', -1) seq_drop = config.float('seq_drop', 0.0) max_seq_length = config.typed_value( 'max_seq_length', None) or config.float('max_seq_length', 0) max_pad_size = config.typed_value("max_pad_size", None) batches = dataset.generate_batches(recurrent_net=recurrent, batch_size=batch_size, max_seqs=max_seqs, max_seq_length=max_seq_length, max_pad_size=max_pad_size, seq_drop=seq_drop, used_data_keys=used_data_keys) step = 0 total_num_seqs = 0 total_num_frames = NumbersDict() total_num_used_frames = NumbersDict() try: while batches.has_more(): # See FeedDictDataProvider. batch, = batches.peek_next_n(1) assert isinstance(batch, Batch) if batch.start_seq > options.endseq: break dataset.load_seqs(batch.start_seq, batch.end_seq) complete_frac = batches.completed_frac() start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % (batch.start_seq, num_seqs_s) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) batch_max_time = NumbersDict.max( [seq.frame_length for seq in batch.seqs]) * len(batch.seqs) batch_num_used_frames = sum( [seq.frame_length for seq in batch.seqs], NumbersDict()) total_num_seqs += len(batch.seqs) num_seqs_stats.collect(numpy.array([len(batch.seqs)])) total_num_frames += batch_max_time total_num_used_frames += batch_num_used_frames print("%s, batch %i, num seqs %i, frames %s, used %s (%s)" % (progress, step, len( batch.seqs), batch_max_time, batch_num_used_frames, batch_num_used_frames / batch_max_time), file=log.v5) if show_interactive_process_bar: util.progress_bar_with_time(complete_frac, prefix=progress_prefix) step += 1 batches.advance(1) finally: print("Done. Total time %s. More seqs which we did not dumped: %s" % (hms(time.time() - start_time), batches.has_more()), file=log.v2) print("Dataset epoch %i, order %r." % (dataset.epoch, dataset.seq_ordering)) print("Num batches (steps): %i" % step, file=log.v1) print("Num seqs: %i" % total_num_seqs, file=log.v1) num_seqs_stats.dump(stream=log.v1, stream_prefix="Batch num seqs ") for key in used_data_keys: print("Data key %r:" % key, file=log.v1) print(" Num frames: %s" % total_num_frames[key], file=log.v1) print(" Num used frames: %s" % total_num_used_frames[key], file=log.v1) print(" Fraction used frames: %s" % (total_num_used_frames / total_num_frames)[key], file=log.v1) dataset.finish_epoch()
def collect_stats(options, iter_corpus): """ :param options: argparse.Namespace """ orth_symbols_filename = options.output if orth_symbols_filename: assert not os.path.exists(orth_symbols_filename) class Stats: count = 0 process_last_time = time.time() total_frame_len = 0 total_orth_len = 0 orth_syms_set = set() if options.add_numbers: Stats.orth_syms_set.update( map(chr, list(range(ord("0"), ord("9") + 1)))) if options.add_lower_alphabet: Stats.orth_syms_set.update( map(chr, list(range(ord("a"), ord("z") + 1)))) if options.add_upper_alphabet: Stats.orth_syms_set.update( map(chr, list(range(ord("A"), ord("Z") + 1)))) def cb(frame_len, orth): if frame_len >= options.max_seq_frame_len: return orth_syms = parse_orthography(orth) if len(orth_syms) >= options.max_seq_orth_len: return Stats.count += 1 Stats.total_frame_len += frame_len if options.dump_orth_syms: print("Orth:", "".join(orth_syms), file=log.v3) if options.filter_orth_sym: if options.filter_orth_sym in orth_syms: print("Found orth:", "".join(orth_syms), file=log.v3) if options.filter_orth_syms_seq: filter_seq = parse_orthography_into_symbols( options.filter_orth_syms_seq) if found_sub_seq(filter_seq, orth_syms): print("Found orth:", "".join(orth_syms), file=log.v3) Stats.orth_syms_set.update(orth_syms) Stats.total_orth_len += len(orth_syms) # Show some progress if it takes long. if time.time() - Stats.process_last_time > 2: Stats.process_last_time = time.time() if options.collect_time: print("Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)), file=log.v3) else: print("Collect process, total orth len so far:", human_size(Stats.total_orth_len), file=log.v3) iter_corpus(cb) if options.remove_symbols: filter_syms = parse_orthography_into_symbols(options.remove_symbols) Stats.orth_syms_set -= set(filter_syms) if options.collect_time: print("Total frame len:", Stats.total_frame_len, "time:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)), file=log.v3) else: print("No time stats (--collect_time False).", file=log.v3) print("Total orth len:", Stats.total_orth_len, "(%s)" % human_size(Stats.total_orth_len), end=' ', file=log.v3) if options.collect_time: print("fraction:", float(Stats.total_orth_len) / Stats.total_frame_len, file=log.v3) else: print("", file=log.v3) print("Average orth len:", float(Stats.total_orth_len) / Stats.count, file=log.v3) print("Num symbols:", len(Stats.orth_syms_set), file=log.v3) if orth_symbols_filename: orth_syms_file = open(orth_symbols_filename, "wb") for orth_sym in sorted(Stats.orth_syms_set): orth_syms_file.write(b"%s\n" % unicode(orth_sym).encode("utf8")) orth_syms_file.close() print("Wrote orthography symbols to", orth_symbols_filename, file=log.v3) else: print("Provide --output to save the symbols.", file=log.v3)
def test_DeviceBatchRun_outputs_format(): # TODO: This is broken... return dev_run = DummyDeviceBatchRun(task="train") assert len(dev_run.alloc_devices) == 1 # Simulate epoch start. trainer = dev_run.parent dev_run.alloc_devices[0].start_epoch_stats() trainer.initialize() # Simulate one batch. dev_run.allocate() dev_run.device_run() dev_run.set_dummy_dev_output(outputs_format=["cost:foo"], output=[1.42]) dev_run.finish() assert_is_instance(dev_run.result, dict) assert_in("results", dev_run.result) res_outputss = dev_run.result["results"] assert_is_instance(res_outputss, list) assert_equal(len(res_outputss), len(dev_run.alloc_devices)) res_outputs = res_outputss[0] assert_is_instance(res_outputs, list) res_outputs_format = dev_run.result["result_format"] assert_is_instance(res_outputs_format, list) res = Device.make_result_dict(res_outputs, res_outputs_format) assert_is_instance(res, dict) pprint(res) # Simulate epoch end. print("train epoch score:", trainer.score, "elapsed:", hms(trainer.elapsed)) trainer.finalize() dev_run.alloc_devices[0].finish_epoch_stats() # Now simulate the eval. dev_run = DummyDeviceBatchRun(task="eval") assert len(dev_run.alloc_devices) == 1 # Simulate epoch start. tester = dev_run.parent dev_run.alloc_devices[0].start_epoch_stats() tester.initialize() # Simulate one batch. dev_run.allocate() dev_run.device_run() dev_run.set_dummy_dev_output(outputs_format=["cost:foo", "error:foo"], output=[1.42, 2.34]) dev_run.finish() # Simulate epoch end. print("eval epoch elapsed:", hms(tester.elapsed)) tester.finalize() dev_run.alloc_devices[0].finish_epoch_stats() print("eval results:", tester.score, tester.error) assert_is_instance(dev_run.result, dict) assert_in("results", dev_run.result) res_outputss = dev_run.result["results"] assert_is_instance(res_outputss, list) assert_equal(len(res_outputss), len(dev_run.alloc_devices)) res_outputs = res_outputss[0] assert_is_instance(res_outputs, list) res_outputs_format = dev_run.result["result_format"] assert_is_instance(res_outputs_format, list) res = Device.make_result_dict(res_outputs, res_outputs_format) assert_is_instance(res, dict) pprint(res) assert_greater(tester.score, 0) assert_greater(tester.error, 0)