示例#1
0
    def calc_nll(self, src, trg):
        if not batchers.is_batched(src):
            src = batchers.ListBatch([src])

        src_inputs = batchers.ListBatch(
            [s[:-1] for s in src],
            mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None)
        src_targets = batchers.ListBatch(
            [s[1:] for s in src],
            mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None)

        event_trigger.start_sent(src)
        embeddings = self.src_embedder.embed_sent(src_inputs)
        encodings = self.rnn.transduce(embeddings)
        encodings_tensor = encodings.as_tensor()
        ((hidden_dim, seq_len), batch_size) = encodings.dim()
        encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim, ),
                                       batch_size=batch_size * seq_len)
        outputs = self.transform.transform(encoding_reshaped)

        ref_action = np.asarray([sent.words for sent in src_targets]).reshape(
            (seq_len * batch_size, ))
        loss_expr_perstep = self.scorer.calc_loss(
            outputs, batchers.mark_as_batch(ref_action))
        loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len, ),
                                       batch_size=batch_size)
        if src_targets.mask:
            loss_expr_perstep = dy.cmult(
                loss_expr_perstep,
                dy.inputTensor(1.0 - src_targets.mask.np_arr.T, batched=True))
        loss = dy.sum_elems(loss_expr_perstep)

        return loss
示例#2
0
    def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \
            -> tt.Tensor:
        if not batchers.is_batched(src):
            src = batchers.ListBatch([src])

        src_inputs = batchers.ListBatch(
            [s[:-1] for s in src],
            mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None)
        src_targets = batchers.ListBatch(
            [s[1:] for s in src],
            mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None)

        event_trigger.start_sent(src)
        embeddings = self.src_embedder.embed_sent(src_inputs)
        encodings = self.rnn.transduce(embeddings)
        encodings_tensor = encodings.as_tensor()

        encoding_reshaped = tt.merge_time_batch_dims(encodings_tensor)
        seq_len = tt.sent_len(encodings_tensor)
        batch_size = tt.batch_size(encodings_tensor)

        outputs = self.transform.transform(encoding_reshaped)

        ref_action = np.asarray([sent.words for sent in src_targets]).reshape(
            (seq_len * batch_size, ))
        loss_expr_perstep = self.scorer.calc_loss(
            outputs, batchers.mark_as_batch(ref_action))

        loss_expr_perstep = tt.unmerge_time_batch_dims(loss_expr_perstep,
                                                       batch_size)

        loss = tt.aggregate_masked_loss(loss_expr_perstep, src_targets.mask)

        return loss
示例#3
0
  def generate_search_output(self,
                             src: batchers.Batch,
                             search_strategy: search_strategies.SearchStrategy,
                             forced_trg_ids: batchers.Batch=None) -> List[search_strategies.SearchOutput]:
    """
    Takes in a batch of source sentences and outputs a list of search outputs.
    Args:
      src: The source sentences
      search_strategy: The strategy with which to perform the search
      forced_trg_ids: The target IDs to generate if performing forced decoding
    Returns:
      A list of search outputs including scores, etc.
    """
    if src.batch_size()!=1:
      raise NotImplementedError("batched decoding not implemented for DefaultTranslator. "
                                "Specify inference batcher with batch size 1.")
    event_trigger.start_sent(src)
    all_src = src
    if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
    # Generating outputs
    cur_forced_trg = None
    src_sent = src[0]#checkme
    sent_mask = None
    if src.mask: sent_mask = batchers.Mask(np_arr=src.mask.np_arr[0:1])
    sent_batch = batchers.mark_as_batch([sent], mask=sent_mask)

    # Encode the sentence
    initial_state = self._encode_src(all_src)

    if forced_trg_ids is  not None: cur_forced_trg = forced_trg_ids[0]
    search_outputs = search_strategy.generate_output(self, initial_state,
                                                     src_length=[src_sent.sent_len()],
                                                     forced_trg_ids=cur_forced_trg)
    return search_outputs
