Exemplo n.º 1
0
    def transduce(
        self, es: 'expression_seqs.ExpressionSequence'
    ) -> 'expression_seqs.ExpressionSequence':
        mask = es.mask
        # first layer
        forward_es = self.forward_layers[0].transduce(es)
        rev_backward_es = self.backward_layers[0].transduce(
            expression_seqs.ReversedExpressionSequence(es))

        for layer_i in range(1, len(self.forward_layers)):
            new_forward_es = self.forward_layers[layer_i].transduce([
                forward_es,
                expression_seqs.ReversedExpressionSequence(rev_backward_es)
            ])
            rev_backward_es = expression_seqs.ExpressionSequence(
                self.backward_layers[layer_i].transduce([
                    expression_seqs.ReversedExpressionSequence(forward_es),
                    rev_backward_es
                ]).as_list(),
                mask=mask)
            forward_es = new_forward_es

        self._final_states = [
          transducers.FinalTransducerState(tt.concatenate([self.forward_layers[layer_i].get_final_states()[0].main_expr(),
                                                           self.backward_layers[layer_i].get_final_states()[0].main_expr()]),
                                           tt.concatenate([self.forward_layers[layer_i].get_final_states()[0].cell_expr(),
                                                           self.backward_layers[layer_i].get_final_states()[0].cell_expr()])) \
          for layer_i in range(len(self.forward_layers))]
        return expression_seqs.ExpressionSequence(expr_list=[
            tt.concatenate([forward_es[i], rev_backward_es[-i - 1]])
            for i in range(forward_es.sent_len())
        ],
                                                  mask=mask)
Exemplo n.º 2
0
  def transduce(self, es: expression_seqs.ExpressionSequence) -> expression_seqs.ExpressionSequence:
    """
    returns the list of output Expressions obtained by adding the given inputs
    to the current state, one by one, to both the forward and backward RNNs,
    and concatenating.

    Args:
      es: an ExpressionSequence
    """
    es_list = [es]

    for layer_i, (fb, bb) in enumerate(self.builder_layers):
      reduce_factor = self._reduce_factor_for_layer(layer_i)

      if es_list[0].mask is None: mask_out = None
      else: mask_out = es_list[0].mask.lin_subsampled(reduce_factor)

      if self.downsampling_method=="concat" and es_list[0].sent_len() % reduce_factor != 0:
        raise ValueError(f"For 'concat' subsampling, sequence lengths must be multiples of the total reduce factor, "
                         f"but got sequence length={es_list[0].sent_len()} for reduce_factor={reduce_factor}. "
                         f"Set Batcher's pad_src_to_multiple argument accordingly.")
      fs = fb.transduce(es_list)
      bs = bb.transduce([expression_seqs.ReversedExpressionSequence(es_item) for es_item in es_list])
      if layer_i < len(self.builder_layers) - 1:
        if self.downsampling_method=="skip":
          es_list = [expression_seqs.ExpressionSequence(expr_list=fs[::reduce_factor], mask=mask_out),
                     expression_seqs.ExpressionSequence(expr_list=bs[::reduce_factor][::-1], mask=mask_out)]
        elif self.downsampling_method=="concat":
          es_len = es_list[0].sent_len()
          es_list_fwd = []
          es_list_bwd = []
          for i in range(0, es_len, reduce_factor):
            for j in range(reduce_factor):
              if i==0:
                es_list_fwd.append([])
                es_list_bwd.append([])
              es_list_fwd[j].append(fs[i+j])
              es_list_bwd[j].append(bs[es_list[0].sent_len()-reduce_factor+j-i])
          es_list = [expression_seqs.ExpressionSequence(expr_list=es_list_fwd[j], mask=mask_out) for j in range(reduce_factor)] + \
                    [expression_seqs.ExpressionSequence(expr_list=es_list_bwd[j], mask=mask_out) for j in range(reduce_factor)]
        else:
          raise RuntimeError(f"unknown downsampling_method {self.downsampling_method}")
      else:
        # concat final outputs
        ret_es = expression_seqs.ExpressionSequence(
          expr_list=[tt.concatenate([f, b]) for f, b in zip(fs, expression_seqs.ReversedExpressionSequence(bs))], mask=mask_out)

    self._final_states = [transducers.FinalTransducerState(tt.concatenate([fb.get_final_states()[0].main_expr(),
                                                                           bb.get_final_states()[0].main_expr()]),
                                                           tt.concatenate([fb.get_final_states()[0].cell_expr(),
                                                                           bb.get_final_states()[0].cell_expr()])) \
                          for (fb, bb) in self.builder_layers]
    return ret_es
Exemplo n.º 3
0
 def _calc_transform(
         self, mlp_dec_state: AutoRegressiveDecoderState) -> tt.Tensor:
     h = tt.concatenate([
         mlp_dec_state.rnn_state.output(),
         mlp_dec_state.context.squeeze(1)
         if xnmt.backend_torch else mlp_dec_state.context
     ])
     return self.transform.transform(h)
