def calc_nll(self, src, trg): if not batchers.is_batched(src): src = batchers.ListBatch([src]) src_inputs = batchers.ListBatch( [s[:-1] for s in src], mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None) src_targets = batchers.ListBatch( [s[1:] for s in src], mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None) event_trigger.start_sent(src) embeddings = self.src_embedder.embed_sent(src_inputs) encodings = self.rnn.transduce(embeddings) encodings_tensor = encodings.as_tensor() ((hidden_dim, seq_len), batch_size) = encodings.dim() encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim, ), batch_size=batch_size * seq_len) outputs = self.transform.transform(encoding_reshaped) ref_action = np.asarray([sent.words for sent in src_targets]).reshape( (seq_len * batch_size, )) loss_expr_perstep = self.scorer.calc_loss( outputs, batchers.mark_as_batch(ref_action)) loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len, ), batch_size=batch_size) if src_targets.mask: loss_expr_perstep = dy.cmult( loss_expr_perstep, dy.inputTensor(1.0 - src_targets.mask.np_arr.T, batched=True)) loss = dy.sum_elems(loss_expr_perstep) return loss
def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \ -> tt.Tensor: if not batchers.is_batched(src): src = batchers.ListBatch([src]) src_inputs = batchers.ListBatch( [s[:-1] for s in src], mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None) src_targets = batchers.ListBatch( [s[1:] for s in src], mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None) event_trigger.start_sent(src) embeddings = self.src_embedder.embed_sent(src_inputs) encodings = self.rnn.transduce(embeddings) encodings_tensor = encodings.as_tensor() encoding_reshaped = tt.merge_time_batch_dims(encodings_tensor) seq_len = tt.sent_len(encodings_tensor) batch_size = tt.batch_size(encodings_tensor) outputs = self.transform.transform(encoding_reshaped) ref_action = np.asarray([sent.words for sent in src_targets]).reshape( (seq_len * batch_size, )) loss_expr_perstep = self.scorer.calc_loss( outputs, batchers.mark_as_batch(ref_action)) loss_expr_perstep = tt.unmerge_time_batch_dims(loss_expr_perstep, batch_size) loss = tt.aggregate_masked_loss(loss_expr_perstep, src_targets.mask) return loss
def generate_search_output(self, src: batchers.Batch, search_strategy: search_strategies.SearchStrategy, forced_trg_ids: batchers.Batch=None) -> List[search_strategies.SearchOutput]: """ Takes in a batch of source sentences and outputs a list of search outputs. Args: src: The source sentences search_strategy: The strategy with which to perform the search forced_trg_ids: The target IDs to generate if performing forced decoding Returns: A list of search outputs including scores, etc. """ if src.batch_size()!=1: raise NotImplementedError("batched decoding not implemented for DefaultTranslator. " "Specify inference batcher with batch size 1.") event_trigger.start_sent(src) all_src = src if isinstance(src, batchers.CompoundBatch): src = src.batches[0] # Generating outputs cur_forced_trg = None src_sent = src[0]#checkme sent_mask = None if src.mask: sent_mask = batchers.Mask(np_arr=src.mask.np_arr[0:1]) sent_batch = batchers.mark_as_batch([sent], mask=sent_mask) # Encode the sentence initial_state = self._encode_src(all_src) if forced_trg_ids is not None: cur_forced_trg = forced_trg_ids[0] search_outputs = search_strategy.generate_output(self, initial_state, src_length=[src_sent.sent_len()], forced_trg_ids=cur_forced_trg) return search_outputs
def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \ -> dy.Expression: src_inputs = batchers.ListBatch([s[:-1] for s in src], mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None) src_targets = batchers.ListBatch([s[1:] for s in src], mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None) embeddings = self.src_embedder.embed_sent(src_inputs) encodings = self.rnn.transduce(embeddings) encodings_tensor = encodings.as_tensor() ((hidden_dim, seq_len), batch_size) = encodings.dim() encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim,), batch_size=batch_size * seq_len) outputs = self.transform.transform(encoding_reshaped) ref_action = np.asarray([sent.words for sent in src_targets]).reshape((seq_len * batch_size,)) loss_expr_perstep = self.scorer.calc_loss(outputs, batchers.mark_as_batch(ref_action)) loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len,), batch_size=batch_size) if src_targets.mask: loss_expr_perstep = dy.cmult(loss_expr_perstep, dy.inputTensor(1.0-src_targets.mask.np_arr.T, batched=True)) loss = dy.sum_elems(loss_expr_perstep) units = [s.len_unpadded() for s in src] return LossExpr(loss, units)
def _cut_or_pad_targets(self, seq_len, trg): old_mask = trg.mask if len(trg[0]) > seq_len: trunc_len = len(trg[0]) - seq_len trg = batchers.mark_as_batch([trg_sent.get_truncated_sent(trunc_len=trunc_len) for trg_sent in trg]) if old_mask: trg.mask = batchers.Mask(np_arr=old_mask.np_arr[:, :-trunc_len]) else: pad_len = seq_len - len(trg[0]) trg = batchers.mark_as_batch([trg_sent.create_padded_sent(pad_len=pad_len) for trg_sent in trg]) if old_mask: trg.mask = np.pad(old_mask.np_arr, pad_width=((0, 0), (0, pad_len)), mode="constant", constant_values=1) return trg
def _cut_or_pad_targets(self, seq_len: numbers.Integral, trg: batchers.Batch) -> batchers.Batch: old_mask = trg.mask if trg.sent_len() > seq_len: trunc_len = trg.sent_len() - seq_len trg = batchers.mark_as_batch([ trg_sent.create_truncated_sent(trunc_len=trunc_len) for trg_sent in trg ]) if old_mask: trg.mask = batchers.Mask( np_arr=old_mask.np_arr[:, :-trunc_len]) else: pad_len = seq_len - trg.sent_len() trg = batchers.mark_as_batch([ trg_sent.create_padded_sent(pad_len=pad_len) for trg_sent in trg ]) if old_mask: trg.mask = np.pad(old_mask.np_arr, pad_width=((0, 0), (0, pad_len)), mode="constant", constant_values=1) return trg
def transduce(self, x): # some preparations output_states = [] current_state = self._encode_src(x, apply_emb=False) if self.mode_transduce == "split": first_state = SymmetricDecoderState( rnn_state=current_state.rnn_state, context=current_state.context) batch_size = x.dim()[1] done = [False] * batch_size out_mask = batchers.Mask(np_arr=np.zeros((batch_size, self.max_dec_len))) out_mask.np_arr.flags.writeable = True # teacher / split mode: unfold guided by reference targets # -> feed everything up unto (except) the last token back into the LSTM # other modes: unfold until EOS is output or max len is reached max_dec_len = self.cur_src.batches[1].sent_len( ) if self.mode_transduce in ["teacher", "split"] else self.max_dec_len atts_list = [] generated_word_ids = [] for pos in range(max_dec_len): if self.train and self.mode_transduce in ["teacher", "split"]: # unroll RNN guided by reference prev_ref_action, ref_action = None, None if pos > 0: prev_ref_action = self._batch_ref_action(pos - 1) if self.transducer_loss: ref_action = self._batch_ref_action(pos) step_loss = self.calc_loss_one_step( dec_state=current_state, batch_size=batch_size, mode=self.mode_transduce, ref_action=ref_action, prev_ref_action=prev_ref_action) self.transducer_losses.append(step_loss) else: # inference # unroll RNN guided by model predictions if self.mode_transduce in ["teacher", "split"]: prev_ref_action = self._batch_max_action( batch_size, current_state, pos) else: prev_ref_action = None out_scores = self.generate_one_step( dec_state=current_state, mask=out_mask, cur_step=pos, batch_size=batch_size, mode=self.mode_transduce, prev_ref_action=prev_ref_action) word_id = np.argmax(out_scores.npvalue(), axis=0) word_id = word_id.reshape((word_id.size, )) generated_word_ids.append(word_id[0]) for batch_i in range(batch_size): if self._terminate_rnn(batch_i=batch_i, pos=pos, batched_word_id=word_id): done[batch_i] = True out_mask.np_arr[batch_i, pos + 1:] = 1.0 if pos > 0 and all(done): atts_list.append(self.attender.get_last_attention()) output_states.append(current_state.rnn_state.h()[-1]) break output_states.append(current_state.rnn_state.h()[-1]) atts_list.append(self.attender.get_last_attention()) if self.mode_transduce == "split": # split mode: use attentions to compute context, then run RNNs over these context inputs if self.split_regularizer: assert len(atts_list) == len( self._chosen_rnn_inputs ), f"{len(atts_list)} != {len(self._chosen_rnn_inputs)}" split_output_states = [] split_rnn_state = first_state.rnn_state for pos, att in enumerate(atts_list): lstm_input_context = self.attender.curr_sent.as_tensor( ) * att # TODO: better reuse the already computed context vecs lstm_input_context = dy.reshape( lstm_input_context, (lstm_input_context.dim()[0][0], ), batch_size=batch_size) if self.split_dual: lstm_input_label = self._chosen_rnn_inputs[pos] if self.split_dual[0] > 0.0 and self.train: lstm_input_context = dy.dropout_batch( lstm_input_context, self.split_dual[0]) if self.split_dual[1] > 0.0 and self.train: lstm_input_label = dy.dropout_batch( lstm_input_label, self.split_dual[1]) if self.split_context_transform: lstm_input_context = self.split_context_transform.transform( lstm_input_context) lstm_input_context = self.split_dual_proj.transform( dy.concatenate([lstm_input_context, lstm_input_label])) if self.split_regularizer and pos < len( self._chosen_rnn_inputs): # _chosen_rnn_inputs does not contain first (empty) input, so this is in fact like comparing to pos-1: penalty = dy.squared_norm(lstm_input_context - self._chosen_rnn_inputs[pos]) if self.split_regularizer != 1: penalty = self.split_regularizer * penalty self.split_reg_penalty_expr = penalty split_rnn_state = split_rnn_state.add_input(lstm_input_context) split_output_states.append(split_rnn_state.h()[-1]) assert len(output_states) == len(split_output_states) output_states = split_output_states out_mask.np_arr = out_mask.np_arr[:, :len(output_states)] self._final_states = [] if self.compute_report: # for symmetric reporter (this can only be run at inference time) assert batch_size == 1 atts_matrix = np.asarray([att.npvalue() for att in atts_list ]).reshape(len(atts_list), atts_list[0].dim()[0][0]).T self.report_sent_info({ "symm_att": atts_matrix, "symm_out": sent.SimpleSentence( words=generated_word_ids, idx=self.cur_src.batches[0][0].idx, vocab=self.cur_src.batches[1][0].vocab, output_procs=self.cur_src.batches[1][0].output_procs), "symm_ref": self.cur_src.batches[1][0] if isinstance( self.cur_src, batchers.CompoundBatch) else None }) # prepare final outputs for layer_i in range(len(current_state.rnn_state.h())): self._final_states.append( transducers.FinalTransducerState( main_expr=current_state.rnn_state.h()[layer_i], cell_expr=current_state.rnn_state._c[layer_i])) out_mask.np_arr.flags.writeable = False return expression_seqs.ExpressionSequence(expr_list=output_states, mask=out_mask)