示例#4
0
  def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \
          -> dy.Expression:
    src_inputs = batchers.ListBatch([s[:-1] for s in src], mask=batchers.Mask(src.mask.np_arr[:, :-1]) if src.mask else None)
    src_targets = batchers.ListBatch([s[1:] for s in src], mask=batchers.Mask(src.mask.np_arr[:, 1:]) if src.mask else None)

    embeddings = self.src_embedder.embed_sent(src_inputs)
    encodings = self.rnn.transduce(embeddings)
    encodings_tensor = encodings.as_tensor()
    ((hidden_dim, seq_len), batch_size) = encodings.dim()
    encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim,), batch_size=batch_size * seq_len)
    outputs = self.transform.transform(encoding_reshaped)

    ref_action = np.asarray([sent.words for sent in src_targets]).reshape((seq_len * batch_size,))
    loss_expr_perstep = self.scorer.calc_loss(outputs, batchers.mark_as_batch(ref_action))
    loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len,), batch_size=batch_size)
    if src_targets.mask:
      loss_expr_perstep = dy.cmult(loss_expr_perstep, dy.inputTensor(1.0-src_targets.mask.np_arr.T, batched=True))
    loss = dy.sum_elems(loss_expr_perstep)
    units = [s.len_unpadded() for s in src]
    return LossExpr(loss, units)
示例#5
0
 def _cut_or_pad_targets(self, seq_len, trg):
   old_mask = trg.mask
   if len(trg[0]) > seq_len:
     trunc_len = len(trg[0]) - seq_len
     trg = batchers.mark_as_batch([trg_sent.get_truncated_sent(trunc_len=trunc_len) for trg_sent in trg])
     if old_mask:
       trg.mask = batchers.Mask(np_arr=old_mask.np_arr[:, :-trunc_len])
   else:
     pad_len = seq_len - len(trg[0])
     trg = batchers.mark_as_batch([trg_sent.create_padded_sent(pad_len=pad_len) for trg_sent in trg])
     if old_mask:
       trg.mask = np.pad(old_mask.np_arr, pad_width=((0, 0), (0, pad_len)), mode="constant", constant_values=1)
   return trg
示例#6
0
 def _cut_or_pad_targets(self, seq_len: numbers.Integral,
                         trg: batchers.Batch) -> batchers.Batch:
     old_mask = trg.mask
     if trg.sent_len() > seq_len:
         trunc_len = trg.sent_len() - seq_len
         trg = batchers.mark_as_batch([
             trg_sent.create_truncated_sent(trunc_len=trunc_len)
             for trg_sent in trg
         ])
         if old_mask:
             trg.mask = batchers.Mask(
                 np_arr=old_mask.np_arr[:, :-trunc_len])
     else:
         pad_len = seq_len - trg.sent_len()
         trg = batchers.mark_as_batch([
             trg_sent.create_padded_sent(pad_len=pad_len)
             for trg_sent in trg
         ])
         if old_mask:
             trg.mask = np.pad(old_mask.np_arr,
                               pad_width=((0, 0), (0, pad_len)),
                               mode="constant",
                               constant_values=1)
     return trg