Exemplo n.º 4
0
    def transduce(
        self, expr_seq: 'expression_seqs.ExpressionSequence'
    ) -> 'expression_seqs.ExpressionSequence':
        """
    transduce the sequence, applying masks if given (masked timesteps simply copy previous h / c)

    Args:
      expr_seq: expression sequence or list of expression sequences (where each inner list will be concatenated)
    Returns:
      expression sequence
    """
        if isinstance(expr_seq, expression_seqs.ExpressionSequence):
            expr_seq = [expr_seq]
        concat_inputs = len(expr_seq) >= 2
        batch_size = tt.batch_size(expr_seq[0][0])
        seq_len = expr_seq[0].sent_len()
        mask = expr_seq[0].mask

        if self.dropout_rate > 0.0 and self.train:
            self.set_dropout_masks(batch_size=batch_size)

        cur_input = expr_seq
        self._final_states = []
        for layer_i in range(self.num_layers):
            h = [tt.zeroes(hidden_dim=self.hidden_dim, batch_size=batch_size)]
            c = [tt.zeroes(hidden_dim=self.hidden_dim, batch_size=batch_size)]
            for pos_i in range(seq_len):
                if concat_inputs and layer_i == 0:
                    x_t = tt.concatenate(
                        [cur_input[i][pos_i] for i in range(len(cur_input))])
                else:
                    x_t = cur_input[0][pos_i]
                h_tm1 = h[-1]
                if self.dropout_rate > 0.0 and self.train:
                    # apply dropout according to https://arxiv.org/abs/1512.05287 (tied weights)
                    x_t = torch.mul(x_t, self.dropout_mask_x[layer_i])
                    h_tm1 = torch.mul(h_tm1, self.dropout_mask_h[layer_i])
                h_t, c_t = self.layers[layer_i](x_t, (h_tm1, c[-1]))
                if mask is None or np.isclose(
                        np.sum(mask.np_arr[:, pos_i:pos_i + 1]), 0.0):
                    c.append(c_t)
                    h.append(h_t)
                else:
                    c.append(
                        mask.cmult_by_timestep_expr(c_t, pos_i, True) +
                        mask.cmult_by_timestep_expr(c[-1], pos_i, False))
                    h.append(
                        mask.cmult_by_timestep_expr(h_t, pos_i, True) +
                        mask.cmult_by_timestep_expr(h[-1], pos_i, False))
            self._final_states.append(
                transducers.FinalTransducerState(h[-1], c[-1]))
            cur_input = [h[1:]]

        return expression_seqs.ExpressionSequence(expr_list=h[1:], mask=mask)
Exemplo n.º 5
0
    def initial_state(self, enc_final_states: Any,
                      ss: Any) -> AutoRegressiveDecoderState:
        """Get the initial state of the decoder given the encoder final states.

    Args:
      enc_final_states: The encoder final states. Usually but not necessarily an :class:`xnmt.expression_sequence.ExpressionSequence`
      ss: first input
    Returns:
      initial decoder state
    """
        rnn_state = self.rnn.initial_state()
        rnn_s = self.bridge.decoder_init(enc_final_states)
        rnn_state = rnn_state.set_s(rnn_s)
        ss_expr = self.embedder.embed(ss)
        zeros = tt.zeroes(
            hidden_dim=self.input_dim,
            batch_size=tt.batch_size(ss_expr)) if self.input_feeding else None
        rnn_state = rnn_state.add_input(
            tt.concatenate([ss_expr, zeros]) if self.input_feeding else ss_expr
        )
        return AutoRegressiveDecoderState(rnn_state=rnn_state, context=zeros)
Exemplo n.º 6
0
 def transduce(
     self, src: expression_seqs.ExpressionSequence
 ) -> expression_seqs.ExpressionSequence:
     sent_len = src.sent_len()
     batch_size = tt.batch_size(src[0])
     embeddings = self.embeddings(
         torch.tensor([list(range(sent_len))] * batch_size).to(xnmt.device))
     # embeddings = dy.strided_select(dy.parameter(self.embedder), [1,1], [0,0], [self.input_dim, sent_len])
     if self.op == 'sum':
         output = embeddings + src.as_tensor()
     elif self.op == 'concat':
         output = tt.concatenate([embeddings, src.as_tensor()])
     else:
         raise ValueError(
             f'Illegal op {op} in PositionalTransducer (options are "sum"/"concat")'
         )
     if self.train and self.dropout > 0.0:
         output = tt.dropout(output, self.dropout)
     output_seq = expression_seqs.ExpressionSequence(expr_tensor=output,
                                                     mask=src.mask)
     self._final_states = [transducers.FinalTransducerState(output_seq[-1])]
     return output_seq
Exemplo n.º 7
0
    def add_input(self, dec_state: AutoRegressiveDecoderState,
                  trg_word: Any) -> AutoRegressiveDecoderState:
        """
    Add an input and return a *new* update the state.

    Args:
      dec_state: An object containing the current state.
      trg_word: The word to input.
    Returns:
      The updated decoder state.
    """
        trg_embedding = self.embedder.embed(trg_word)
        inp = trg_embedding
        if self.input_feeding:
            inp = tt.concatenate([
                inp,
                dec_state.context.squeeze(1)
                if xnmt.backend_torch else dec_state.context
            ])
        rnn_state = dec_state.rnn_state
        return AutoRegressiveDecoderState(rnn_state=rnn_state.add_input(inp),
                                          context=dec_state.context)