def __call__(self, trainer): encdec = trainer.updater.get_optimizer("main").target log.info("computing %s" % self.observation_name) dev_loss = compute_loss_all(encdec, self.data, self.eos_idx, self.mb_size, gpu=self.gpu, reverse_src=self.reverse_src, reverse_tgt=self.reverse_tgt, use_chainerx=self.use_chainerx) log.info("%s: %f (current best: %r)" % (self.observation_name, dev_loss, self.best_loss)) chainer.reporter.report({self.observation_name: dev_loss}) if self.best_loss is None or self.best_loss > dev_loss: log.info("loss (%s) improvement: %r -> %r" % (self.observation_name, self.best_loss, dev_loss)) self.best_loss = dev_loss if self.save_best_model_to is not None: log.info("saving best loss (%s) model to %s" % (self.observation_name, self.save_best_model_to,)) serializers.save_npz(self.save_best_model_to, encdec) if self.config_training is not None: config_session = self.config_training.copy(readonly=False) config_session.add_section("model_parameters", keep_at_bottom="metadata") config_session["model_parameters"]["filename"] = self.save_best_model_to config_session["model_parameters"]["type"] = "model" config_session["model_parameters"]["description"] = "best_loss" config_session["model_parameters"]["infos"] = argument_parsing_tools.OrderedNamespace() config_session["model_parameters"]["infos"]["loss"] = float(dev_loss) config_session["model_parameters"]["infos"]["iteration"] = trainer.updater.iteration config_session.set_metadata_modified_time() config_session.save_to(self.save_best_model_to + ".config")
def __call__(self, trainer): encdec = trainer.updater.get_optimizer("main").target # translations_fn = output_files_dict["dev_translation_output"] #save_prefix + ".test.out" # control_src_fn = output_files_dict["dev_src_output"] #save_prefix + ".test.src.out" bleu_stats = translate_to_file(encdec, self.eos_idx, self.src_data, self.mb_size, self.tgt_indexer, self.translations_fn, test_references=self.references, control_src_fn=self.control_src_fn, src_indexer=self.src_indexer, gpu=self.gpu, nb_steps=50, reverse_src=self.reverse_src, reverse_tgt=self.reverse_tgt, s_unk_tag=self.s_unk_tag, t_unk_tag=self.t_unk_tag) bleu = bleu_stats.bleu() chainer.reporter.report({ self.observation_name: bleu, self.observation_name + "_details": repr(bleu) }) if self.best_bleu is None or self.best_bleu < bleu: log.info("%s improvement: %r -> %r" % (self.observation_name, self.best_bleu, bleu)) self.best_bleu = bleu if self.save_best_model_to is not None: log.info("saving best bleu (%s) model to %s" % ( self.observation_name, self.save_best_model_to, )) serializers.save_npz(self.save_best_model_to, encdec) if self.config_training is not None: config_session = self.config_training.copy(readonly=False) config_session.add_section("model_parameters", keep_at_bottom="metadata") config_session["model_parameters"][ "filename"] = self.save_best_model_to config_session["model_parameters"]["type"] = "model" config_session["model_parameters"][ "description"] = "best_bleu" config_session["model_parameters"][ "infos"] = argument_parsing_tools.OrderedNamespace() config_session["model_parameters"]["infos"][ "bleu_stats"] = str(bleu_stats) config_session["model_parameters"]["infos"][ "iteration"] = trainer.updater.iteration config_session.set_metadata_modified_time() config_session.save_to(self.save_best_model_to + ".config") # json.dump(config_session, open(self.save_best_model_to + ".config", # "w"), indent=2, separators=(',', ': ')) else: log.info("no bleu (%s) improvement: %f >= %f" % (self.observation_name, self.best_bleu, bleu))
def load_config_train(filename, readonly=True, no_error=False): config = argument_parsing_tools.OrderedNamespace.load_from(filename) if "metadata" not in config: # older config file parse_option_orderer = get_parse_option_orderer() config_training = parse_option_orderer.convert_args_to_ordered_dict( config["command_line"], args_is_namespace=False) convert_cell_string(config_training, no_error=no_error) assert "data" not in config_training config_training["data"] = argument_parsing_tools.OrderedNamespace() config_training["data"]["data_fn"] = config["data"] config_training["data"]["Vi"] = config["Vi"] config_training["data"]["Vo"] = config["Vo"] config_training["data"]["voc"] = config["voc"] assert "metadata" not in config_training config_training["metadata"] = argument_parsing_tools.OrderedNamespace() config_training["metadata"]["config_version_num"] = 0.9 config_training["metadata"]["command_line"] = None config_training["metadata"]["knmt_version"] = None config = config_training elif config["metadata"]["config_version_num"] != 1.0: raise ValueError( "The config version of %s is not supported by this version of the program" % filename) # Compatibility with intermediate verions of config file if "data_prefix" in config and "data_prefix" not in config[ "training_management"]: config["training_management"]["data_prefix"] = config["data_prefix"] del config["data_prefix"] if "train_prefix" in config and "train_prefix" not in config[ "training_management"]: config["training_management"]["train_prefix"] = config["train_prefix"] del config["train_prefix"] if readonly: config.set_readonly() return config
def __call__(self, trainer): log.info("Saving current trainer state to file %s" % self.save_to) serializers.save_npz(self.save_to, trainer) config_session = self.config_training.copy(readonly=False) config_session.add_section("model_parameters", keep_at_bottom="metadata") config_session["model_parameters"]["filename"] = self.save_to config_session["model_parameters"]["type"] = "snapshot" config_session["model_parameters"]["description"] = "checkpoint" config_session["model_parameters"]["infos"] = argument_parsing_tools.OrderedNamespace() config_session["model_parameters"]["infos"]["iteration"] = trainer.updater.iteration config_session.set_metadata_modified_time() config_session.save_to(self.save_to + ".config") # json.dump(config_session, open(self.save_to + ".config", "w"), indent=2, # separators=(',', ': ')) log.info("Saved trainer snapshot to file %s" % self.save_to)
def load_config_eval(filename, readonly=True): config = argument_parsing_tools.OrderedNamespace.load_from(filename) if "metadata" not in config: # older config file parse_option_orderer = get_parse_option_orderer() config_eval = parse_option_orderer.convert_args_to_ordered_dict(config, args_is_namespace=False) assert "metadata" not in config_eval config_eval["metadata"] = argument_parsing_tools.OrderedNamespace() config_eval["metadata"]["config_version_num"] = 0.9 config_eval["metadata"]["command_line"] = None config_eval["metadata"]["knmt_version"] = None config = config_eval elif config["metadata"]["config_version_num"] != 1.0: raise ValueError("The config version of %s is not supported by this version of the program" % filename) if readonly: config.set_readonly() return config
def train_on_data_chainer( encdec, optimizer, training_data, output_files_dict, src_indexer, tgt_indexer, eos_idx, config_training, stop_trigger=None, test_data=None, dev_data=None, valid_data=None, ): output_dir = config_training.training_management.save_prefix mb_size = config_training.training.mb_size nb_of_batch_to_sort = config_training.training.nb_batch_to_sort gpu = config_training.training_management.gpu report_every = config_training.training_management.report_every randomized = config_training.training.randomized_data reverse_src = config_training.training.reverse_src reverse_tgt = config_training.training.reverse_tgt do_not_save_data_for_resuming = config_training.training_management.no_resume noise_on_prev_word = config_training.training.noise_on_prev_word curiculum_training = config_training.training.curiculum_training use_previous_prediction = config_training.training.use_previous_prediction no_report_or_save = config_training.training_management.no_report_or_save use_memory_optimization = config_training.training_management.use_memory_optimization sample_every = config_training.training_management.sample_every use_reinf = config_training.training.use_reinf save_ckpt_every = config_training.training_management.save_ckpt_every trainer_snapshot = config_training.training_management.load_trainer_snapshot save_initial_model_to = config_training.training_management.save_initial_model_to reshuffle_every_epoch = config_training.training_management.reshuffle_every_epoch use_soft_prediction_feedback = config_training.training.use_soft_prediction_feedback use_gumbel_for_soft_predictions = config_training.training.use_gumbel_for_soft_predictions temperature_for_soft_predictions = config_training.training.temperature_for_soft_predictions generate_computation_graph = config_training.training_management.generate_computation_graph dynamic_batching = config_training.training.get("dynamic_batching", False) dynamic_batching_max_elems = config_training.training.get( "dynamic_batching_max_elems", 10000) dynamic_batching_nb_sent_to_sort = config_training.training.get( "dynamic_batching_nb_sent_to_sort", 5000) @chainer.training.make_extension() def sample_extension(trainer): encdec = trainer.updater.get_optimizer("main").target iterator = trainer.updater.get_iterator("main") mb_raw = iterator.peek() def s_unk_tag(num, utag): return "S_UNK_%i" % utag def t_unk_tag(num, utag): return "T_UNK_%i" % utag try: if encdec.encdec_type() == "ff": src_seqs, tgt_seqs = zip(*mb_raw) sample_once_ff(encdec, src_seqs, tgt_seqs, src_indexer, tgt_indexer, max_nb=20, s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag) else: src_batch, tgt_batch, src_mask = make_batch_src_tgt( mb_raw, eos_idx=eos_idx, padding_idx=0, gpu=gpu, need_arg_sort=False) sample_once(encdec, src_batch, tgt_batch, src_mask, src_indexer, tgt_indexer, eos_idx, max_nb=20, s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag) except CudaException: log.warn("CUDARuntimeError during sample. Skipping sample") if dynamic_batching: log.info("using dynamic matching with %i %i", dynamic_batching_max_elems, dynamic_batching_nb_sent_to_sort) iterator_training_data = DynamicLengthBasedSerialIterator( training_data, max_nb_elements=dynamic_batching_max_elems, nb_sent_sort=dynamic_batching_nb_sent_to_sort, sort_key=lambda x: len(x[1]), repeat=True, shuffle=reshuffle_every_epoch) else: iterator_training_data = LengthBasedSerialIterator( training_data, mb_size, nb_of_batch_to_sort=nb_of_batch_to_sort, sort_key=lambda x: len(x[0]), repeat=True, shuffle=reshuffle_every_epoch) generate_loss_computation_graph_on_first_call = [ generate_computation_graph is not None ] if encdec.encdec_type() == "ff": def loss_func(src_seq, tgt_seq): t0 = time.clock() loss = encdec.compute_loss(src_seq, tgt_seq, reduce="no") total_loss = F.sum(loss) total_nb_predictions = sum(len(seq) + 1 for seq in tgt_seq) avg_loss = total_loss / total_nb_predictions t1 = time.clock() chainer.reporter.report({"forward_time": t1 - t0}) chainer.reporter.report({"mb_loss": total_loss.data}) chainer.reporter.report( {"mb_nb_predictions": total_nb_predictions}) chainer.reporter.report({"trg_loss": avg_loss.data}) log.info("batch infos: %i x [%i | %i]", len(src_seq), max(len(s) for s in src_seq), max(len(s) for s in tgt_seq)) if generate_loss_computation_graph_on_first_call[0]: log.info("Writing loss computation graph to %s", generate_computation_graph) import chainer.computational_graph as c g = c.build_computational_graph( [avg_loss] ) #, variable_style=None, function_style=None, show_name=False ) with open(generate_computation_graph, 'w') as o: o.write(g.dump()) generate_loss_computation_graph_on_first_call[0] = False return avg_loss def convert_mb(mb_raw, device): return tuple(zip(*mb_raw)) else: def loss_func(src_batch, tgt_batch, src_mask): t0 = time.clock() (total_loss, total_nb_predictions), attn = encdec( src_batch, tgt_batch, src_mask, raw_loss_info=True, noise_on_prev_word=noise_on_prev_word, use_previous_prediction=use_previous_prediction, use_soft_prediction_feedback=use_soft_prediction_feedback, use_gumbel_for_soft_predictions=use_gumbel_for_soft_predictions, temperature_for_soft_predictions= temperature_for_soft_predictions) avg_loss = total_loss / total_nb_predictions t1 = time.clock() chainer.reporter.report({"forward_time": t1 - t0}) chainer.reporter.report({"mb_loss": total_loss.data}) chainer.reporter.report( {"mb_nb_predictions": total_nb_predictions}) chainer.reporter.report({"trg_loss": avg_loss.data}) log.info("batch infos: %i x [%i | %i]", src_batch[0].data.shape[0], len(src_batch), len(tgt_batch)) if generate_loss_computation_graph_on_first_call[0]: log.info("Writing loss computation graph to %s", generate_computation_graph) import chainer.computational_graph as c g = c.build_computational_graph([avg_loss]) with open(generate_computation_graph, 'w') as o: o.write(g.dump()) generate_loss_computation_graph_on_first_call[0] = False return avg_loss def convert_mb(mb_raw, device): return make_batch_src_tgt(mb_raw, eos_idx=eos_idx, padding_idx=0, gpu=device, need_arg_sort=False) updater = Updater( iterator_training_data, optimizer, converter=convert_mb, # iterator_training_data = chainer.iterators.SerialIterator(training_data, mb_size, # repeat = True, # shuffle = reshuffle_every_epoch) device=gpu, loss_func=loss_func, need_to_convert_to_variables=False) trainer = chainer.training.Trainer(updater, stop_trigger, out=output_dir) # trainer.extend(chainer.training.extensions.LogReport(trigger=(10, 'iteration'))) # trainer.extend(chainer.training.extensions.PrintReport(['epoch', 'iteration', 'trg_loss', "dev_loss", "dev_bleu"]), # trigger = (1, "iteration")) if dev_data is not None and not no_report_or_save: dev_loss_extension = ComputeLossExtension( dev_data, eos_idx, mb_size, gpu, reverse_src, reverse_tgt, save_best_model_to=output_files_dict["model_best_loss"], observation_name="dev_loss", config_training=config_training) trainer.extend(dev_loss_extension, trigger=(report_every, "iteration")) dev_bleu_extension = ComputeBleuExtension( dev_data, eos_idx, src_indexer, tgt_indexer, output_files_dict["dev_translation_output"], output_files_dict["dev_src_output"], mb_size, gpu, reverse_src, reverse_tgt, save_best_model_to=output_files_dict["model_best"], observation_name="dev_bleu", config_training=config_training) trainer.extend(dev_bleu_extension, trigger=(report_every, "iteration")) if test_data is not None and not no_report_or_save: test_loss_extension = ComputeLossExtension( test_data, eos_idx, mb_size, gpu, reverse_src, reverse_tgt, observation_name="test_loss") trainer.extend(test_loss_extension, trigger=(report_every, "iteration")) test_bleu_extension = ComputeBleuExtension( test_data, eos_idx, src_indexer, tgt_indexer, output_files_dict["test_translation_output"], output_files_dict["test_src_output"], mb_size, gpu, reverse_src, reverse_tgt, observation_name="test_bleu") trainer.extend(test_bleu_extension, trigger=(report_every, "iteration")) if not no_report_or_save: trainer.extend(sample_extension, trigger=(sample_every, "iteration")) # trainer.extend(chainer.training.extensions.snapshot(), trigger = (save_ckpt_every, "iteration")) trainer.extend(CheckpontSavingExtension( output_files_dict["model_ckpt"], config_training), trigger=(save_ckpt_every, "iteration")) trainer.extend( SqliteLogExtension(db_path=output_files_dict["sqlite_db"])) trainer.extend( TrainingLossSummaryExtension(trigger=(report_every, "iteration"))) if config_training.training_management.resume: if "model_parameters" not in config_training: log.error("cannot find model parameters in config file") raise ValueError( "Config file do not contain model_parameters section") if config_training.model_parameters.type == "snapshot": model_filename = config_training.model_parameters.filename log.info("resuming from trainer parameters %s" % model_filename) serializers.load_npz(model_filename, trainer) if trainer_snapshot is not None: log.info("loading trainer parameters from %s" % trainer_snapshot) serializers.load_npz(trainer_snapshot, trainer) try: if save_initial_model_to is not None: log.info("Saving initial parameters to %s" % save_initial_model_to) encdec = trainer.updater.get_optimizer("main").target serializers.save_npz(save_initial_model_to, encdec) trainer.run() except BaseException: if not no_report_or_save: final_snapshot_fn = output_files_dict["model_final"] log.info( "Exception met. Trying to save current trainer state to file %s" % final_snapshot_fn) serializers.save_npz(final_snapshot_fn, trainer) # chainer.training.extensions.snapshot(filename = final_snapshot_fn)(trainer) config_session = config_training.copy(readonly=False) config_session.add_section("model_parameters", keep_at_bottom="metadata") config_session["model_parameters"]["filename"] = final_snapshot_fn config_session["model_parameters"]["type"] = "snapshot" config_session["model_parameters"]["description"] = "final" config_session["model_parameters"][ "infos"] = argument_parsing_tools.OrderedNamespace() config_session["model_parameters"]["infos"][ "iteration"] = trainer.updater.iteration config_session.set_metadata_modified_time() config_session.save_to(final_snapshot_fn + ".config") # json.dump(config_session, open(final_snapshot_fn + ".config", "w"), # indent=2, separators=(',', ': ')) log.info("Saved trainer snapshot to file %s" % final_snapshot_fn) raise
def train_on_data_chainer( encdec, optimizer, training_data, output_files_dict, src_indexer, tgt_indexer, eos_idx, config_training, stop_trigger=None, test_data=None, dev_data=None, valid_data=None, ): output_dir = config_training.training_management.save_prefix mb_size = config_training.training.mb_size nb_of_batch_to_sort = config_training.training.nb_batch_to_sort gpu = config_training.training_management.gpu report_every = config_training.training_management.report_every randomized = config_training.training.randomized_data reverse_src = config_training.training.reverse_src reverse_tgt = config_training.training.reverse_tgt do_not_save_data_for_resuming = config_training.training_management.no_resume noise_on_prev_word = config_training.training.noise_on_prev_word curiculum_training = config_training.training.curiculum_training use_previous_prediction = config_training.training.use_previous_prediction no_report_or_save = config_training.training_management.no_report_or_save use_memory_optimization = config_training.training_management.use_memory_optimization sample_every = config_training.training_management.sample_every use_reinf = config_training.training.use_reinf save_ckpt_every = config_training.training_management.save_ckpt_every trainer_snapshot = config_training.training_management.load_trainer_snapshot save_initial_model_to = config_training.training_management.save_initial_model_to reshuffle_every_epoch = config_training.training_management.reshuffle_every_epoch use_soft_prediction_feedback = config_training.training.use_soft_prediction_feedback use_gumbel_for_soft_predictions = config_training.training.use_gumbel_for_soft_predictions temperature_for_soft_predictions = config_training.training.temperature_for_soft_predictions @chainer.training.make_extension() def sample_extension(trainer): encdec = trainer.updater.get_optimizer("main").target iterator = trainer.updater.get_iterator("main") mb_raw = iterator.peek() src_batch, tgt_batch, src_mask = make_batch_src_tgt( mb_raw, eos_idx=eos_idx, padding_idx=0, gpu=gpu, volatile="on", need_arg_sort=False) def s_unk_tag(num, utag): return "S_UNK_%i" % utag def t_unk_tag(num, utag): return "T_UNK_%i" % utag sample_once(encdec, src_batch, tgt_batch, src_mask, src_indexer, tgt_indexer, eos_idx, max_nb=20, s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag) iterator_training_data = LengthBasedSerialIterator( training_data, mb_size, nb_of_batch_to_sort=nb_of_batch_to_sort, sort_key=lambda x: len(x[0]), repeat=True, shuffle=reshuffle_every_epoch) def loss_func(src_batch, tgt_batch, src_mask): t0 = time.clock() (total_loss, total_nb_predictions), attn = encdec( src_batch, tgt_batch, src_mask, raw_loss_info=True, noise_on_prev_word=noise_on_prev_word, use_previous_prediction=use_previous_prediction, mode="train", use_soft_prediction_feedback=use_soft_prediction_feedback, use_gumbel_for_soft_predictions=use_gumbel_for_soft_predictions, temperature_for_soft_predictions=temperature_for_soft_predictions) avg_loss = total_loss / total_nb_predictions t1 = time.clock() chainer.reporter.report({"forward_time": t1 - t0}) chainer.reporter.report({"mb_loss": total_loss.data}) chainer.reporter.report({"mb_nb_predictions": total_nb_predictions}) chainer.reporter.report({"trg_loss": avg_loss.data}) return avg_loss def convert_mb(mb_raw, device): return make_batch_src_tgt(mb_raw, eos_idx=eos_idx, padding_idx=0, gpu=device, volatile="off", need_arg_sort=False) updater = Updater( iterator_training_data, optimizer, converter=convert_mb, # iterator_training_data = chainer.iterators.SerialIterator(training_data, mb_size, # repeat = True, # shuffle = reshuffle_every_epoch) device=gpu, loss_func=loss_func, need_to_convert_to_variables=False) trainer = chainer.training.Trainer(updater, stop_trigger, out=output_dir) # trainer.extend(chainer.training.extensions.LogReport(trigger=(10, 'iteration'))) # trainer.extend(chainer.training.extensions.PrintReport(['epoch', 'iteration', 'trg_loss', "dev_loss", "dev_bleu"]), # trigger = (1, "iteration")) if dev_data is not None and not no_report_or_save: dev_loss_extension = ComputeLossExtension( dev_data, eos_idx, mb_size, gpu, reverse_src, reverse_tgt, save_best_model_to=output_files_dict["model_best_loss"], observation_name="dev_loss", config_training=config_training) trainer.extend(dev_loss_extension, trigger=(report_every, "iteration")) dev_bleu_extension = ComputeBleuExtension( dev_data, eos_idx, src_indexer, tgt_indexer, output_files_dict["dev_translation_output"], output_files_dict["dev_src_output"], mb_size, gpu, reverse_src, reverse_tgt, save_best_model_to=output_files_dict["model_best"], observation_name="dev_bleu", config_training=config_training) trainer.extend(dev_bleu_extension, trigger=(report_every, "iteration")) if test_data is not None and not no_report_or_save: test_loss_extension = ComputeLossExtension( test_data, eos_idx, mb_size, gpu, reverse_src, reverse_tgt, observation_name="test_loss") trainer.extend(test_loss_extension, trigger=(report_every, "iteration")) test_bleu_extension = ComputeBleuExtension( test_data, eos_idx, src_indexer, tgt_indexer, output_files_dict["test_translation_output"], output_files_dict["test_src_output"], mb_size, gpu, reverse_src, reverse_tgt, observation_name="test_bleu") trainer.extend(test_bleu_extension, trigger=(report_every, "iteration")) if not no_report_or_save: trainer.extend(sample_extension, trigger=(sample_every, "iteration")) # trainer.extend(chainer.training.extensions.snapshot(), trigger = (save_ckpt_every, "iteration")) trainer.extend(CheckpontSavingExtension( output_files_dict["model_ckpt"], config_training), trigger=(save_ckpt_every, "iteration")) trainer.extend( SqliteLogExtension(db_path=output_files_dict["sqlite_db"])) trainer.extend( TrainingLossSummaryExtension(trigger=(report_every, "iteration"))) if config_training.training_management.resume: if "model_parameters" not in config_training: log.error("cannot find model parameters in config file") raise ValueError( "Config file do not contain model_parameters section") if config_training.model_parameters.type == "snapshot": model_filename = config_training.model_parameters.filename log.info("resuming from trainer parameters %s" % model_filename) serializers.load_npz(model_filename, trainer) if trainer_snapshot is not None: log.info("loading trainer parameters from %s" % trainer_snapshot) serializers.load_npz(trainer_snapshot, trainer) try: if save_initial_model_to is not None: log.info("Saving initial parameters to %s" % save_initial_model_to) encdec = trainer.updater.get_optimizer("main").target serializers.save_npz(save_initial_model_to, encdec) trainer.run() except BaseException: if not no_report_or_save: final_snapshot_fn = output_files_dict["model_final"] log.info( "Exception met. Trying to save current trainer state to file %s" % final_snapshot_fn) serializers.save_npz(final_snapshot_fn, trainer) # chainer.training.extensions.snapshot(filename = final_snapshot_fn)(trainer) config_session = config_training.copy(readonly=False) config_session.add_section("model_parameters", keep_at_bottom="metadata") config_session["model_parameters"]["filename"] = final_snapshot_fn config_session["model_parameters"]["type"] = "snapshot" config_session["model_parameters"]["description"] = "final" config_session["model_parameters"][ "infos"] = argument_parsing_tools.OrderedNamespace() config_session["model_parameters"]["infos"][ "iteration"] = trainer.updater.iteration config_session.set_metadata_modified_time() config_session.save_to(final_snapshot_fn + ".config") # json.dump(config_session, open(final_snapshot_fn + ".config", "w"), # indent=2, separators=(',', ': ')) log.info("Saved trainer snapshot to file %s" % final_snapshot_fn) raise