def eval(self) -> 'metrics.EvalScore': """ Perform evaluation task. Returns: Evaluated score """ event_trigger.set_train(False) if self.src_data is None: self.src_data, self.ref_data, self.src_batches, self.ref_batches = \ input_readers.read_parallel_corpus(src_reader=self.model.src_reader, trg_reader=self.model.trg_reader, src_file=self.src_file, trg_file=self.ref_file, batcher=self.batcher, max_num_sents=self.max_num_sents, max_src_len=self.max_src_len, max_trg_len=self.max_trg_len) loss_val = losses.FactoredLossVal() ref_words_cnt = 0 for src, trg in zip(self.src_batches, self.ref_batches): with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) loss = self.loss_calculator.calc_loss(self.model, src, trg) ref_words_cnt += sum([trg_i.len_unpadded() for trg_i in trg]) loss_val += loss.get_factored_loss_val(comb_method=self.loss_comb_method) loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()} # return metrics.LossScore(sum(loss_stats.values()), loss_stats=loss_stats, num_ref_words=ref_words_cnt, desc=self.desc)
def run_training(self, save_fct: Callable) -> None: """ Main training loop (overwrites TrainingRegimen.run_training()) """ if self.run_for_epochs is None or self.run_for_epochs > 0: for src, trg in self.next_minibatch(): if self.dev_zero: self.checkpoint_and_save(save_fct) self.dev_zero = False with utils.ReportOnException({ "src": src, "trg": trg, "graph": utils.print_cg_conditional }): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) with self.train_loss_tracker.time_tracker: event_trigger.set_train(True) loss_builder = self.training_step(src, trg) loss = loss_builder.compute() self.backward(loss, self.dynet_profiling) self.update(self.trainer) self.train_loss_tracker.report( trg, loss_builder.get_factored_loss_val( comb_method=self.loss_comb_method)) if self.checkpoint_needed(): self.checkpoint_and_save(save_fct) if self.should_stop_training(): break
def _generate_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, max_src_len: Optional[int] = None, fp: TextIO = None): """ Generate outputs for a single batch and write them to the output file. """ batch_size = len(src_batch) src_batches = batcher.pack(src_batch, None) src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is not None and src_len > max_src_len: output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size) fp.write(f"{output_txt}\n") else: with utils.ReportOnException({ "src": src_batch, "graph": utils.print_cg_conditional }): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) outputs = self.generate_one(generator, src_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): output_txt = outputs[i].sent_str( custom_output_procs=self.post_processor) fp.write(f"{output_txt}\n")
def _forced_decode_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, ref_batch: batchers.Batch = None, assert_scores: batchers.Batch = None, max_src_len: Optional[int] = None): """ Performs forced decoding for a single batch. """ batch_size = len(src_batch) src_batches, ref_batches = batcher.pack(src_batch, ref_batch) src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is None or src_len <= max_src_len is not None and src_len > max_src_len: with utils.ReportOnException({ "src": src_batch, "graph": utils.print_cg_conditional }): tt.reset_graph() outputs = self.generate_one(generator, src_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): if assert_scores is not None: # If debugging forced decoding, make sure it matches assert batch_size == len( outputs ), "debug forced decoding not supported with nbest inference" if (abs(outputs[i].score - assert_scores[i]) / abs(assert_scores[i])) > 1e-5: raise ValueError( f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at ' f'sentence {i}')
def _generate_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, max_src_len: Optional[int] = None, fp: TextIO = None): """ Generate outputs for a single batch and write them to the output file. """ batch_size = len(src_batch) src_batches = batcher.pack(src_batch, None) src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is not None and src_len > max_src_len: output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size) fp.write(f"{output_txt}\n") else: with utils.ReportOnException({ "src": src_batch, "graph": utils.print_cg_conditional }): tt.reset_graph() with torch.no_grad( ) if xnmt.backend_torch else utils.dummy_context_mgr(): outputs = self.generate_one(generator, src_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): output_txt = outputs[i].sent_str( custom_output_procs=self.post_processor) fp.write(f"{output_txt}\n")
def run_training(self, save_fct: Callable) -> None: """ Main training loop (overwrites TrainingRegimen.run_training()) """ dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) if self.run_for_epochs is None or self.run_for_epochs > 0: total_loss = losses.FactoredLossExpr() # Needed for report total_trg = [] for src, trg in self.next_minibatch(): if self.dev_zero: self.checkpoint_and_save(save_fct) self.dev_zero = False with utils.ReportOnException({ "src": src, "trg": trg, "graph": utils.print_cg_conditional }): with self.train_loss_tracker.time_tracker: event_trigger.set_train(True) total_trg.append(trg[0]) loss_builder = self.training_step(src, trg) total_loss.add_factored_loss_expr(loss_builder) # num_updates_skipped is incremented in update but # we need to call backward before update if self.num_updates_skipped == self.update_every - 1: self.backward(total_loss.compute(), self.dynet_profiling) self.update(self.trainer) if self.num_updates_skipped == 0: total_loss_val = total_loss.get_factored_loss_val( comb_method=self.loss_comb_method) reported_trg = batchers.ListBatch(total_trg) self.train_loss_tracker.report(reported_trg, total_loss_val) total_loss = losses.FactoredLossExpr() total_trg = [] dy.renew_cg( immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) if self.checkpoint_needed(): # Do a last update before checkpoint # Force forward-backward for the last batch even if it's smaller than update_every self.num_updates_skipped = self.update_every - 1 self.backward(total_loss.compute(), self.dynet_profiling) self.update(self.trainer) total_loss_val = total_loss.get_factored_loss_val( comb_method=self.loss_comb_method) reported_trg = batchers.ListBatch(total_trg) self.train_loss_tracker.report(reported_trg, total_loss_val) total_loss = losses.FactoredLossExpr() total_trg = [] self.checkpoint_and_save(save_fct) if self.should_stop_training(): break
def _generate_output(self, generator: 'models.GeneratorModel', src_corpus: Sequence[sent.Sentence], trg_file: str, batcher: Optional[batchers.Batcher] = None, max_src_len: Optional[int] = None, forced_ref_corpus: Optional[Sequence[sent.Sentence]] = None, assert_scores: Optional[Sequence[float]] = None) -> None: """ Generate outputs and write them to file. Args: generator: generator model to use src_corpus: src-side inputs to generate outputs for trg_file: file to write outputs to batcher: necessary with some cases of input pre-processing such as padding or truncation max_src_len: if given, skip inputs that are too long forced_ref_corpus: if given, perform forced decoding with the given trg-side inputs assert_scores: if given, raise exception if the scores for generated outputs don't match the given scores """ with open(trg_file, 'wt', encoding='utf-8') as fp: # Saving the translated output to a trg file if forced_ref_corpus: src_batches, ref_batches = batcher.pack(src_corpus, forced_ref_corpus) else: src_batches = batcher.pack(src_corpus, None) cur_sent_i = 0 ref_batch = None for batch_i, src_batch in enumerate(src_batches): batch_size = src_batch.batch_size() src_len = src_batch.sent_len() if max_src_len is not None and src_len > max_src_len: output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size) fp.write(f"{output_txt}\n") else: if forced_ref_corpus: ref_batch = ref_batches[batch_i] with utils.ReportOnException({"batchno":batch_i, "src": src_batch, "graph": utils.print_cg_conditional}): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) outputs = self.generate_one(generator, src_batch, ref_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): if assert_scores is not None: # If debugging forced decoding, make sure it matches assert batch_size == len(outputs), "debug forced decoding not supported with nbest inference" if (abs(outputs[i].score - assert_scores[cur_sent_i + i]) / abs(assert_scores[cur_sent_i + i])) > 1e-5: raise ValueError( f'Forced decoding score {outputs[0].score} and loss {assert_scores[cur_sent_i + i]} do not match at ' f'sentence {cur_sent_i + i}') output_txt = outputs[i].sent_str(custom_output_procs=self.post_processor) fp.write(f"{output_txt}\n") cur_sent_i += batch_size if self.max_num_sents and cur_sent_i >= self.max_num_sents: break if self.reporter: self._conclude_report()
def eval(self) -> 'metrics.EvalScore': """ Perform evaluation task. Returns: Evaluated score """ event_trigger.set_train(False) if self.src_data is None: self.src_data, self.ref_data, self.src_batches, self.ref_batches = \ input_readers.read_parallel_corpus(src_reader=self.model.src_reader, trg_reader=self.model.trg_reader, src_file=self.src_file, trg_file=self.ref_file, batcher=self.batcher, max_num_sents=self.max_num_sents, max_src_len=self.max_src_len, max_trg_len=self.max_trg_len) ref_words_cnt = 0 loss_maps = defaultdict(float) loss_wrds = defaultdict(float) for src, trg in zip(self.src_batches, self.ref_batches): with utils.ReportOnException({ "src": src, "trg": trg, "graph": utils.print_cg_conditional }): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) loss_expr = self.loss_calculator.calc_loss( self.model, src, trg) loss, loss_value = loss_expr.compute( comb_method=self.loss_comb_method) for k, (value, unit) in loss_value.items(): loss_maps[k] += value loss_wrds[k] += unit loss_stats = {k: loss_maps[k] / loss_wrds[k] for k in loss_maps.keys()} return metrics.LossScore(sum(loss_stats.values()), loss_stats=loss_stats, num_ref_words=ref_words_cnt, desc=self.desc)
def eval(self) -> 'metrics.EvalScore': """ Perform evaluation task. Returns: Evaluated score """ event_trigger.set_train(False) if self.src_data is None: self.src_data, self.ref_data, self.src_batches, self.ref_batches = \ input_readers.read_parallel_corpus(src_reader=self.model.src_reader, trg_reader=self.model.trg_reader, src_file=self.src_file, trg_file=self.ref_file, batcher=self.batcher, max_num_sents=self.max_num_sents, max_src_len=self.max_src_len, max_trg_len=self.max_trg_len) loss_val = losses.FactoredLossVal() ref_words_cnt = 0 for src, trg in zip(self.src_batches, self.ref_batches): with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}): tt.reset_graph() with torch.no_grad() if xnmt.backend_torch else utils.dummy_context_mgr(): loss = self.loss_calculator.calc_loss(self.model, src, trg) ref_words_cnt += sum([trg_i.len_unpadded() for trg_i in trg]) loss_val += loss.get_factored_loss_val() if settings.PRETEND: break loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()} self.src_data, self.trg_data, self.src_batches, self.trg_batches = None, None, None, None return metrics.LossScore(sum(loss_stats.values()), loss_stats=loss_stats, num_ref_words=ref_words_cnt, desc=self.desc)
def _generate_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, ref_batch: Optional[batchers.Batch] = None, assert_scores: Optional[List[int]] = None, max_src_len: Optional[int] = None, fp: TextIO = None): """ Generate outputs for a single batch and write them to the output file. """ batch_size = len(src_batch) if ref_batch[0] is not None: src_batches, ref_batches = batcher.pack(src_batch, ref_batch) ref_batch = ref_batches[0] else: src_batches = batcher.pack(src_batch, None) ref_batch = None src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is not None and src_len > max_src_len: output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size) fp.write(f"{output_txt}\n") else: with utils.ReportOnException({"src": src_batch, "graph": utils.print_cg_conditional}): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) outputs = self.generate_one(generator, src_batch, ref_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): if assert_scores[0] is not None: # If debugging forced decoding, make sure it matches assert batch_size == len(outputs), "debug forced decoding not supported with nbest inference" if (abs(outputs[i].score - assert_scores[i]) / abs(assert_scores[i])) > 1e-5: raise ValueError( f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at ' f'sentence {i}') output_txt = outputs[i].sent_str(custom_output_procs=self.post_processor) fp.write(f"{output_txt}\n")
def run_training(self, save_fct: Callable) -> None: """ Main training loop (overwrites TrainingRegimen.run_training()) """ if self.run_for_epochs is None or self.run_for_epochs > 0: for src, trg in self.next_minibatch(): if self.dev_zero: self.checkpoint_and_save(save_fct) self.dev_zero = False with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}), \ self.skip_out_of_memory: tt.reset_graph() with self.train_loss_tracker.time_tracker: event_trigger.set_train(True) loss_builder = self.training_step(src, trg) loss = loss_builder.compute( comb_method=self.loss_comb_method) self.backward(loss) self.update(self.trainer) self.train_loss_tracker.report( trg, loss_builder.get_factored_loss_val()) if self.checkpoint_needed(): self.checkpoint_and_save(save_fct) if self.should_stop_training(): break