def run_training(self, save_fct: Callable) -> None: task_generators = OrderedDict() for task in self.tasks: task_generators[task] = task.next_minibatch() dev_zero = {i: self.dev_zero for i in range(len(self.tasks))} if self.tasks[0].run_for_epochs > 0: while True: tt.reset_graph() cur_task_i = np.random.choice(range(len(self.tasks)), p=self.task_weights) cur_task = self.tasks[cur_task_i] task_gen = task_generators[cur_task] if dev_zero[cur_task_i]: self.checkpoint_and_save(cur_task, cur_task_i, save_fct, dev_zero) with cur_task.train_loss_tracker.time_tracker: for _ in range(self.update_every_within): src, trg = next(task_gen) event_trigger.set_train(True) loss_builder = cur_task.training_step(src, trg) self.backward(loss=loss_builder.compute( comb_method=self.loss_comb_method)) self.update(trainer=self.trainer) cur_task.train_loss_tracker.report( trg, loss_builder.get_factored_loss_val()) self.checkpoint_and_save(cur_task, cur_task_i, save_fct, dev_zero) if self.tasks[0].should_stop_training(): break
def run_training(self, save_fct: Callable) -> None: dev_zero = {i: self.dev_zero for i in range(len(self.tasks))} for cur_task_id in range(len(self.tasks)): self.train = None cur_task = self.tasks[cur_task_id] task_gen = cur_task.next_minibatch() if cur_task.run_for_epochs > 0: while True: tt.reset_graph() src, trg = next(task_gen) if dev_zero[cur_task_id]: self.checkpoint_and_save(cur_task, cur_task_id, save_fct, dev_zero) with cur_task.train_loss_tracker.time_tracker: event_trigger.set_train(True) loss_builder = cur_task.training_step(src, trg) task_loss = loss_builder.compute( comb_method=self.loss_comb_method) self.backward(task_loss) self.update(self.trainer) cur_task.train_loss_tracker.report( trg, loss_builder.get_factored_loss_val()) self.checkpoint_and_save(cur_task, cur_task_id, save_fct, dev_zero) if cur_task.should_stop_training(): break
def test_py_lstm_encoder_len(self): layer_dim = 512 model = DefaultTranslator( src_reader=self.src_reader, trg_reader=self.trg_reader, src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100), encoder=PyramidalLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim, layers=3), attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim), decoder=AutoRegressiveDecoder( input_dim=layer_dim, embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100), rnn=UniLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim, decoder_input_dim=layer_dim, yaml_path="model.decoder.rnn"), transform=NonLinear(input_dim=layer_dim * 2, output_dim=layer_dim), scorer=Softmax(input_dim=layer_dim, vocab_size=100), bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)), ) event_trigger.set_train(True) for sent_i in range(10): tt.reset_graph() src = self.src_data[sent_i].create_padded_sent( 4 - (self.src_data[sent_i].sent_len() % 4)) event_trigger.start_sent(src) embeddings = model.src_embedder.embed_sent(src) encodings = model.encoder.transduce(embeddings) self.assertEqual(int(math.ceil(len(embeddings) / float(4))), len(encodings))
def _forced_decode_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, ref_batch: batchers.Batch = None, assert_scores: batchers.Batch = None, max_src_len: Optional[int] = None): """ Performs forced decoding for a single batch. """ batch_size = len(src_batch) src_batches, ref_batches = batcher.pack(src_batch, ref_batch) src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is None or src_len <= max_src_len is not None and src_len > max_src_len: with utils.ReportOnException({ "src": src_batch, "graph": utils.print_cg_conditional }): tt.reset_graph() outputs = self.generate_one(generator, src_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): if assert_scores is not None: # If debugging forced decoding, make sure it matches assert batch_size == len( outputs ), "debug forced decoding not supported with nbest inference" if (abs(outputs[i].score - assert_scores[i]) / abs(assert_scores[i])) > 1e-5: raise ValueError( f'Forced decoding score {outputs[i].score} and loss {assert_scores[i]} do not match at ' f'sentence {i}')
def _generate_one_batch(self, generator: 'models.GeneratorModel', batcher: Optional[batchers.Batcher] = None, src_batch: batchers.Batch = None, max_src_len: Optional[int] = None, fp: TextIO = None): """ Generate outputs for a single batch and write them to the output file. """ batch_size = len(src_batch) src_batches = batcher.pack(src_batch, None) src_batch = src_batches[0] src_len = src_batch.sent_len() if max_src_len is not None and src_len > max_src_len: output_txt = "\n".join([NO_DECODING_ATTEMPTED] * batch_size) fp.write(f"{output_txt}\n") else: with utils.ReportOnException({ "src": src_batch, "graph": utils.print_cg_conditional }): tt.reset_graph() with torch.no_grad( ) if xnmt.backend_torch else utils.dummy_context_mgr(): outputs = self.generate_one(generator, src_batch) if self.reporter: self._create_sent_report() for i in range(len(outputs)): output_txt = outputs[i].sent_str( custom_output_procs=self.post_processor) fp.write(f"{output_txt}\n")
def assert_in_out_len_equal(self, model): tt.reset_graph() event_trigger.set_train(True) src = self.src_data[0] event_trigger.start_sent(src) embeddings = model.src_embedder.embed_sent(src) encodings = model.encoder.transduce(embeddings) self.assertEqual(len(embeddings), len(encodings))
def run_training(self, save_fct: Callable) -> None: """ Main training loop (overwrites TrainingRegimen.run_training()) """ tt.reset_graph() if self.run_for_epochs is None or self.run_for_epochs > 0: total_loss = losses.FactoredLossExpr() # Needed for report total_trg = [] for src, trg in self.next_minibatch(): if self.dev_zero: self.checkpoint_and_save(save_fct) self.dev_zero = False with utils.ReportOnException({ "src": src, "trg": trg, "graph": utils.print_cg_conditional }): with self.train_loss_tracker.time_tracker: event_trigger.set_train(True) total_trg.append(trg[0]) loss_builder = self.training_step(src, trg) total_loss.add_factored_loss_expr(loss_builder) # num_updates_skipped is incremented in update but # we need to call backward before update if self.num_updates_skipped == self.update_every - 1: self.backward( total_loss.compute( comb_method=self.loss_comb_method)) self.update(self.trainer) if self.num_updates_skipped == 0: total_loss_val = total_loss.get_factored_loss_val() reported_trg = batchers.ListBatch(total_trg) self.train_loss_tracker.report(reported_trg, total_loss_val) total_loss = losses.FactoredLossExpr() total_trg = [] tt.reset_graph() if self.checkpoint_needed(): # Do a last update before checkpoint # Force forward-backward for the last batch even if it's smaller than update_every self.num_updates_skipped = self.update_every - 1 self.backward( total_loss.compute(comb_method=self.loss_comb_method)) self.update(self.trainer) total_loss_val = total_loss.get_factored_loss_val() reported_trg = batchers.ListBatch(total_trg) self.train_loss_tracker.report(reported_trg, total_loss_val) total_loss = losses.FactoredLossExpr() total_trg = [] self.checkpoint_and_save(save_fct) if self.should_stop_training(): break
def test_single(self): tt.reset_graph() outputs = self.model.generate( batchers.mark_as_batch([self.src_data[0]]), GreedySearch()) output_score = outputs[0].score tt.reset_graph() train_loss = tt.npvalue( self.model.calc_nll(src=self.src_data[0], trg=outputs[0])) self.assertAlmostEqual(-output_score, train_loss[0], places=3)
def test_greedy_vs_beam(self): tt.reset_graph() outputs = self.model.generate( batchers.mark_as_batch([self.src_data[0]]), BeamSearch(beam_size=1)) output_score1 = outputs[0].score tt.reset_graph() outputs = self.model.generate( batchers.mark_as_batch([self.src_data[0]]), GreedySearch()) output_score2 = outputs[0].score self.assertAlmostEqual(output_score1, output_score2)
def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1): """ Tests whether single loss equals batch loss. Here we don't truncate the target side and use masking. """ batch_size = 5 src_sents = self.src_data[:batch_size] src_min = min([x.sent_len() for x in src_sents]) src_sents_trunc = [s.words[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min - 1] = Vocab.ES while len(single_sent) % pad_src_to_multiple != 0: single_sent.append(Vocab.ES) trg_sents = sorted(self.trg_data[:batch_size], key=lambda x: x.sent_len(), reverse=True) trg_max = max([x.sent_len() for x in trg_sents]) np_arr = np.zeros([batch_size, trg_max]) for i in range(batch_size): for j in range(trg_sents[i].sent_len(), trg_max): np_arr[i, j] = 1.0 trg_masks = Mask(np_arr) trg_sents_padded = [[w for w in s] + [Vocab.ES] * (trg_max - s.sent_len()) for s in trg_sents] src_sents_trunc = [ sent.SimpleSentence(words=s) for s in src_sents_trunc ] trg_sents_padded = [ sent.SimpleSentence(words=s) for s in trg_sents_padded ] single_loss = 0.0 for sent_id in range(batch_size): tt.reset_graph() train_loss = MLELoss().calc_loss(model=model, src=src_sents_trunc[sent_id], trg=trg_sents[sent_id]).value() single_loss += train_loss[0] tt.reset_graph() batched_loss = MLELoss().calc_loss(model=model, src=mark_as_batch(src_sents_trunc), trg=mark_as_batch( trg_sents_padded, trg_masks)).value() self.assertAlmostEqual(single_loss, np.sum(batched_loss), places=4)
def _compute_losses(self, generator, ref_corpus, src_corpus, max_num_sents) -> List[numbers.Real]: batched_src, batched_ref = self.batcher.pack(src_corpus, ref_corpus) ref_scores = [] for sent_count, (src, ref) in enumerate(zip(batched_src, batched_ref)): if max_num_sents and sent_count >= max_num_sents: break tt.reset_graph() loss = self.compute_losses_one(generator, src, ref) if isinstance(loss.value(), collections.abc.Iterable): ref_scores.extend(loss.value()) else: ref_scores.append(loss.value()) ref_scores = [-x for x in ref_scores] return ref_scores
def run_training(self, save_fct: Callable) -> None: task_generators = OrderedDict() for task in self.tasks: task_generators[task] = task.next_minibatch() if self.tasks[0].run_for_epochs > 0: while True: task_losses = [] task_src_trg = [] for (task, task_gen), task_n in zip(task_generators.items(), self.n_task_steps): for _ in range(task_n): src, trg = next(task_gen) task_src_trg.append((task, src, trg)) if self.dev_zero: # True only in first iteration self.checkpoint_and_save(save_fct) tt.reset_graph() task_trg_loss_stats = {} with contextlib.ExitStack( ) as stack: #use exit stack to control whether to use global or per-task time tracking if not self.per_task_backward: stack.enter_context( self.tasks[0].train_loss_tracker.time_tracker) event_trigger.set_train(True) for task, src, trg in task_src_trg: with contextlib.ExitStack() as stack2: if self.per_task_backward: stack2.enter_context( task.train_loss_tracker.time_tracker) loss_builder = task.training_step(src, trg) task_trg_loss_stats[task] = ( trg, loss_builder.get_factored_loss_val()) if self.per_task_backward: self.backward( loss_builder.compute( comb_method=self.loss_comb_method)) tt.reset_graph(zero_grad=False) else: task_losses.append( loss_builder.compute( comb_method=self.loss_comb_method)) if not self.per_task_backward: self.backward(sum(task_losses)) self.update(self.trainer) for task, (trg, stats) in task_trg_loss_stats.items(): task.train_loss_tracker.report(trg, stats) self.checkpoint_and_save(save_fct) if self.tasks[0].should_stop_training(): break
def test_single(self): tt.reset_graph() outputs = self.model.generate( batchers.mark_as_batch([self.src_data[0]]), BeamSearch()) # Make sure the output of beam search is the same as the target sentence # (this is a very overfit model on exactly this data) self.assertEqual(outputs[0].sent_len(), self.trg_data[0].sent_len()) for i in range(outputs[0].sent_len()): self.assertEqual(outputs[0][i], self.trg_data[0][i]) # Verify that the loss we get from beam search is the same as the loss # we get if we call model.calc_nll tt.reset_graph() train_loss = self.model.calc_nll(src=self.src_data[0], trg=outputs[0]).value() self.assertAlmostEqual(-outputs[0].score, train_loss, places=4)
def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1): """ Tests whether single loss equals batch loss. Truncating src / trg sents to same length so no masking is necessary """ batch_size = 5 src_sents = self.src_data[:batch_size] src_min = min([x.sent_len() for x in src_sents]) src_sents_trunc = [s.words[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min - 1] = Vocab.ES while len(single_sent) % pad_src_to_multiple != 0: single_sent.append(Vocab.ES) trg_sents = self.trg_data[:batch_size] trg_min = min([x.sent_len() for x in trg_sents]) trg_sents_trunc = [s.words[:trg_min] for s in trg_sents] for single_sent in trg_sents_trunc: single_sent[trg_min - 1] = Vocab.ES src_sents_trunc = [ sent.SimpleSentence(words=s) for s in src_sents_trunc ] trg_sents_trunc = [ sent.SimpleSentence(words=s) for s in trg_sents_trunc ] single_loss = 0.0 for sent_id in range(batch_size): tt.reset_graph() train_loss = MLELoss().calc_loss( model=model, src=src_sents_trunc[sent_id], trg=trg_sents_trunc[sent_id]).value() single_loss += train_loss[0] tt.reset_graph() batched_loss = MLELoss().calc_loss( model=model, src=mark_as_batch(src_sents_trunc), trg=mark_as_batch(trg_sents_trunc)).value() self.assertAlmostEqual(single_loss, np.sum(batched_loss), places=4)
def test_py_lstm_mask(self): layer_dim = 512 model = DefaultTranslator( src_reader=self.src_reader, trg_reader=self.trg_reader, src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100), encoder=PyramidalLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim, layers=1), attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim), decoder=AutoRegressiveDecoder( input_dim=layer_dim, embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100), rnn=UniLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim, decoder_input_dim=layer_dim, yaml_path="model.decoder.rnn"), transform=NonLinear(input_dim=layer_dim * 2, output_dim=layer_dim), scorer=Softmax(input_dim=layer_dim, vocab_size=100), bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)), ) batcher = batchers.TrgBatcher(batch_size=3) train_src, _ = \ batcher.pack(self.src_data, self.trg_data) event_trigger.set_train(True) for sent_i in range(3): tt.reset_graph() src = train_src[sent_i] event_trigger.start_sent(src) embeddings = model.src_embedder.embed_sent(src) encodings = model.encoder.transduce(embeddings) if train_src[sent_i].mask is None: assert encodings.mask is None else: np.testing.assert_array_almost_equal( train_src[sent_i].mask.np_arr, encodings.mask.np_arr)
def eval(self) -> 'metrics.EvalScore': """ Perform evaluation task. Returns: Evaluated score """ event_trigger.set_train(False) if self.src_data is None: self.src_data, self.ref_data, self.src_batches, self.ref_batches = \ input_readers.read_parallel_corpus(src_reader=self.model.src_reader, trg_reader=self.model.trg_reader, src_file=self.src_file, trg_file=self.ref_file, batcher=self.batcher, max_num_sents=self.max_num_sents, max_src_len=self.max_src_len, max_trg_len=self.max_trg_len) loss_val = losses.FactoredLossVal() ref_words_cnt = 0 for src, trg in zip(self.src_batches, self.ref_batches): with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}): tt.reset_graph() with torch.no_grad() if xnmt.backend_torch else utils.dummy_context_mgr(): loss = self.loss_calculator.calc_loss(self.model, src, trg) ref_words_cnt += sum([trg_i.len_unpadded() for trg_i in trg]) loss_val += loss.get_factored_loss_val() if settings.PRETEND: break loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()} self.src_data, self.trg_data, self.src_batches, self.trg_batches = None, None, None, None return metrics.LossScore(sum(loss_stats.values()), loss_stats=loss_stats, num_ref_words=ref_words_cnt, desc=self.desc)
def run_training(self, save_fct: Callable) -> None: """ Main training loop (overwrites TrainingRegimen.run_training()) """ if self.run_for_epochs is None or self.run_for_epochs > 0: for src, trg in self.next_minibatch(): if self.dev_zero: self.checkpoint_and_save(save_fct) self.dev_zero = False with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}), \ self.skip_out_of_memory: tt.reset_graph() with self.train_loss_tracker.time_tracker: event_trigger.set_train(True) loss_builder = self.training_step(src, trg) loss = loss_builder.compute( comb_method=self.loss_comb_method) self.backward(loss) self.update(self.trainer) self.train_loss_tracker.report( trg, loss_builder.get_factored_loss_val()) if self.checkpoint_needed(): self.checkpoint_and_save(save_fct) if self.should_stop_training(): break