def calc_nll(self, src, trg): if not batchers.is_batched(src): src = batchers.ListBatch([src]) src_inputs = batchers.ListBatch( [s[:-1] for s in src], mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None) src_targets = batchers.ListBatch( [s[1:] for s in src], mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None) event_trigger.start_sent(src) embeddings = self.src_embedder.embed_sent(src_inputs) encodings = self.rnn.transduce(embeddings) encodings_tensor = encodings.as_tensor() ((hidden_dim, seq_len), batch_size) = encodings.dim() encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim, ), batch_size=batch_size * seq_len) outputs = self.transform.transform(encoding_reshaped) ref_action = np.asarray([sent.words for sent in src_targets]).reshape( (seq_len * batch_size, )) loss_expr_perstep = self.scorer.calc_loss( outputs, batchers.mark_as_batch(ref_action)) loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len, ), batch_size=batch_size) if src_targets.mask: loss_expr_perstep = dy.cmult( loss_expr_perstep, dy.inputTensor(1.0 - src_targets.mask.np_arr.T, batched=True)) loss = dy.sum_elems(loss_expr_perstep) return loss
def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \ -> tt.Tensor: if not batchers.is_batched(src): src = batchers.ListBatch([src]) src_inputs = batchers.ListBatch( [s[:-1] for s in src], mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None) src_targets = batchers.ListBatch( [s[1:] for s in src], mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None) event_trigger.start_sent(src) embeddings = self.src_embedder.embed_sent(src_inputs) encodings = self.rnn.transduce(embeddings) encodings_tensor = encodings.as_tensor() encoding_reshaped = tt.merge_time_batch_dims(encodings_tensor) seq_len = tt.sent_len(encodings_tensor) batch_size = tt.batch_size(encodings_tensor) outputs = self.transform.transform(encoding_reshaped) ref_action = np.asarray([sent.words for sent in src_targets]).reshape( (seq_len * batch_size, )) loss_expr_perstep = self.scorer.calc_loss( outputs, batchers.mark_as_batch(ref_action)) loss_expr_perstep = tt.unmerge_time_batch_dims(loss_expr_perstep, batch_size) loss = tt.aggregate_masked_loss(loss_expr_perstep, src_targets.mask) return loss
def run_training(self, save_fct: Callable) -> None: """ Main training loop (overwrites TrainingRegimen.run_training()) """ dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) if self.run_for_epochs is None or self.run_for_epochs > 0: total_loss = losses.FactoredLossExpr() # Needed for report total_trg = [] for src, trg in self.next_minibatch(): if self.dev_zero: self.checkpoint_and_save(save_fct) self.dev_zero = False with utils.ReportOnException({ "src": src, "trg": trg, "graph": utils.print_cg_conditional }): with self.train_loss_tracker.time_tracker: event_trigger.set_train(True) total_trg.append(trg[0]) loss_builder = self.training_step(src, trg) total_loss.add_factored_loss_expr(loss_builder) # num_updates_skipped is incremented in update but # we need to call backward before update if self.num_updates_skipped == self.update_every - 1: self.backward(total_loss.compute(), self.dynet_profiling) self.update(self.trainer) if self.num_updates_skipped == 0: total_loss_val = total_loss.get_factored_loss_val( comb_method=self.loss_comb_method) reported_trg = batchers.ListBatch(total_trg) self.train_loss_tracker.report(reported_trg, total_loss_val) total_loss = losses.FactoredLossExpr() total_trg = [] dy.renew_cg( immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) if self.checkpoint_needed(): # Do a last update before checkpoint # Force forward-backward for the last batch even if it's smaller than update_every self.num_updates_skipped = self.update_every - 1 self.backward(total_loss.compute(), self.dynet_profiling) self.update(self.trainer) total_loss_val = total_loss.get_factored_loss_val( comb_method=self.loss_comb_method) reported_trg = batchers.ListBatch(total_trg) self.train_loss_tracker.report(reported_trg, total_loss_val) total_loss = losses.FactoredLossExpr() total_trg = [] self.checkpoint_and_save(save_fct) if self.should_stop_training(): break
def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \ -> dy.Expression: h = self._encode_src(src) ids = trg.value if not batchers.is_batched( trg) else batchers.ListBatch([trg_i.value for trg_i in trg]) loss_expr = self.scorer.calc_loss(h, ids) return loss_expr
def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \ -> dy.Expression: src_inputs = batchers.ListBatch([s[:-1] for s in src], mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None) src_targets = batchers.ListBatch([s[1:] for s in src], mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None) embeddings = self.src_embedder.embed_sent(src_inputs) encodings = self.rnn.transduce(embeddings) encodings_tensor = encodings.as_tensor() ((hidden_dim, seq_len), batch_size) = encodings.dim() encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim,), batch_size=batch_size * seq_len) outputs = self.transform.transform(encoding_reshaped) ref_action = np.asarray([sent.words for sent in src_targets]).reshape((seq_len * batch_size,)) loss_expr_perstep = self.scorer.calc_loss(outputs, batchers.mark_as_batch(ref_action)) loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len,), batch_size=batch_size) if src_targets.mask: loss_expr_perstep = dy.cmult(loss_expr_perstep, dy.inputTensor(1.0-src_targets.mask.np_arr.T, batched=True)) loss = dy.sum_elems(loss_expr_perstep) units = [s.len_unpadded() for s in src] return LossExpr(loss, units)
def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \ -> LossExpr: if batchers.is_batched(trg): units = [t.len_unpadded() for t in trg] ids = batchers.ListBatch([t.value for t in trg]) else: units = trg.len_unpadded() ids = trg.value h = self._encode_src(src) loss_expr = self.scorer.calc_loss(h, ids) return LossExpr(loss_expr, units)
def generate(self, src, forced_trg_ids=None, **kwargs): event_trigger.start_sent(src) if isinstance(src, batchers.CompoundBatch): src = src.batches[0] outputs = [] batch_size = src.batch_size() score = batchers.ListBatch([[] for _ in range(batch_size)]) words = batchers.ListBatch([[] for _ in range(batch_size)]) done = [False] * batch_size initial_state = self._encode_src(src) current_state = initial_state attentions = [] for pos in range(self.max_dec_len): prev_ref_action = None if pos > 0 and self.mode_translate != "context": if forced_trg_ids is not None: prev_ref_action = batchers.mark_as_batch([ forced_trg_ids[batch_i][pos - 1] for batch_i in range(batch_size) ]) elif batch_size > 1: prev_ref_action = batchers.mark_as_batch( np.argmax(current_state.out_prob.npvalue(), axis=0)) else: prev_ref_action = batchers.mark_as_batch( [np.argmax(current_state.out_prob.npvalue(), axis=0)]) logsoft = self.generate_one_step(dec_state=current_state, batch_size=batch_size, mode=self.mode_translate, cur_step=pos, prev_ref_action=prev_ref_action) attentions.append(self.attender.get_last_attention().npvalue()) logsoft = logsoft.npvalue() logsoft = logsoft.reshape(logsoft.shape[0], batch_size) if forced_trg_ids is None: batched_word_id = np.argmax(logsoft, axis=0) batched_word_id = batched_word_id.reshape( (batched_word_id.size, )) else: batched_word_id = [ forced_trg_batch_elem[pos] for forced_trg_batch_elem in forced_trg_ids ] for batch_i in range(batch_size): if done[batch_i]: batch_word = vocabs.Vocab.ES batch_score = 0.0 else: batch_word = batched_word_id[batch_i] batch_score = logsoft[batch_word, batch_i] if self._terminate_rnn(batch_i=batch_i, pos=pos, batched_word_id=batched_word_id): done[batch_i] = True score[batch_i].append(batch_score) words[batch_i].append(batch_word) if all(done): break for batch_i in range(batch_size): batch_elem_score = sum(score[batch_i]) outputs.append( sent.SimpleSentence(words=words[batch_i], idx=src[batch_i].idx, vocab=getattr(self.trg_reader, "vocab", None), score=batch_elem_score, output_procs=self.trg_reader.output_procs)) if self.compute_report: if batch_size > 1: cur_attentions = [x[:, :, batch_i] for x in attentions] else: cur_attentions = attentions attention = np.concatenate(cur_attentions, axis=1) self.report_sent_info({ "attentions": attention, "src": src[batch_i], "output": outputs[-1] }) return outputs
def calc_nll(self, src, trg): h = self._encode_src(src) ids = trg.value if not batchers.is_batched( trg) else batchers.ListBatch([trg_i.value for trg_i in trg]) loss_expr = self.scorer.calc_loss(h, ids) return loss_expr