示例#7
0
 def transduce(self, x):
     # some preparations
     output_states = []
     current_state = self._encode_src(x, apply_emb=False)
     if self.mode_transduce == "split":
         first_state = SymmetricDecoderState(
             rnn_state=current_state.rnn_state,
             context=current_state.context)
     batch_size = x.dim()[1]
     done = [False] * batch_size
     out_mask = batchers.Mask(np_arr=np.zeros((batch_size,
                                               self.max_dec_len)))
     out_mask.np_arr.flags.writeable = True
     # teacher / split mode: unfold guided by reference targets
     #  -> feed everything up unto (except) the last token back into the LSTM
     # other modes: unfold until EOS is output or max len is reached
     max_dec_len = self.cur_src.batches[1].sent_len(
     ) if self.mode_transduce in ["teacher", "split"] else self.max_dec_len
     atts_list = []
     generated_word_ids = []
     for pos in range(max_dec_len):
         if self.train and self.mode_transduce in ["teacher", "split"]:
             # unroll RNN guided by reference
             prev_ref_action, ref_action = None, None
             if pos > 0:
                 prev_ref_action = self._batch_ref_action(pos - 1)
             if self.transducer_loss:
                 ref_action = self._batch_ref_action(pos)
             step_loss = self.calc_loss_one_step(
                 dec_state=current_state,
                 batch_size=batch_size,
                 mode=self.mode_transduce,
                 ref_action=ref_action,
                 prev_ref_action=prev_ref_action)
             self.transducer_losses.append(step_loss)
         else:  # inference
             # unroll RNN guided by model predictions
             if self.mode_transduce in ["teacher", "split"]:
                 prev_ref_action = self._batch_max_action(
                     batch_size, current_state, pos)
             else:
                 prev_ref_action = None
             out_scores = self.generate_one_step(
                 dec_state=current_state,
                 mask=out_mask,
                 cur_step=pos,
                 batch_size=batch_size,
                 mode=self.mode_transduce,
                 prev_ref_action=prev_ref_action)
             word_id = np.argmax(out_scores.npvalue(), axis=0)
             word_id = word_id.reshape((word_id.size, ))
             generated_word_ids.append(word_id[0])
             for batch_i in range(batch_size):
                 if self._terminate_rnn(batch_i=batch_i,
                                        pos=pos,
                                        batched_word_id=word_id):
                     done[batch_i] = True
                     out_mask.np_arr[batch_i, pos + 1:] = 1.0
             if pos > 0 and all(done):
                 atts_list.append(self.attender.get_last_attention())
                 output_states.append(current_state.rnn_state.h()[-1])
                 break
         output_states.append(current_state.rnn_state.h()[-1])
         atts_list.append(self.attender.get_last_attention())
     if self.mode_transduce == "split":
         # split mode: use attentions to compute context, then run RNNs over these context inputs
         if self.split_regularizer:
             assert len(atts_list) == len(
                 self._chosen_rnn_inputs
             ), f"{len(atts_list)} != {len(self._chosen_rnn_inputs)}"
         split_output_states = []
         split_rnn_state = first_state.rnn_state
         for pos, att in enumerate(atts_list):
             lstm_input_context = self.attender.curr_sent.as_tensor(
             ) * att  # TODO: better reuse the already computed context vecs
             lstm_input_context = dy.reshape(
                 lstm_input_context, (lstm_input_context.dim()[0][0], ),
                 batch_size=batch_size)
             if self.split_dual:
                 lstm_input_label = self._chosen_rnn_inputs[pos]
                 if self.split_dual[0] > 0.0 and self.train:
                     lstm_input_context = dy.dropout_batch(
                         lstm_input_context, self.split_dual[0])
                 if self.split_dual[1] > 0.0 and self.train:
                     lstm_input_label = dy.dropout_batch(
                         lstm_input_label, self.split_dual[1])
                 if self.split_context_transform:
                     lstm_input_context = self.split_context_transform.transform(
                         lstm_input_context)
                 lstm_input_context = self.split_dual_proj.transform(
                     dy.concatenate([lstm_input_context, lstm_input_label]))
             if self.split_regularizer and pos < len(
                     self._chosen_rnn_inputs):
                 # _chosen_rnn_inputs does not contain first (empty) input, so this is in fact like comparing to pos-1:
                 penalty = dy.squared_norm(lstm_input_context -
                                           self._chosen_rnn_inputs[pos])
                 if self.split_regularizer != 1:
                     penalty = self.split_regularizer * penalty
                 self.split_reg_penalty_expr = penalty
             split_rnn_state = split_rnn_state.add_input(lstm_input_context)
             split_output_states.append(split_rnn_state.h()[-1])
         assert len(output_states) == len(split_output_states)
         output_states = split_output_states
     out_mask.np_arr = out_mask.np_arr[:, :len(output_states)]
     self._final_states = []
     if self.compute_report:
         # for symmetric reporter (this can only be run at inference time)
         assert batch_size == 1
         atts_matrix = np.asarray([att.npvalue() for att in atts_list
                                   ]).reshape(len(atts_list),
                                              atts_list[0].dim()[0][0]).T
         self.report_sent_info({
             "symm_att":
             atts_matrix,
             "symm_out":
             sent.SimpleSentence(
                 words=generated_word_ids,
                 idx=self.cur_src.batches[0][0].idx,
                 vocab=self.cur_src.batches[1][0].vocab,
                 output_procs=self.cur_src.batches[1][0].output_procs),
             "symm_ref":
             self.cur_src.batches[1][0] if isinstance(
                 self.cur_src, batchers.CompoundBatch) else None
         })
     # prepare final outputs
     for layer_i in range(len(current_state.rnn_state.h())):
         self._final_states.append(
             transducers.FinalTransducerState(
                 main_expr=current_state.rnn_state.h()[layer_i],
                 cell_expr=current_state.rnn_state._c[layer_i]))
     out_mask.np_arr.flags.writeable = False
     return expression_seqs.ExpressionSequence(expr_list=output_states,
                                               mask=out_mask)