def print_process(self): if not self.parent.interactive and not log.v[5]: return start_elapsed = time.time() - self.parent.start_time complete = self.parent.batches.completed_frac() assert complete > 0 total_time_estimated = start_elapsed / complete remaining_estimated = total_time_estimated - start_elapsed if log.verbose[5]: mem_usage = self.device_mem_usage_str(self.alloc_devices) info = [ self.parent.report_prefix, "batch %i" % self.run_start_batch_idx ] if self.eval_info: # Such as score. info += [ "%s %s" % item for item in sorted(self.eval_info.items()) ] info += [ "elapsed %s" % hms(start_elapsed), "exp. remaining %s" % hms(remaining_estimated), "complete %.02f%%" % (complete * 100) ] if mem_usage: info += ["memory %s" % mem_usage] print(", ".join(filter(None, info)), file=log.v5) if self.parent.interactive: progress_bar(complete, hms(remaining_estimated))
def get_raw_strings(dataset, options): """ :param Dataset dataset: :param options: argparse.Namespace :return: list of (seq tag, string) :rtype: list[(str,str)] """ refs = [] start_time = time.time() seq_len_stats = Stats() seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") interactive = Util.is_tty() and not log.verbose[5] print("Iterating over %r." % dataset, file=log.v2) while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq: dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % (seq_idx, num_seqs_s,) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) seq_tag = dataset.get_tag(seq_idx) assert isinstance(seq_tag, str) ref = dataset.get_data(seq_idx, options.key) if isinstance(ref, numpy.ndarray): assert ref.shape == () or (ref.ndim == 1 and ref.dtype == numpy.uint8) if ref.shape == (): ref = ref.flatten()[0] # get the entry itself (str or bytes) else: ref = ref.tobytes() if isinstance(ref, bytes): ref = ref.decode("utf8") assert isinstance(ref, str) seq_len_stats.collect([len(ref)]) refs.append((seq_tag, ref)) if interactive: Util.progress_bar_with_time(complete_frac, prefix=progress_prefix) elif log.verbose[5]: print(progress_prefix, "seq tag %r, ref len %i chars" % (seq_tag, len(ref))) seq_idx += 1 print("Done. Num seqs %i. Total time %s." % ( seq_idx, hms(time.time() - start_time)), file=log.v1) print("More seqs which we did not dumped: %s." % ( dataset.is_less_than_num_seqs(seq_idx),), file=log.v1) seq_len_stats.dump(stream_prefix="Seq-length %r " % (options.key,), stream=log.v2) return refs
def _cleanup_old(self): mod_path = self._mod_path # .../base_name/hash base_mod_path = os.path.dirname(mod_path) # .../base_name my_mod_path_name = os.path.basename(mod_path) if not os.path.exists(base_mod_path): return import time from Util import hms cleanup_time_limit_secs = self._cleanup_time_limit_days * 24 * 60 * 60 for p in os.listdir(base_mod_path): if p == my_mod_path_name: continue full_dir_path = "%s/%s" % (base_mod_path, p) if not os.path.isdir(full_dir_path): continue # ignore for now info_path = "%s/info.py" % full_dir_path if not os.path.exists(info_path): self._cleanup_old_path(full_dir_path, reason="corrupt dir, missing info.py") continue so_path = "%s/%s.so" % (full_dir_path, self.base_name) if not os.path.exists(so_path): self._cleanup_old_path(full_dir_path, reason="corrupt dir, missing so") continue dt = time.time() - os.path.getmtime(so_path) if dt > cleanup_time_limit_secs: self._cleanup_old_path(full_dir_path, reason="%s old" % hms(dt))
def cb(frame_len, orth): if frame_len >= options.max_seq_frame_len: return orth_syms = parse_orthography(orth) if len(orth_syms) >= options.max_seq_orth_len: return Stats.count += 1 Stats.total_frame_len += frame_len if options.dump_orth_syms: print >> log.v3, "Orth:", "".join(orth_syms) if options.filter_orth_sym: if options.filter_orth_sym in orth_syms: print >> log.v3, "Found orth:", "".join(orth_syms) if options.filter_orth_syms_seq: filter_seq = parse_orthography_into_symbols(options.filter_orth_syms_seq) if found_sub_seq(filter_seq, orth_syms): print >> log.v3, "Found orth:", "".join(orth_syms) Stats.orth_syms_set.update(orth_syms) Stats.total_orth_len += len(orth_syms) # Show some progress if it takes long. if time.time() - Stats.process_last_time > 2: Stats.process_last_time = time.time() if options.collect_time: print >> log.v3, "Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)) else: print >> log.v3, "Collect process, total orth len so far:", human_size(Stats.total_orth_len)
def cb(frame_len, orth): if frame_len >= options.max_seq_frame_len: return orth_syms = parse_orthography(orth) if len(orth_syms) >= options.max_seq_orth_len: return Stats.count += 1 Stats.total_frame_len += frame_len if options.dump_orth_syms: print("Orth:", "".join(orth_syms), file=log.v3) if options.filter_orth_sym: if options.filter_orth_sym in orth_syms: print("Found orth:", "".join(orth_syms), file=log.v3) if options.filter_orth_syms_seq: filter_seq = parse_orthography_into_symbols( options.filter_orth_syms_seq) if found_sub_seq(filter_seq, orth_syms): print("Found orth:", "".join(orth_syms), file=log.v3) Stats.orth_syms_set.update(orth_syms) Stats.total_orth_len += len(orth_syms) # Show some progress if it takes long. if time.time() - Stats.process_last_time > 2: Stats.process_last_time = time.time() if options.collect_time: print("Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)), file=log.v3) else: print("Collect process, total orth len so far:", human_size(Stats.total_orth_len), file=log.v3)
def analyze(self, data, statistics): """ :param Dataset.Dataset data: :param list[str]|None statistics: ignored at the moment :return: nothing, will print everything to log.v1 """ print("Analyze with network on %r." % data, file=log.v1) if "analyze" not in self.network.layers: from TFNetworkLayer import FramewiseStatisticsLayer assert self.config.has("sil_label_idx") self.network.add_layer(name="analyze", layer_class=FramewiseStatisticsLayer, sil_label_idx=self.config.int( "sil_label_idx", 0), sources=self.network.get_output_layers()) # It's constructed lazily and it will set used_data_keys, so make sure that we have it now. self.network.get_all_errors() batch_size = self.config.int('batch_size', 1) max_seqs = self.config.int('max_seqs', -1) max_seq_length = self.config.float('max_seq_length', 0) if max_seq_length <= 0: max_seq_length = sys.maxsize batches = data.generate_batches( recurrent_net=self.network.recurrent, batch_size=batch_size, max_seqs=max_seqs, max_seq_length=max_seq_length, used_data_keys=self.network.used_data_keys) analyzer = Runner(engine=self, dataset=data, batches=batches, train=False) analyzer.run(report_prefix=self.get_epoch_str() + " analyze") print("Finished analyzing of the dataset %r." % data, file=log.v1) print("elapsed:", hms(analyzer.elapsed), file=log.v1) print("num mini-batches:", analyzer.num_steps, file=log.v1) print("total num_frames:", analyzer.data_provider.get_num_frames(), file=log.v1) print("score:", self.format_score(analyzer.score), file=log.v1) print("error:", self.format_score(analyzer.error), file=log.v1) for k, v in sorted(analyzer.stats.items()): if k.startswith("stats:"): print("%s:" % k, v, file=log.v1) print("That are all collected stats.", file=log.v1) if not analyzer.finalized: print("WARNING: Did not finished through the whole epoch.", file=log.v1) sys.exit(1)
def _print_process(self, report_prefix, step, step_duration, eval_info): if not self._show_interactive_process_bar and not log.v[5]: return start_elapsed = time.time() - self.start_time complete = self.data_provider.batches.completed_frac() assert complete > 0 total_time_estimated = start_elapsed / complete remaining_estimated = total_time_estimated - start_elapsed if log.verbose[5]: info = [report_prefix, "step %i" % step] if eval_info: # Such as score. info += ["%s %s" % item for item in sorted(eval_info.items())] info += [ "%.3f sec/step" % step_duration, "elapsed %s" % hms(start_elapsed), "exp. remaining %s" % hms(remaining_estimated), "complete %.02f%%" % (complete * 100) ] print(", ".join(filter(None, info)), file=log.v5) elif self._show_interactive_process_bar: from Util import progress_bar progress_bar(complete, hms(remaining_estimated))
def print_process(self): if not self.parent.interactive and not log.v[5]: return start_elapsed = time.time() - self.parent.start_time complete = self.parent.batches.completed_frac() assert complete > 0 total_time_estimated = start_elapsed / complete remaining_estimated = total_time_estimated - start_elapsed if log.verbose[5]: mem_usage = self.device_mem_usage_str(self.alloc_devices) info = [ self.parent.report_prefix, "batch %i" % self.run_start_batch_idx] if self.eval_info: # Such as score. info += ["%s %s" % item for item in sorted(self.eval_info.items())] info += [ "elapsed %s" % hms(start_elapsed), "exp. remaining %s" % hms(remaining_estimated), "complete %.02f%%" % (complete * 100)] if mem_usage: info += ["memory %s" % mem_usage] print >> log.v5, ", ".join(filter(None, info)) if self.parent.interactive: progress_bar(complete, hms(remaining_estimated))
def test_DeviceBatchRun_outputs_format(): # TODO: This is broken... return dev_run = DummyDeviceBatchRun(task="train") assert len(dev_run.alloc_devices) == 1 # Simulate epoch start. trainer = dev_run.parent dev_run.alloc_devices[0].start_epoch_stats() trainer.initialize() # Simulate one batch. dev_run.allocate() dev_run.device_run() dev_run.set_dummy_dev_output(outputs_format=["cost:foo"], output=[1.42]) dev_run.finish() assert_is_instance(dev_run.result, dict) assert_in("results", dev_run.result) res_outputss = dev_run.result["results"] assert_is_instance(res_outputss, list) assert_equal(len(res_outputss), len(dev_run.alloc_devices)) res_outputs = res_outputss[0] assert_is_instance(res_outputs, list) res_outputs_format = dev_run.result["result_format"] assert_is_instance(res_outputs_format, list) res = Device.make_result_dict(res_outputs, res_outputs_format) assert_is_instance(res, dict) pprint(res) # Simulate epoch end. print("train epoch score:", trainer.score, "elapsed:", hms(trainer.elapsed)) trainer.finalize() dev_run.alloc_devices[0].finish_epoch_stats() # Now simulate the eval. dev_run = DummyDeviceBatchRun(task="eval") assert len(dev_run.alloc_devices) == 1 # Simulate epoch start. tester = dev_run.parent dev_run.alloc_devices[0].start_epoch_stats() tester.initialize() # Simulate one batch. dev_run.allocate() dev_run.device_run() dev_run.set_dummy_dev_output(outputs_format=["cost:foo", "error:foo"], output=[1.42, 2.34]) dev_run.finish() # Simulate epoch end. print("eval epoch elapsed:", hms(tester.elapsed)) tester.finalize() dev_run.alloc_devices[0].finish_epoch_stats() print("eval results:", tester.score, tester.error) assert_is_instance(dev_run.result, dict) assert_in("results", dev_run.result) res_outputss = dev_run.result["results"] assert_is_instance(res_outputss, list) assert_equal(len(res_outputss), len(dev_run.alloc_devices)) res_outputs = res_outputss[0] assert_is_instance(res_outputs, list) res_outputs_format = dev_run.result["result_format"] assert_is_instance(res_outputs_format, list) res = Device.make_result_dict(res_outputs, res_outputs_format) assert_is_instance(res, dict) pprint(res) assert_greater(tester.score, 0) assert_greater(tester.error, 0)
def analyze_dataset(options): """ :param options: argparse.Namespace """ print("Epoch: %i" % options.epoch, file=log.v3) print("Dataset keys:", dataset.get_data_keys(), file=log.v3) print("Dataset target keys:", dataset.get_target_list(), file=log.v3) assert options.key in dataset.get_data_keys() terminal_width, _ = Util.terminal_size() show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5]) and terminal_width >= 0) start_time = time.time() num_seqs_stats = Stats() if options.endseq < 0: options.endseq = float("inf") recurrent = True used_data_keys = dataset.get_data_keys() batch_size = config.typed_value('batch_size', 1) max_seqs = config.int('max_seqs', -1) seq_drop = config.float('seq_drop', 0.0) max_seq_length = config.typed_value('max_seq_length', None) or config.float('max_seq_length', 0) max_pad_size = config.typed_value("max_pad_size", None) batches = dataset.generate_batches( recurrent_net=recurrent, batch_size=batch_size, max_seqs=max_seqs, max_seq_length=max_seq_length, max_pad_size=max_pad_size, seq_drop=seq_drop, used_data_keys=used_data_keys) step = 0 total_num_seqs = 0 total_num_frames = NumbersDict() total_num_used_frames = NumbersDict() try: while batches.has_more(): # See FeedDictDataProvider. batch, = batches.peek_next_n(1) assert isinstance(batch, Batch) if batch.start_seq > options.endseq: break dataset.load_seqs(batch.start_seq, batch.end_seq) complete_frac = batches.completed_frac() start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % (batch.start_seq, num_seqs_s) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) batch_max_time = NumbersDict.max([seq.frame_length for seq in batch.seqs]) * len(batch.seqs) batch_num_used_frames = sum([seq.frame_length for seq in batch.seqs], NumbersDict()) total_num_seqs += len(batch.seqs) num_seqs_stats.collect(numpy.array([len(batch.seqs)])) total_num_frames += batch_max_time total_num_used_frames += batch_num_used_frames print( "%s, batch %i, num seqs %i, frames %s, used %s (%s)" % ( progress, step, len(batch.seqs), batch_max_time, batch_num_used_frames, batch_num_used_frames / batch_max_time), file=log.v5) if show_interactive_process_bar: Util.progress_bar_with_time(complete_frac, prefix=progress_prefix) step += 1 batches.advance(1) finally: print("Done. Total time %s. More seqs which we did not dumped: %s" % ( hms(time.time() - start_time), batches.has_more()), file=log.v2) print("Dataset epoch %i, order %r." % (dataset.epoch, dataset.seq_ordering)) print("Num batches (steps): %i" % step, file=log.v1) print("Num seqs: %i" % total_num_seqs, file=log.v1) num_seqs_stats.dump(stream=log.v1, stream_prefix="Batch num seqs ") for key in used_data_keys: print("Data key %r:" % key, file=log.v1) print(" Num frames: %s" % total_num_frames[key], file=log.v1) print(" Num used frames: %s" % total_num_used_frames[key], file=log.v1) print(" Fraction used frames: %s" % (total_num_used_frames / total_num_frames)[key], file=log.v1) dataset.finish_epoch()
def work(self): print("Starting hyper param search. Using %i threads." % self.num_threads, file=log.v1) from TFUtil import get_available_gpu_devices from Log import wrap_log_streams, StreamDummy from threading import Thread, Condition from Util import progress_bar, hms, is_tty class Outstanding: cond = Condition() threads = [] # type: list[WorkerThread] population = [] exit = False exception = None class WorkerThread(Thread): def __init__(self, gpu_ids): """ :param set[int] gpu_ids: """ super(WorkerThread, self).__init__(name="Hyper param tune train thread") self.gpu_ids = gpu_ids self.trainer = None # type: _IndividualTrainer self.finished = False self.start() def cancel(self, join=False): with Outstanding.cond: if self.trainer: self.trainer.cancel_flag = True if self.trainer.runner: self.trainer.runner.cancel_flag = True if join: self.join() def get_complete_frac(self): with Outstanding.cond: if self.trainer and self.trainer.runner: return self.trainer.runner.data_provider.get_complete_frac() return 0.0 def run(self_thread): try: while True: with Outstanding.cond: if Outstanding.exit or Outstanding.exception: return if not Outstanding.population: self_thread.finished = True Outstanding.cond.notify_all() return individual = Outstanding.population.pop(0) self_thread.trainer = _IndividualTrainer(optim=self, individual=individual, gpu_ids=self_thread.gpu_ids) self_thread.name = "Hyper param tune train thread on %r" % individual.name self_thread.trainer.run() except Exception as exc: with Outstanding.cond: if not Outstanding.exception: Outstanding.exception = exc or True Outstanding.cond.notify_all() for thread in Outstanding.threads: if thread is not self_thread: thread.cancel() if not isinstance(exc, CancelTrainingException): with Outstanding.cond: # So that we don't mix up multiple on sys.stderr. # This would normally dump it on sys.stderr so it's fine. sys.excepthook(*sys.exc_info()) best_individuals = [] population = [] canceled = False num_gpus = len(get_available_gpu_devices()) print("Num available GPUs:", num_gpus) num_gpus = num_gpus or 1 # Would be ignored anyway. interactive = is_tty() try: print("Population of %i individuals (hyper param setting instances), running for %i evaluation iterations." % ( self.num_individuals, self.num_iterations), file=log.v2) for cur_iteration_idx in range(1, self.num_iterations + 1): print("Starting iteration %i." % cur_iteration_idx, file=log.v2) if cur_iteration_idx == 1: population.append(Individual( {p: p.get_default_value() for p in self.hyper_params}, name="default")) population.append(Individual( {p: p.get_initial_value() for p in self.hyper_params}, name="canonical")) population.extend(self.get_population( iteration_idx=cur_iteration_idx, num_individuals=self.num_individuals - len(population))) if cur_iteration_idx > 1: self.cross_over(population=population, iteration_idx=cur_iteration_idx) if cur_iteration_idx == 1 and self.dry_run_first_individual: # Train first directly for testing and to see log output. # Later we will strip away all log output. print("Very first try with log output:", file=log.v2) _IndividualTrainer(optim=self, individual=population[0], gpu_ids={0}).run() print("Starting training with thread pool of %i threads." % self.num_threads) iteration_start_time = time.time() with wrap_log_streams(StreamDummy(), also_sys_stdout=True, tf_log_verbosity="WARN"): Outstanding.exit = False Outstanding.population = list(population) Outstanding.threads = [WorkerThread(gpu_ids={i % num_gpus}) for i in range(self.num_threads)] try: while True: with Outstanding.cond: if all([thread.finished for thread in Outstanding.threads]) or Outstanding.exception: break complete_frac = max(len(population) - len(Outstanding.population) - len(Outstanding.threads), 0) complete_frac += sum([thread.get_complete_frac() for thread in Outstanding.threads]) complete_frac /= float(len(population)) remaining_str = "" if complete_frac > 0: start_elapsed = time.time() - iteration_start_time total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed remaining_str = hms(remaining_estimated) if interactive: progress_bar(complete_frac, prefix=remaining_str, file=sys.__stdout__) else: print( "Progress: %.02f%%" % (complete_frac * 100), "remaining:", remaining_str or "unknown", file=sys.__stdout__) sys.__stdout__.flush() Outstanding.cond.wait(1 if interactive else 10) for thread in Outstanding.threads: thread.join() finally: Outstanding.exit = True for thread in Outstanding.threads: thread.cancel(join=True) Outstanding.threads = [] print("Training iteration elapsed time:", hms(time.time() - iteration_start_time)) if Outstanding.exception: raise Outstanding.exception assert not Outstanding.population print("Training iteration finished.") population.sort(key=lambda p: p.cost) del population[-self.num_kill_individuals:] best_individuals.extend(population) best_individuals.sort(key=lambda p: p.cost) del best_individuals[self.num_best:] population = best_individuals[:self.num_kill_individuals // 4] + population print("Current best setting, individual %s" % best_individuals[0].name, "cost:", best_individuals[0].cost) for p in self.hyper_params: print(" %s -> %s" % (p.description(), best_individuals[0].hyper_param_mapping[p])) except KeyboardInterrupt: print("KeyboardInterrupt, canceled search.") canceled = True print("Best %i settings:" % len(best_individuals)) for individual in best_individuals: print("Individual %s" % individual.name, "cost:", individual.cost) for p in self.hyper_params: print(" %s -> %s" % (p.description(), individual.hyper_param_mapping[p]))
def collect_stats(options, iter_corpus): """ :param options: argparse.Namespace """ orth_symbols_filename = options.output if orth_symbols_filename: assert not os.path.exists(orth_symbols_filename) class Stats: count = 0 process_last_time = time.time() total_frame_len = 0 total_orth_len = 0 orth_syms_set = set() if options.add_numbers: Stats.orth_syms_set.update( map(chr, list(range(ord("0"), ord("9") + 1)))) if options.add_lower_alphabet: Stats.orth_syms_set.update( map(chr, list(range(ord("a"), ord("z") + 1)))) if options.add_upper_alphabet: Stats.orth_syms_set.update( map(chr, list(range(ord("A"), ord("Z") + 1)))) def cb(frame_len, orth): if frame_len >= options.max_seq_frame_len: return orth_syms = parse_orthography(orth) if len(orth_syms) >= options.max_seq_orth_len: return Stats.count += 1 Stats.total_frame_len += frame_len if options.dump_orth_syms: print("Orth:", "".join(orth_syms), file=log.v3) if options.filter_orth_sym: if options.filter_orth_sym in orth_syms: print("Found orth:", "".join(orth_syms), file=log.v3) if options.filter_orth_syms_seq: filter_seq = parse_orthography_into_symbols( options.filter_orth_syms_seq) if found_sub_seq(filter_seq, orth_syms): print("Found orth:", "".join(orth_syms), file=log.v3) Stats.orth_syms_set.update(orth_syms) Stats.total_orth_len += len(orth_syms) # Show some progress if it takes long. if time.time() - Stats.process_last_time > 2: Stats.process_last_time = time.time() if options.collect_time: print("Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)), file=log.v3) else: print("Collect process, total orth len so far:", human_size(Stats.total_orth_len), file=log.v3) iter_corpus(cb) if options.remove_symbols: filter_syms = parse_orthography_into_symbols(options.remove_symbols) Stats.orth_syms_set -= set(filter_syms) if options.collect_time: print("Total frame len:", Stats.total_frame_len, "time:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)), file=log.v3) else: print("No time stats (--collect_time False).", file=log.v3) print("Total orth len:", Stats.total_orth_len, "(%s)" % human_size(Stats.total_orth_len), end=' ', file=log.v3) if options.collect_time: print("fraction:", float(Stats.total_orth_len) / Stats.total_frame_len, file=log.v3) else: print("", file=log.v3) print("Average orth len:", float(Stats.total_orth_len) / Stats.count, file=log.v3) print("Num symbols:", len(Stats.orth_syms_set), file=log.v3) if orth_symbols_filename: orth_syms_file = open(orth_symbols_filename, "wb") for orth_sym in sorted(Stats.orth_syms_set): orth_syms_file.write("%s\n" % unicode(orth_sym).encode("utf8")) orth_syms_file.close() print("Wrote orthography symbols to", orth_symbols_filename, file=log.v3) else: print("Provide --output to save the symbols.", file=log.v3)
def test_DeviceBatchRun_outputs_format(): dev_run = DummyDeviceBatchRun(task="train") assert len(dev_run.alloc_devices) == 1 # Simulate epoch start. trainer = dev_run.parent dev_run.alloc_devices[0].start_epoch_stats() trainer.initialize() # Simulate one batch. dev_run.allocate() dev_run.device_run() dev_run.set_dummy_dev_output(outputs_format=["cost:foo"], output=[1.42]) dev_run.finish() assert_is_instance(dev_run.result, dict) assert_in("results", dev_run.result) res_outputss = dev_run.result["results"] assert_is_instance(res_outputss, list) assert_equal(len(res_outputss), len(dev_run.alloc_devices)) res_outputs = res_outputss[0] assert_is_instance(res_outputs, list) res_outputs_format = dev_run.result["result_format"] assert_is_instance(res_outputs_format, list) res = Device.make_result_dict(res_outputs, res_outputs_format) assert_is_instance(res, dict) pprint(res) # Simulate epoch end. print "train epoch score:", trainer.score, "elapsed:", hms(trainer.elapsed) trainer.finalize() dev_run.alloc_devices[0].finish_epoch_stats() # Now simulate the eval. dev_run = DummyDeviceBatchRun(task="eval") assert len(dev_run.alloc_devices) == 1 # Simulate epoch start. tester = dev_run.parent dev_run.alloc_devices[0].start_epoch_stats() tester.initialize() # Simulate one batch. dev_run.allocate() dev_run.device_run() dev_run.set_dummy_dev_output(outputs_format=["cost:foo", "error:foo"], output=[1.42, 2.34]) dev_run.finish() # Simulate epoch end. print "eval epoch elapsed:", hms(tester.elapsed) tester.finalize() dev_run.alloc_devices[0].finish_epoch_stats() print "eval results:", tester.score, tester.error assert_is_instance(dev_run.result, dict) assert_in("results", dev_run.result) res_outputss = dev_run.result["results"] assert_is_instance(res_outputss, list) assert_equal(len(res_outputss), len(dev_run.alloc_devices)) res_outputs = res_outputss[0] assert_is_instance(res_outputs, list) res_outputs_format = dev_run.result["result_format"] assert_is_instance(res_outputs_format, list) res = Device.make_result_dict(res_outputs, res_outputs_format) assert_is_instance(res, dict) pprint(res) assert_greater(tester.score, 0) assert_greater(tester.error, 0)
def collect_stats(options, iter_corpus): """ :param options: argparse.Namespace """ orth_symbols_filename = options.output if orth_symbols_filename: assert not os.path.exists(orth_symbols_filename) class Stats: count = 0 process_last_time = time.time() total_frame_len = 0 total_orth_len = 0 orth_syms_set = set() if options.add_numbers: Stats.orth_syms_set.update(map(chr, range(ord("0"), ord("9") + 1))) if options.add_lower_alphabet: Stats.orth_syms_set.update(map(chr, range(ord("a"), ord("z") + 1))) if options.add_upper_alphabet: Stats.orth_syms_set.update(map(chr, range(ord("A"), ord("Z") + 1))) def cb(frame_len, orth): if frame_len >= options.max_seq_frame_len: return orth_syms = parse_orthography(orth) if len(orth_syms) >= options.max_seq_orth_len: return Stats.count += 1 Stats.total_frame_len += frame_len if options.dump_orth_syms: print >> log.v3, "Orth:", "".join(orth_syms) if options.filter_orth_sym: if options.filter_orth_sym in orth_syms: print >> log.v3, "Found orth:", "".join(orth_syms) if options.filter_orth_syms_seq: filter_seq = parse_orthography_into_symbols(options.filter_orth_syms_seq) if found_sub_seq(filter_seq, orth_syms): print >> log.v3, "Found orth:", "".join(orth_syms) Stats.orth_syms_set.update(orth_syms) Stats.total_orth_len += len(orth_syms) # Show some progress if it takes long. if time.time() - Stats.process_last_time > 2: Stats.process_last_time = time.time() if options.collect_time: print >> log.v3, "Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)) else: print >> log.v3, "Collect process, total orth len so far:", human_size(Stats.total_orth_len) iter_corpus(cb) if options.remove_symbols: filter_syms = parse_orthography_into_symbols(options.remove_symbols) Stats.orth_syms_set -= set(filter_syms) if options.collect_time: print >> log.v3, "Total frame len:", Stats.total_frame_len, "time:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)) else: print >> log.v3, "No time stats (--collect_time False)." print >> log.v3, "Total orth len:", Stats.total_orth_len, "(%s)" % human_size(Stats.total_orth_len), if options.collect_time: print >> log.v3, "fraction:", float(Stats.total_orth_len) / Stats.total_frame_len else: print >> log.v3, "" print >> log.v3, "Average orth len:", float(Stats.total_orth_len) / Stats.count print >> log.v3, "Num symbols:", len(Stats.orth_syms_set) if orth_symbols_filename: orth_syms_file = open(orth_symbols_filename, "wb") for orth_sym in sorted(Stats.orth_syms_set): orth_syms_file.write("%s\n" % unicode(orth_sym).encode("utf8")) orth_syms_file.close() print >> log.v3, "Wrote orthography symbols to", orth_symbols_filename else: print >> log.v3, "Provide --output to save the symbols."
def dump_dataset(dataset, options): """ :type dataset: Dataset.Dataset :param options: argparse.Namespace """ print("Epoch: %i" % options.epoch, file=log.v3) dataset.init_seq_order(epoch=options.epoch) print("Dataset keys:", dataset.get_data_keys(), file=log.v3) print("Dataset target keys:", dataset.get_target_list(), file=log.v3) assert options.key in dataset.get_data_keys() if options.get_num_seqs: print("Get num seqs.") print("estimated_num_seqs: %r" % dataset.estimated_num_seqs) try: print("num_seqs: %r" % dataset.num_seqs) except Exception as exc: print("num_seqs exception %r, which is valid, so we count." % exc) seq_idx = 0 if dataset.get_target_list(): default_target = dataset.get_target_list()[0] else: default_target = None while dataset.is_less_than_num_seqs(seq_idx): dataset.load_seqs(seq_idx, seq_idx + 1) if seq_idx % 10000 == 0: if default_target: targets = dataset.get_targets(default_target, seq_idx) postfix = " (targets = %r...)" % (targets[:10],) else: postfix = "" print("%i ...%s" % (seq_idx, postfix)) seq_idx += 1 print("accumulated num seqs: %i" % seq_idx) print("Done.") return dump_file = None if options.type == "numpy": print("Dump files: %r*%r" % (options.dump_prefix, options.dump_postfix), file=log.v3) elif options.type == "stdout": print("Dump to stdout", file=log.v3) if options.stdout_limit is not None: Util.set_pretty_print_default_limit(options.stdout_limit) numpy.set_printoptions( threshold=sys.maxsize if options.stdout_limit == float("inf") else int(options.stdout_limit)) if options.stdout_as_bytes: Util.set_pretty_print_as_bytes(options.stdout_as_bytes) elif options.type == "print_tag": print("Dump seq tag to stdout", file=log.v3) elif options.type == "dump_tag": dump_file = open("%sseq-tags.txt" % options.dump_prefix, "w") print("Dump seq tag to file: %s" % (dump_file.name,), file=log.v3) elif options.type == "dump_seq_len": dump_file = open("%sseq-lens.txt" % options.dump_prefix, "w") print("Dump seq lens to file: %s" % (dump_file.name,), file=log.v3) dump_file.write("{\n") elif options.type == "print_shape": print("Dump shape to stdout", file=log.v3) elif options.type == "plot": print("Plot.", file=log.v3) elif options.type == "interactive": print("Interactive debug shell.", file=log.v3) elif options.type == "null": print("No dump.") else: raise Exception("unknown dump option type %r" % options.type) start_time = time.time() stats = Stats() if (options.stats or options.dump_stats) else None seq_len_stats = {key: Stats() for key in dataset.get_data_keys()} seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq: dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % (seq_idx, num_seqs_s) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) if options.type == "print_tag": print("seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx)) elif options.type == "dump_tag": print("seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx)) dump_file.write("%s\n" % dataset.get_tag(seq_idx)) elif options.type == "dump_seq_len": seq_len = dataset.get_seq_length(seq_idx)[options.key] print( "seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx), "%r len:" % options.key, seq_len) dump_file.write("%r: %r,\n" % (dataset.get_tag(seq_idx), seq_len)) else: data = dataset.get_data(seq_idx, options.key) if options.type == "numpy": numpy.savetxt("%s%i.data%s" % (options.dump_prefix, seq_idx, options.dump_postfix), data) elif options.type == "stdout": print("seq %s tag:" % progress, dataset.get_tag(seq_idx)) print("seq %s data:" % progress, pretty_print(data)) elif options.type == "print_shape": print("seq %s data shape:" % progress, data.shape) elif options.type == "plot": plot(data) for target in dataset.get_target_list(): targets = dataset.get_targets(target, seq_idx) if options.type == "numpy": numpy.savetxt("%s%i.targets.%s%s" % (options.dump_prefix, seq_idx, target, options.dump_postfix), targets, fmt='%i') elif options.type == "stdout": extra = "" if target in dataset.labels and len(dataset.labels[target]) > 1: assert dataset.can_serialize_data(target) extra += " (%r)" % dataset.serialize_data(key=target, data=targets) print("seq %i target %r: %s%s" % (seq_idx, target, pretty_print(targets), extra)) elif options.type == "print_shape": print("seq %i target %r shape:" % (seq_idx, target), targets.shape) if options.type == "interactive": from Debug import debug_shell debug_shell(locals()) seq_len = dataset.get_seq_length(seq_idx) for key in dataset.get_data_keys(): seq_len_stats[key].collect([seq_len[key]]) if stats: stats.collect(data) if options.type == "null": Util.progress_bar_with_time(complete_frac, prefix=progress_prefix) seq_idx += 1 print("Done. Total time %s. More seqs which we did not dumped: %s" % ( hms_fraction(time.time() - start_time), dataset.is_less_than_num_seqs(seq_idx)), file=log.v2) for key in dataset.get_data_keys(): seq_len_stats[key].dump(stream_prefix="Seq-length %r " % key, stream=log.v2) if stats: stats.dump(output_file_prefix=options.dump_stats, stream_prefix="Data %r " % options.key, stream=log.v1) if options.type == "dump_seq_len": dump_file.write("}\n") if dump_file: print("Dumped to file:", dump_file.name, file=log.v2) dump_file.close()
def train_epoch(self): print("start", self.get_epoch_str(), "with learning rate", self.learning_rate, "...", file=log.v4) if self.epoch == 1 and self.save_epoch1_initial_model: epoch0_model_filename = self.epoch_model_filename(self.model_filename, 0, self.is_pretrain_epoch()) print("save initial epoch1 model", epoch0_model_filename, file=log.v4) self.save_model(epoch0_model_filename) if 'train' not in self.dataset_batches or not self.train_data.batch_set_generator_cache_whole_epoch(): self.dataset_batches['train'] = self.train_data.generate_batches(recurrent_net=self.network.recurrent, batch_size=self.batch_size, max_seqs=self.max_seqs, max_seq_length=int(self.max_seq_length), seq_drop=self.seq_drop, shuffle_batches=self.shuffle_batches, used_data_keys=self.network.used_data_keys) else: self.dataset_batches['train'].reset() train_batches = self.dataset_batches['train'] self.updater.set_learning_rate(self.learning_rate) trainer = Runner(engine=self, dataset=self.train_data, batches=train_batches, train=True) trainer.run(report_prefix=("pre" if self.is_pretrain_epoch() else "") + "train epoch %s" % self.epoch) if not trainer.finalized: if trainer.device_crash_batch is not None: # Otherwise we got an unexpected exception - a bug in our code. self.save_model(self.get_epoch_model_filename() + ".crash_%i" % trainer.device_crash_batch) sys.exit(1) assert not any(numpy.isinf(list(trainer.score.values()))) or any(numpy.isnan(list(trainer.score.values()))), \ "Model is broken, got inf or nan final score: %s" % trainer.score if self.model_filename and (self.epoch % self.save_model_epoch_interval == 0): self.save_model(self.get_epoch_model_filename()) self.learning_rate_control.setEpochError(self.epoch, {"train_score": trainer.score}) self.learning_rate_control.save() print(self.get_epoch_str(), "score:", self.format_score(trainer.score), "elapsed:", hms(trainer.elapsed), end=" ", file=log.v1) self.eval_model()
def calc_wer_on_dataset(dataset, refs, options, hyps): """ :param Dataset|None dataset: :param dict[str,str]|None refs: seq tag -> ref string (words delimited by space) :param options: argparse.Namespace :param dict[str,str] hyps: seq tag -> hyp string (words delimited by space) :return: WER :rtype: float """ assert dataset or refs start_time = time.time() seq_len_stats = {"refs": Stats(), "hyps": Stats()} seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") wer = 1.0 remaining_hyp_seq_tags = set(hyps.keys()) interactive = Util.is_tty() and not log.verbose[5] collected = {"hyps": [], "refs": []} max_num_collected = 1 if dataset: dataset.init_seq_order(epoch=1) else: refs = sorted(refs.items(), key=lambda item: len(item[1])) while True: if seq_idx > options.endseq: break if dataset: if not dataset.is_less_than_num_seqs(seq_idx): break dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) seq_tag = dataset.get_tag(seq_idx) assert isinstance(seq_tag, str) ref = dataset.get_data(seq_idx, options.key) if isinstance(ref, numpy.ndarray): assert ref.shape == () ref = ref.flatten()[0] # get the entry itself (str or bytes) if isinstance(ref, bytes): ref = ref.decode("utf8") assert isinstance(ref, str) try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" else: if seq_idx >= len(refs): break complete_frac = (seq_idx + 1) / float(len(refs)) seq_tag, ref = refs[seq_idx] assert isinstance(seq_tag, str) assert isinstance(ref, str) num_seqs_s = str(len(refs)) start_elapsed = time.time() - start_time progress_prefix = "%i/%s (WER %.02f%%)" % (seq_idx, num_seqs_s, wer * 100) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) remaining_hyp_seq_tags.remove(seq_tag) hyp = hyps[seq_tag] seq_len_stats["hyps"].collect([len(hyp)]) seq_len_stats["refs"].collect([len(ref)]) collected["hyps"].append(hyp) collected["refs"].append(ref) if len(collected["hyps"]) >= max_num_collected: wer = wer_compute.step(session, **collected) del collected["hyps"][:] del collected["refs"][:] if interactive: Util.progress_bar_with_time(complete_frac, prefix=progress_prefix) elif log.verbose[5]: print( progress_prefix, "seq tag %r, ref/hyp len %i/%i chars" % (seq_tag, len(ref), len(hyp))) seq_idx += 1 if len(collected["hyps"]) > 0: wer = wer_compute.step(session, **collected) print("Done. Num seqs %i. Total time %s." % (seq_idx, hms(time.time() - start_time)), file=log.v1) print("Remaining num hyp seqs %i." % (len(remaining_hyp_seq_tags), ), file=log.v1) if dataset: print("More seqs which we did not dumped: %s." % dataset.is_less_than_num_seqs(seq_idx), file=log.v1) for key in ["hyps", "refs"]: seq_len_stats[key].dump(stream_prefix="Seq-length %r %r " % (key, options.key), stream=log.v2) if options.expect_full: assert not remaining_hyp_seq_tags, "There are still remaining hypotheses." return wer
def train_epoch(self): print >> log.v4, "start", self.get_epoch_str( ), "with learning rate", self.learning_rate, "..." if self.epoch == 1 and self.save_epoch1_initial_model: epoch0_model_filename = self.epoch_model_filename( self.model_filename, 0, self.is_pretrain_epoch()) print >> log.v4, "save initial epoch1 model", epoch0_model_filename self.save_model(epoch0_model_filename, 0) if self.is_pretrain_epoch(): self.print_network_info() training_devices = self.devices if 'train' not in self.dataset_batches or not self.train_data.batch_set_generator_cache_whole_epoch( ): self.dataset_batches['train'] = self.train_data.generate_batches( recurrent_net=self.network.recurrent, batch_size=self.batch_size, max_seqs=self.max_seqs, max_seq_length=int(self.max_seq_length), seq_drop=self.seq_drop, shuffle_batches=self.shuffle_batches, used_data_keys=self.network.get_used_data_keys()) else: self.dataset_batches['train'].reset() train_batches = self.dataset_batches['train'] start_batch = self.start_batch if self.epoch == self.start_epoch else 0 trainer = TrainTaskThread( self.network, training_devices, data=self.train_data, batches=train_batches, learning_rate=self.learning_rate, updater=self.updater, eval_batch_size=self.update_batch_size, start_batch=start_batch, share_batches=self.share_batches, exclude=self.exclude, seq_train_parallel=self.seq_train_parallel, report_prefix=("pre" if self.is_pretrain_epoch() else "") + "train epoch %s" % self.epoch, epoch=self.epoch) trainer.join() if not trainer.finalized: if trainer.device_crash_batch is not None: # Otherwise we got an unexpected exception - a bug in our code. self.save_model( self.get_epoch_model_filename() + ".crash_%i" % trainer.device_crash_batch, self.epoch - 1) sys.exit(1) assert not any(numpy.isinf(trainer.score.values())) or any(numpy.isnan(trainer.score.values())), \ "Model is broken, got inf or nan final score: %s" % trainer.score if self.model_filename and (self.epoch % self.save_model_epoch_interval == 0): self.save_model(self.get_epoch_model_filename(), self.epoch) self.learning_rate_control.setEpochError( self.epoch, {"train_score": trainer.score}) self.learning_rate_control.save() if self.ctc_prior_file is not None: trainer.save_ctc_priors(self.ctc_prior_file, self.get_epoch_str()) print >> log.v1, self.get_epoch_str(), "score:", self.format_score( trainer.score), "elapsed:", hms(trainer.elapsed), self.eval_model()
def dump_dataset(dataset, options): """ :type dataset: Dataset.Dataset :param options: argparse.Namespace """ print("Epoch: %i" % options.epoch, file=log.v3) dataset.init_seq_order(epoch=options.epoch) print("Dataset keys:", dataset.get_data_keys(), file=log.v3) print("Dataset target keys:", dataset.get_target_list(), file=log.v3) assert options.key in dataset.get_data_keys() if options.get_num_seqs: print("Get num seqs.") print("estimated_num_seqs: %r" % dataset.estimated_num_seqs) try: print("num_seqs: %r" % dataset.num_seqs) except Exception as exc: print("num_seqs exception %r, which is valid, so we count." % exc) seq_idx = 0 if dataset.get_target_list(): default_target = dataset.get_target_list()[0] else: default_target = None while dataset.is_less_than_num_seqs(seq_idx): dataset.load_seqs(seq_idx, seq_idx + 1) if seq_idx % 10000 == 0: if default_target: targets = dataset.get_targets(default_target, seq_idx) postfix = " (targets = %r...)" % (targets[:10], ) else: postfix = "" print("%i ...%s" % (seq_idx, postfix)) seq_idx += 1 print("accumulated num seqs: %i" % seq_idx) print("Done.") return if options.type == "numpy": print("Dump files: %r*%r" % (options.dump_prefix, options.dump_postfix), file=log.v3) elif options.type == "stdout": print("Dump to stdout", file=log.v3) elif options.type == "print_shape": print("Dump shape to stdout", file=log.v3) elif options.type == "plot": print("Plot.", file=log.v3) elif options.type == "null": print("No dump.") else: raise Exception("unknown dump option type %r" % options.type) start_time = time.time() stats = Stats() if (options.stats or options.dump_stats) else None seq_len_stats = {key: Stats() for key in dataset.get_data_keys()} seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq: dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % (seq_idx, num_seqs_s) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) data = dataset.get_data(seq_idx, options.key) if options.type == "numpy": numpy.savetxt( "%s%i.data%s" % (options.dump_prefix, seq_idx, options.dump_postfix), data) elif options.type == "stdout": print("seq %s data:" % progress, pretty_print(data)) elif options.type == "print_shape": print("seq %s data shape:" % progress, data.shape) elif options.type == "plot": plot(data) for target in dataset.get_target_list(): targets = dataset.get_targets(target, seq_idx) if options.type == "numpy": numpy.savetxt("%s%i.targets.%s%s" % (options.dump_prefix, seq_idx, target, options.dump_postfix), targets, fmt='%i') elif options.type == "stdout": print("seq %i target %r:" % (seq_idx, target), pretty_print(targets)) elif options.type == "print_shape": print("seq %i target %r shape:" % (seq_idx, target), targets.shape) seq_len = dataset.get_seq_length(seq_idx) for key in dataset.get_data_keys(): seq_len_stats[key].collect([seq_len[key]]) if stats: stats.collect(data) if options.type == "null": Util.progress_bar_with_time(complete_frac, prefix=progress_prefix) seq_idx += 1 print("Done. Total time %s. More seqs which we did not dumped: %s" % (hms(time.time() - start_time), dataset.is_less_than_num_seqs(seq_idx)), file=log.v1) for key in dataset.get_data_keys(): seq_len_stats[key].dump(stream_prefix="Seq-length %r " % key, stream=log.v2) if stats: stats.dump(output_file_prefix=options.dump_stats, stream_prefix="Data %r " % options.key, stream=log.v2)
def train_epoch(self): print >> log.v4, "start", self.get_epoch_str(), "with learning rate", self.learning_rate, "..." if self.epoch == 1 and self.save_epoch1_initial_model: epoch0_model_filename = self.epoch_model_filename(self.model_filename, 0, self.is_pretrain_epoch()) print >> log.v4, "save initial epoch1 model", epoch0_model_filename self.save_model(epoch0_model_filename, 0) if self.is_pretrain_epoch(): self.print_network_info() training_devices = self.devices if not 'train' in self.dataset_batches: self.dataset_batches['train'] = self.train_data.generate_batches(recurrent_net=self.network.recurrent, batch_size=self.batch_size, max_seqs=self.max_seqs, max_seq_length=int(self.max_seq_length), batch_variance=self.batch_variance, shuffle_batches=self.shuffle_batches) else: self.dataset_batches['train'].reset() train_batches = self.dataset_batches['train'] start_batch = self.start_batch if self.epoch == self.start_epoch else 0 trainer = TrainTaskThread(self.network, training_devices, data=self.train_data, batches=train_batches, learning_rate=self.learning_rate, updater=self.updater, eval_batch_size=self.update_batch_size, start_batch=start_batch, share_batches=self.share_batches, exclude=self.exclude, report_prefix=("pre" if self.is_pretrain_epoch() else "") + "train epoch %s" % self.epoch, epoch=self.epoch) trainer.join() if not trainer.finalized: if trainer.device_crash_batch is not None: # Otherwise we got an unexpected exception - a bug in our code. self.save_model(self.get_epoch_model_filename() + ".crash_%i" % trainer.device_crash_batch, self.epoch - 1) sys.exit(1) assert not any(numpy.isinf(trainer.score.values())) or any(numpy.isnan(trainer.score.values())), \ "Model is broken, got inf or nan final score: %s" % trainer.score if self.model_filename and (self.epoch % self.save_model_epoch_interval == 0): self.save_model(self.get_epoch_model_filename(), self.epoch) self.learning_rate_control.setEpochError(self.epoch, {"train_score": trainer.score}) self.learning_rate_control.save() if self.ctc_prior_file is not None: trainer.save_ctc_priors(self.ctc_prior_file, self.get_epoch_str()) print >> log.v1, self.get_epoch_str(), "score:", self.format_score(trainer.score), "elapsed:", hms(trainer.elapsed), self.eval_model() print >> log.v1, ""
def dump_dataset(dataset, options): """ :type dataset: Dataset.Dataset :param options: argparse.Namespace """ print("Epoch: %i" % options.epoch, file=log.v3) dataset.init_seq_order(epoch=options.epoch) print("Dataset keys:", dataset.get_data_keys(), file=log.v3) print("Dataset target keys:", dataset.get_target_list(), file=log.v3) assert options.key in dataset.get_data_keys() if options.get_num_seqs: print("Get num seqs.") print("estimated_num_seqs: %r" % dataset.estimated_num_seqs) try: print("num_seqs: %r" % dataset.num_seqs) except Exception as exc: print("num_seqs exception %r, which is valid, so we count." % exc) seq_idx = 0 if dataset.get_target_list(): default_target = dataset.get_target_list()[0] else: default_target = None while dataset.is_less_than_num_seqs(seq_idx): dataset.load_seqs(seq_idx, seq_idx + 1) if seq_idx % 10000 == 0: if default_target: targets = dataset.get_targets(default_target, seq_idx) postfix = " (targets = %r...)" % (targets[:10],) else: postfix = "" print("%i ...%s" % (seq_idx, postfix)) seq_idx += 1 print("accumulated num seqs: %i" % seq_idx) print("Done.") return if options.type == "numpy": print("Dump files: %r*%r" % (options.dump_prefix, options.dump_postfix), file=log.v3) elif options.type == "stdout": print("Dump to stdout", file=log.v3) if options.stdout_limit is not None: Util.set_pretty_print_default_limit(options.stdout_limit) numpy.set_printoptions( threshold=sys.maxsize if options.stdout_limit == float("inf") else int(options.stdout_limit)) if options.stdout_as_bytes: Util.set_pretty_print_as_bytes(options.stdout_as_bytes) elif options.type == "print_tag": print("Dump seq tag to stdout", file=log.v3) elif options.type == "print_shape": print("Dump shape to stdout", file=log.v3) elif options.type == "plot": print("Plot.", file=log.v3) elif options.type == "null": print("No dump.") else: raise Exception("unknown dump option type %r" % options.type) start_time = time.time() stats = Stats() if (options.stats or options.dump_stats) else None seq_len_stats = {key: Stats() for key in dataset.get_data_keys()} seq_idx = options.startseq if options.endseq < 0: options.endseq = float("inf") while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq: dataset.load_seqs(seq_idx, seq_idx + 1) complete_frac = dataset.get_complete_frac(seq_idx) start_elapsed = time.time() - start_time try: num_seqs_s = str(dataset.num_seqs) except NotImplementedError: try: num_seqs_s = "~%i" % dataset.estimated_num_seqs except TypeError: # a number is required, not NoneType num_seqs_s = "?" progress_prefix = "%i/%s" % (seq_idx, num_seqs_s) progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100) if complete_frac > 0: total_time_estimated = start_elapsed / complete_frac remaining_estimated = total_time_estimated - start_elapsed progress += " (%s)" % hms(remaining_estimated) if options.type == "print_tag": print("seq %s tag:" % (progress if log.verbose[2] else progress_prefix), dataset.get_tag(seq_idx)) else: data = dataset.get_data(seq_idx, options.key) if options.type == "numpy": numpy.savetxt("%s%i.data%s" % (options.dump_prefix, seq_idx, options.dump_postfix), data) elif options.type == "stdout": print("seq %s tag:" % progress, dataset.get_tag(seq_idx)) print("seq %s data:" % progress, pretty_print(data)) elif options.type == "print_shape": print("seq %s data shape:" % progress, data.shape) elif options.type == "plot": plot(data) for target in dataset.get_target_list(): targets = dataset.get_targets(target, seq_idx) if options.type == "numpy": numpy.savetxt("%s%i.targets.%s%s" % (options.dump_prefix, seq_idx, target, options.dump_postfix), targets, fmt='%i') elif options.type == "stdout": extra = "" if target in dataset.labels and len(dataset.labels[target]) > 1: labels = dataset.labels[target] if len(labels) < 1000 and all([len(l) == 1 for l in labels]): join_str = "" else: join_str = " " extra += " (%r)" % join_str.join(map(dataset.labels[target].__getitem__, targets)) print("seq %i target %r: %s%s" % (seq_idx, target, pretty_print(targets), extra)) elif options.type == "print_shape": print("seq %i target %r shape:" % (seq_idx, target), targets.shape) seq_len = dataset.get_seq_length(seq_idx) for key in dataset.get_data_keys(): seq_len_stats[key].collect([seq_len[key]]) if stats: stats.collect(data) if options.type == "null": Util.progress_bar_with_time(complete_frac, prefix=progress_prefix) seq_idx += 1 print("Done. Total time %s. More seqs which we did not dumped: %s" % ( hms(time.time() - start_time), dataset.is_less_than_num_seqs(seq_idx)), file=log.v2) for key in dataset.get_data_keys(): seq_len_stats[key].dump(stream_prefix="Seq-length %r " % key, stream=log.v2) if stats: stats.dump(output_file_prefix=options.dump_stats, stream_prefix="Data %r " % options.key, stream=log.v1)