Example #1
0
    def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
        """
                assert self.copy_attn is None  # TODO, no support yet.
        assert not self._coverage  # TODO, no support yet.

        attns = {}
        """
        assert self.copy_attn is None
        assert not self._coverage

        attns = {}
        index_select = [
            torch.index_select(a, 0, i).unsqueeze(0)
            for a, i in zip(torch.transpose(memory_bank, 0, 1),
                            torch.t(torch.squeeze(tgt, 2)))
        ]
        emb = torch.transpose(torch.cat(index_select), 0, 1)

        if isinstance(self.rnn, nn.GRU):
            rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0])
        else:
            rnn_output, dec_state = self.rnn(emb, self.state["hidden"])

        # Check
        tgt_len, tgt_batch, _ = tgt.size()
        output_len, output_batch, _ = rnn_output.size()
        aeq(tgt_len, output_len)
        aeq(tgt_batch, output_batch)

        # Calculate the attention
        p_attn = self.attn(rnn_output.transpose(0, 1).contiguous(),
                           memory_bank.transpose(0, 1),
                           memory_lengths=memory_lengths)
        attns["std"] = p_attn
        return dec_state, None, attns
Example #2
0
    def forward(self, src, img_feats, lengths=None):
        """See :func:`EncoderBase.forward()`"""
        self._check_args(src, lengths)

        emb = self.embeddings(src)

        #s_len, n_batch, emb_dim = emb.size()
        img_emb = self.img_to_emb(img_feats).unsqueeze(0)
        # prepend image "word"
        emb = torch.cat([img_emb, emb], dim=0)

        out = emb.transpose(0, 1).contiguous()
        words = src[:, :, 0].transpose(0, 1)
        # expand mask to account for image "word"
        words = torch.cat([words[:, 0:1], words], dim=1)

        # CHECKS
        out_batch, out_len, _ = out.size()
        w_batch, w_len = words.size()
        aeq(out_batch, w_batch)
        aeq(out_len, w_len)
        # END CHECKS

        # Make mask.
        padding_idx = self.embeddings.word_padding_idx
        mask = words.data.eq(padding_idx).unsqueeze(1) \
            .expand(w_batch, w_len, w_len)
        # Run the forward pass of every layer of the tranformer.
        for layer in self.transformer:
            out = layer(out, mask)
        out = self.layer_norm(out)

        return emb, out.transpose(0, 1).contiguous(), lengths
Example #3
0
    def forward(self, source, memory_bank, memory_lengths=None,
                memory_turns = None, coverage=None):
        # here we implement a hierarchical attention
        if source.dim() == 2:
            source = source.unsqueeze(1)

        batch, source_tl, source_wl, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        # word level attention
        word_align = self.word_score(source, memory_bank.contiguous()
                           .view(batch, -1, dim))

        # transform align (b, 1, tl * wl) -> (b * tl, 1, wl)
        word_align = word_align.view(batch * source_tl, 1, source_wl)
        if memory_lengths is not None:
            word_mask = sequence_mask_herd(memory_lengths.view(-1), max_len=word_align.size(-1))
            word_mask = word_mask.unsqueeze(1)  # Make it broadcastable.
            word_align.masked_fill_(1 - word_mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            word_align_vectors = F.softmax(word_align.view(batch * source_tl, source_wl), -1)
        else:
            word_align_vectors = sparsemax(word_align.view(batch * source_tl, source_wl), -1)

        # mask the all padded sentences
        sent_pad_mask = memory_lengths.view(-1).eq(0).unsqueeze(1)
        word_align_vectors = torch.mul(word_align_vectors,
                                       (1.0 - sent_pad_mask).type_as(word_align_vectors))
        word_align_vectors = word_align_vectors.view(batch * source_tl, target_l, source_wl)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        cw = torch.bmm(word_align_vectors, memory_bank.view(batch * source_tl, source_wl, -1))
        cw = cw.view(batch, source_tl, -1)
        # concat_cw = torch.cat([cw, source.repeat(1, source_tl, 1)], 2).view(batch*source_tl, -1)
        # attn_hw = self.word_linear_out(concat_cw).view(batch, source_tl, -1)
        # attn_hw = torch.tanh(attn_hw)

        # turn level attention
        turn_align = self.turn_score(source, cw)

        if memory_turns is not None:
            turn_mask = sequence_mask(memory_turns, max_len=turn_align.size(-1))
            turn_mask = turn_mask.unsqueeze(1)  # Make it broadcastable.
            turn_align.masked_fill_(1 - turn_mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            turn_align_vectors = F.softmax(turn_align.view(batch * target_l, source_tl), -1)
        else:
            turn_align_vectors = sparsemax(turn_align.view(batch * target_l, source_tl), -1)
        turn_align_vectors = turn_align_vectors.view(batch, target_l, source_tl)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        ct = torch.bmm(turn_align_vectors, cw)

        return ct.squeeze(1), None
Example #4
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        # here we do not need to calculate the align
        # because the answer vector is already averaged representations
        if source.dim() == 2:
            source = source.unsqueeze(1)

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch * target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch * target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        return c.squeeze(1), align_vectors
Example #5
0
    def forward(self, src, lengths=None):
        """ See :obj:`EncoderBase.forward()`"""
        self._check_args(src, lengths)

        emb = self.embeddings(src)

        out = emb.transpose(0, 1).contiguous()
        words = src[:, :, 0].transpose(0, 1)
        # CHECKS
        out_batch, out_len, _ = out.size()
        w_batch, w_len = words.size()
        aeq(out_batch, w_batch)
        aeq(out_len, w_len)
        # END CHECKS

        # Make mask.
        padding_idx = self.embeddings.word_padding_idx
        mask = words.data.eq(padding_idx).unsqueeze(1) \
            .expand(w_batch, w_len, w_len)
        # Run the forward pass of every layer of the tranformer.
        for i in range(self.num_layers):
            out = self.transformer[i](out, mask)
        out = self.layer_norm(out)

        return Variable(emb.data), out.transpose(0, 1).contiguous()
Example #6
0
    def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
        """
        See StdRNNDecoder._run_forward_pass() for description
        of arguments and return values.
        """
        # Additional args check.
        input_feed = self.state["input_feed"].squeeze(0)
        input_feed_batch, _ = input_feed.size()
        _, tgt_batch, _ = tgt.size()
        aeq(tgt_batch, input_feed_batch)
        # END Additional args check.

        dec_outs = []
        attns = {"std": []}
        if self.copy_attn is not None or self._reuse_copy_attn:
            attns["copy"] = []
        if self._coverage:
            attns["coverage"] = []

        emb = self.embeddings(tgt)
        assert emb.dim() == 3  # len x batch x embedding_dim

        dec_state = self.state["hidden"]
        coverage = self.state["coverage"].squeeze(0) \
            if self.state["coverage"] is not None else None

        # Input feed concatenates hidden state with
        # input at every time step.
        for emb_t in emb.split(1):
            decoder_input = torch.cat([emb_t.squeeze(0), input_feed], 1)
            rnn_output, dec_state = self.rnn(decoder_input, dec_state)
            decoder_output, p_attn = self.attn(
                rnn_output,
                memory_bank.transpose(0, 1),
                memory_lengths=memory_lengths)
            if self.context_gate is not None:
                # TODO: context gate should be employed
                # instead of second RNN transform.
                decoder_output = self.context_gate(
                    decoder_input, rnn_output, decoder_output
                )
            decoder_output = self.dropout(decoder_output)
            input_feed = decoder_output

            dec_outs += [decoder_output]
            attns["std"] += [p_attn]

            # Update the coverage attention.
            if self._coverage:
                coverage = p_attn if coverage is None else p_attn + coverage
                attns["coverage"] += [coverage]

            if self.copy_attn is not None:
                _, copy_attn = self.copy_attn(
                    decoder_output, memory_bank.transpose(0, 1))
                attns["copy"] += [copy_attn]
            elif self._reuse_copy_attn:
                attns["copy"] = attns["std"]

        return dec_state, dec_outs, attns
Example #7
0
    def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
        """
        Private helper for running the specific RNN forward pass.
        Must be overriden by all subclasses.

        Args:
            tgt (LongTensor): a sequence of input tokens tensors
                ``(len, batch, nfeats)``.
            memory_bank (FloatTensor): output(tensor sequence) from the
                encoder RNN of size ``(src_len, batch, hidden_size)``.
            memory_lengths (LongTensor): the source memory_bank lengths.

        Returns:
            (Tensor, List[FloatTensor], Dict[str, List[FloatTensor]):

            * dec_state: final hidden state from the decoder.
            * dec_outs: an array of output of every time
              step from the decoder.
            * attns: a dictionary of different
              type of attention Tensor array of every time
              step from the decoder.
        """

        assert self.copy_attn is None  # TODO, no support yet.
        assert not self._coverage  # TODO, no support yet.

        attns = {}
        emb = self.embeddings(tgt)

        if isinstance(self.rnn, nn.GRU):
            rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0])
        else:
            rnn_output, dec_state = self.rnn(emb, self.state["hidden"])

        # Check
        tgt_len, tgt_batch, _ = tgt.size()
        output_len, output_batch, _ = rnn_output.size()
        aeq(tgt_len, output_len)
        aeq(tgt_batch, output_batch)

        # Calculate the attention.
        if not self.attentional:
            dec_outs = rnn_output
        else:
            dec_outs, p_attn = self.attn(rnn_output.transpose(0,
                                                              1).contiguous(),
                                         memory_bank.transpose(0, 1),
                                         memory_lengths=memory_lengths)
            attns["std"] = p_attn

        # Calculate the context gate.
        if self.context_gate is not None:
            dec_outs = self.context_gate(
                emb.view(-1, emb.size(2)),
                rnn_output.view(-1, rnn_output.size(2)),
                dec_outs.view(-1, dec_outs.size(2)))
            dec_outs = dec_outs.view(tgt_len, tgt_batch, self.hidden_size)

        dec_outs = self.dropout(dec_outs)
        return dec_state, dec_outs, attns
Example #8
0
    def _compute_orthogonal_loss(self, sep_states):
        """
        The orthogonal loss computation function
        sep_states: a tuple (stacked_sep_states, sep_states_lens)
        :return: a scalar, the orthogonal loss
        """
        # stacked_sep_states: [b_size, max_sep_num, src_h_size]
        stacked_sep_states, sep_states_lens = sep_states
        b_size, max_sep_num, src_h_size = stacked_sep_states.size()
        b_size_ = len(sep_states_lens)
        aeq(b_size, b_size_)

        device = stacked_sep_states.device

        # obtain the mask
        # [b_size, max_sep_num]
        mask = sequence_mask(torch.Tensor(sep_states_lens)).to(device)
        mask = mask.float()
        # [b_size, 1, max_sep_num]
        mask = mask.unsqueeze(1)
        # [b_size, max_sep_num, max_sep_num]
        mask_2d = torch.bmm(mask.transpose(1, 2), mask)

        # compute the loss
        # [b_size, max_sep_num, max_sep_num]
        identity = torch.eye(max_sep_num).unsqueeze(0).repeat(b_size, 1, 1).to(device)
        # [b_size, max_sep_num, max_sep_num]
        orthogonal_loss_ = torch.bmm(stacked_sep_states, stacked_sep_states.transpose(1, 2)) - identity
        orthogonal_loss_ = orthogonal_loss_ * mask_2d
        # [b_size]
        orthogonal_loss = torch.norm(orthogonal_loss_.view(b_size, -1), p=2, dim=1)
        return orthogonal_loss
Example #9
0
    def forward(self, src1, src2):
        """ See :obj:`EncoderBase.forward()`"""
        # src: (seq_len, bsz, 1)
        emb1 = self.embeddings(src1)
        emb2 = self.embeddings(src2)
        # emb: (seq_len, bsz, dim)
        emb2_biased = emb2 + self.emb_bias
        emb = torch.cat([emb1, emb2_biased], dim=0)

        out = emb.transpose(0, 1).contiguous()
        src = torch.cat([src1, src2], dim=0)

        words = src[:, :, 0].transpose(0, 1)
        # CHECKS
        out_batch, out_len, _ = out.size()
        w_batch, w_len = words.size()
        aeq(out_batch, w_batch)
        aeq(out_len, w_len)
        # END CHECKS

        # Make mask.i
        padding_idx = self.embeddings.word_padding_idx
        mask = words.data.eq(padding_idx).unsqueeze(1) \
            .expand(w_batch, w_len, w_len)
        # Run the forward pass of every layer of the tranformer.
        for i in range(self.num_layers):
            out = self.transformer[i](out, mask)
        out = self.layer_norm(out)

        return Variable(emb.data), out.transpose(0, 1).contiguous()
Example #10
0
    def _compute_orthogonal_loss(self, batch, orthog_states):
        """
        The orthogonal loss computation function
        :param batch: the current batch
        :param orthog_states: the orthog_states from the sent level decoder
        :return: a scalar, the orthogonal loss
        """
        # [b_size, s_num, tgt_s_len-1]
        valid_tgt = batch.tgt[:, :, 1:]
        b_size, s_num, _ = valid_tgt.size()
        b_size1, s_num1, _ = orthog_states.size()
        aeq(b_size, b_size1)
        aeq(s_num, s_num1)

        # obtain the mask
        # [b_size, s_num]
        mask = valid_tgt.ne(self.padding_idx).sum(dim=-1).ne(0)
        mask = mask.float()
        # [b_size, 1, s_num]
        mask = mask.unsqueeze(1)
        # [b_size, s_num, s_num]
        mask_2d = torch.bmm(mask.transpose(1, 2), mask)

        # compute the loss
        # [b_size, s_num, s_num]
        identity = torch.eye(s_num).unsqueeze(0).repeat(b_size, 1, 1).to(orthog_states.device)
        # [b_size, s_num, s_num]
        orthogonal_loss_ = torch.bmm(orthog_states, orthog_states.transpose(1, 2)) - identity
        orthogonal_loss_ = orthogonal_loss_ * mask_2d
        # [b_size]
        orthogonal_loss = torch.norm(orthogonal_loss_.view(b_size, -1), p=2, dim=1)
        return orthogonal_loss
 def _check_args(self, src, lengths=None, hidden=None):
     if isinstance(src, dict):
         src = src['src']
     n_batch = src.size(1)
     if lengths is not None:
         n_batch_, = lengths.size()
         aeq(n_batch, n_batch_)
Example #12
0
    def forward(self, query, memory_bank, memory_lengths=None, **kwargs):
        """
        query (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
        memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
        memory_lengths (`LongTensor`): the source context lengths `[batch]`

        returns attention distribution (tgt_len x batch x src_len)
        """
        src_batch, src_len, src_dim = memory_bank.size()
        tgt_batch, tgt_len, tgt_dim = query.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)

        align = self.score(query, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            align.masked_fill_(~mask.unsqueeze(1), -float('inf'))

        #import pdb; pdb.set_trace()
        # it should not be necessary to view align as a 2d tensor, but
        # something is broken with sparsemax and it cannot handle a 3d tensor
        #print(align.size())
        #print(align.view(-1, src_len).size())

        #print(src_len)
        #print(src_batch)
        #return self.transform(align.view(-1, src_len), lengths=torch.tensor([src_len]*src_batch)).view_as(align)
        #return self.transform(align.view(-1, src_len), lengths=memory_lengths).view_as(align)
        return self.transform(align.view(-1, src_len)).view_as(align)
Example #13
0
    def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
        """
        Private helper for running the specific RNN forward pass.
        Must be overriden by all subclasses.
        Args:
            tgt (LongTensor): a sequence of input tokens tensors
                                 [len x batch x nfeats].
            memory_bank (FloatTensor): output(tensor sequence) from the
                          encoder RNN of size (src_len x batch x hidden_size).
            state (FloatTensor): hidden state from the encoder RNN for
                                 initializing the decoder.
            memory_lengths (LongTensor): the source memory_bank lengths.
        Returns:
            dec_state (Tensor): final hidden state from the decoder.
            dec_outs ([FloatTensor]): an array of output of every time
                                     step from the decoder.
            attns (dict of (str, [FloatTensor]): a dictionary of different
                            type of attention Tensor array of every time
                            step from the decoder.
        """
        assert not self._copy  # TODO, no support yet.
        assert not self._coverage  # TODO, no support yet.

        # Initialize local and return variables.
        attns = {}
        emb = self.embeddings(tgt)

        # Run the forward pass of the RNN.
        if isinstance(self.rnn, nn.GRU):
            rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0])
        else:
            rnn_output, dec_state = self.rnn(emb, self.state["hidden"])

        # Check
        tgt_len, tgt_batch, _ = tgt.size()
        output_len, output_batch, _ = rnn_output.size()
        aeq(tgt_len, output_len)
        aeq(tgt_batch, output_batch)
        # END

        # Calculate the attention.
        dec_outs, p_attn = self.attn(
            rnn_output.transpose(0, 1).contiguous(),
            memory_bank.transpose(0, 1),
            memory_lengths=memory_lengths
        )
        attns["std"] = p_attn

        # Calculate the context gate.
        if self.context_gate is not None:
            dec_outs = self.context_gate(
                emb.view(-1, emb.size(2)),
                rnn_output.view(-1, rnn_output.size(2)),
                dec_outs.view(-1, dec_outs.size(2))
            )
            dec_outs = \
                dec_outs.view(tgt_len, tgt_batch, self.hidden_size)

        dec_outs = self.dropout(dec_outs)
        return dec_state, dec_outs, attns
    def _example_dict_iter(self, line, index):
        line = line.split()
        if self.line_truncate:
            line = line[:self.line_truncate]
        if self.side == 'tgt':
            words, feats, n_feats = TextDataset.extract_text_features(line)
            example_dict = {self.side: words, "indices": index}
        else:
            feats = None
            n_feats = 0
            graph = AMR.extract_amr_features(line, reentrancies)
            words = graph.traverse

            example_dict = {
                self.side: words,
                self.side + "_graph": graph,
                "indices": index
            }
        if feats:
            # All examples must have same number of features.
            aeq(self.n_feats, n_feats)

            prefix = self.side + "_feat_"
            example_dict.update(
                (prefix + str(j), f) for j, f in enumerate(feats))

        return example_dict
Example #15
0
    def score(self, h_t, h_s, type):
        """
        Args:
          h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]`
          h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]`
          type: use word or sent matrix
        Returns:
          :obj:`FloatTensor`:
           raw attention scores (unnormalized) for each src index
          `[batch x tgt_len x src_len]`

        """

        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)

        h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim)
        if type == 'qa_word':
            h_t_ = self.qa_word_linear_in(h_t_)
        elif type == 'qa_sent':
            h_t_ = self.qa_sent_linear_in(h_t_)
        elif type == 'pass':
            h_t_ = self.pass_linear_in(h_t_)
        h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
        h_s_ = h_s.transpose(1, 2)
        # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
        return torch.bmm(h_t, h_s_)
Example #16
0
 def _check_args(self, src, lengths=None, hidden=None):
     if isinstance(src, tuple):
         src = src[0]
     _, n_batch, _ = src.size()
     if lengths is not None:
         n_batch_, = lengths.size()
         aeq(n_batch, n_batch_)
    def _run_forward_pass(self,
                          tgt,
                          memory_bank,
                          state,
                          answer,
                          memory_lengths=None):
        """
        See StdRNNDecoder._run_forward_pass() for description
        of arguments and return values.
        """
        # Additional args check.
        input_feed = state.input_feed.squeeze(0)
        input_feed_batch, _ = input_feed.size()
        _, tgt_batch = tgt.size()
        aeq(tgt_batch, input_feed_batch)
        # END Additional args check.

        # Initialize local and return variables.
        decoder_outputs = []
        attns = {"std": []}
        if self._copy:
            attns["copy"] = []

        emb = self.embeddings(tgt.unsqueeze(-1))
        assert emb.dim() == 3  # len x batch x embedding_dim

        hidden = state.hidden
        coverage = None

        # Input feed concatenates hidden state with
        # input at every time step.
        for _, emb_t in enumerate(emb.split(1)):
            emb_t = emb_t.squeeze(0)
            decoder_input = torch.cat([emb_t, input_feed], 1)

            rnn_output, hidden = self.rnn(decoder_input, hidden)

            # construct query [h, ans] to interact with sources
            query = torch.cat([rnn_output, answer], 1)

            decoder_output, p_attn = self.attn(query,
                                               memory_bank.transpose(0, 1),
                                               memory_lengths=memory_lengths)

            decoder_output = self.dropout(decoder_output)
            input_feed = decoder_output

            decoder_outputs += [decoder_output]
            attns["std"] += [p_attn]

            # Run the forward pass of the copy attention layer.
            if self._copy and not self._reuse_copy_attn:
                _, copy_attn = self.copy_attn(decoder_output,
                                              memory_bank.transpose(0, 1))
                attns["copy"] += [copy_attn]
            elif self._copy:
                attns["copy"] = attns["std"]
        # Return result.
        return hidden, decoder_outputs, attns
Example #18
0
    def forward(self, tgt, memory_bank, state, memory_lengths=None,
                step=None,sent_encoder=None,src_sents=None,dec=None):
        """
        Args:
            tgt (`LongTensor`): sequences of padded tokens
                 `[tgt_len x batch x nfeats]`.
            memory_bank (`FloatTensor`): vectors from the encoder
                 `[src_len x batch x hidden]`.
            state (:obj:`onmt.models.DecoderState`):
                 decoder state object to initialize the decoder
            memory_lengths (`LongTensor`): the padded source lengths
                `[batch]`.
        Returns:
            (`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`):
                * decoder_outputs: output from the decoder (after attn)
                         `[tgt_len x batch x hidden]`.
                * decoder_state: final hidden state from the decoder
                * attns: distribution over src at each tgt
                        `[tgt_len x batch x src_len]`.
        """
        # Check
        assert isinstance(state, RNNDecoderState)
        # tgt.size() returns tgt length and batch
        _, tgt_batch, _ = tgt.size()
        _, memory_batch, _ = memory_bank.size()
        aeq(tgt_batch, memory_batch)
        # END


        # 23333: TODO I changed this return value 'sent_decoder'

        # Run the forward pass of the RNN.
        decoder_final, decoder_outputs, attns = self._run_forward_pass(
            tgt, memory_bank, state, memory_lengths=memory_lengths,sent_encoder=sent_encoder,src_sents=src_sents,dec=dec)

        # Update the state with the result.
        final_output = decoder_outputs[-1]
        coverage = None
        if "coverage" in attns:
            coverage = attns["coverage"][-1].unsqueeze(0)
        state.update_state(decoder_final, final_output.unsqueeze(0), coverage)

        # Concatenates sequence of tensors along a new dimension.
        # NOTE: v0.3 to 0.4: decoder_outputs / attns[*] may not be list
        #       (in particular in case of SRU) it was not raising error in 0.3
        #       since stack(Variable) was allowed.
        #       In 0.4, SRU returns a tensor that shouldn't be stacke


        if type(decoder_outputs) == list:
            decoder_outputs = torch.stack(decoder_outputs)

            for k in attns:
                if type(attns[k]) == list:

                    attns[k] = torch.stack(attns[k])

        return decoder_outputs, state, attns
Example #19
0
    def forward(self,
                tgt,
                memory_bank,
                state,
                memory_lengths=None,
                wals_features=None,
                step=None):

        # Check
        assert isinstance(state, RNNDecoderStateDoublyAttentive)
        # tgt.size() returns tgt length and batch
        _, tgt_batch, _ = tgt.size()
        _, memory_batch, _ = memory_bank.size()
        _, wals_features_batch, _ = wals_features.size()
        aeq(tgt_batch, memory_batch)
        aeq(tgt_batch, wals_features_batch)
        # END

        # Run the forward pass of the RNN.

        decoder_final, decoder_outputs, decoder_outputs_wals, attns = self._run_forward_pass(
            tgt,
            memory_bank,
            state,
            wals_features=wals_features,
            memory_lengths=memory_lengths)

        # Update the state with the result.
        final_output = decoder_outputs[-1]
        final_output_wals = decoder_outputs_wals[-1]
        coverage = None
        coverage_wals = None

        if "coverage" in attns:
            coverage = attns["coverage"][-1].unsqueeze(0)
        if "coverage_wals" in attns:
            coverage_wals = attns["coverage_wals"][-1].unsqueeze(0)

        state.update_state(decoder_final, final_output.unsqueeze(0),
                           final_output_wals.unsqueeze(0), coverage,
                           coverage_wals)

        # Concatenates sequence of tensors along a new dimension.
        # NOTE: v0.3 to 0.4: decoder_outputs / attns[*] may not be list
        #       (in particular in case of SRU) it was not raising error in 0.3
        #       since stack(Variable) was allowed.
        #       In 0.4, SRU returns a tensor that shouldn't be stacke
        if type(decoder_outputs) == list:
            decoder_outputs = torch.stack(decoder_outputs)

        if type(decoder_outputs_wals) == list:
            decoder_outputs_wals = torch.stack(decoder_outputs_wals)

        for k in attns:
            if type(attns[k]) == list:
                attns[k] = torch.stack(attns[k])

        return decoder_outputs, decoder_outputs_wals, state, attns
Example #20
0
 def _check_args(self, src, lengths=None, hidden=None):
     #import pdb;pdb.set_trace()
     n_batch = src.size(1)
     if lengths is not None:
         n_batch_, = lengths.size()
         aeq(n_batch, n_batch_)
     if src.size(0) != max(lengths):
         lengths -= max(lengths) - src.size(0)
     return lengths
Example #21
0
    def _run_forward_pass(self,
                          tgt,
                          memory_bank,
                          state,
                          wals_features,
                          memory_lengths=None):

        assert not self._copy
        assert not self._coverage

        # Initialize local and return variables.
        attns = {}
        emb = self.embeddings(tgt)

        # Run the forward pass of the RNN.
        if isinstance(self.rnn, nn.GRU):
            rnn_output, decoder_final = self.rnn(emb, state.hidden[0])
        else:
            rnn_output, decoder_final = self.rnn(emb, state.hidden)

        # Check
        tgt_len, tgt_batch, _ = tgt.size()
        output_len, output_batch, _ = rnn_output.size()
        aeq(tgt_len, output_len)
        aeq(tgt_batch, output_batch)
        # END

        # Calculate the attention.
        decoder_outputs, p_attn = self.attn(rnn_output.transpose(
            0, 1).contiguous(),
                                            memory_bank.transpose(0, 1),
                                            memory_lengths=memory_lengths)
        attns["std"] = p_attn

        decoder_outputs_wals, p_attn_wals = self.attn_wals(
            rnn_output.transpose(0, 1).contiguous(),
            wals_features.transpose(0, 1), None)
        attns["std_wals"] = p_attn_wals

        # Calculate the context gate.
        if self.context_gate is not None:
            decoder_outputs = self.context_gate(
                emb.view(-1, emb.size(2)),
                rnn_output.view(-1, rnn_output.size(2)),
                decoder_outputs.view(-1, attn_outputs.size(2)))
            decoder_outputs = decoder_outputs.view(tgt_len, tgt_batch,
                                                   self.hidden_size)
            decoder_outputs = self.dropout(decoder_outputs)
        else:
            decoder_outputs = self.dropout(decoder_outputs)

        # no context gate on WALS features
        decoder_outputs_wals = self.dropout(decoder_outputs_wals)

        # Return result.
        return decoder_final, decoder_outputs, decoder_outputs_wals, attns
    def _run_forward_pass(self, tgt, word_memory_bank, sent_memory_bank,sent_context,
                          state, word_memory_lengths, sent_memory_lengths,
                          static_attn):
        """
        See StdRNNDecoder._run_forward_pass() for description
        of arguments and return values.
        """
        # Additional args check.
        input_feed = state.input_feed.squeeze(0)
        input_feed_batch, _ = input_feed.size()
        _, tgt_batch, _ = tgt.size()
        aeq(tgt_batch, input_feed_batch)
        # END Additional args check.

        # Initialize local and return variables.
        decoder_outputs = []
        attns = {"std": []}

        emb = self.embeddings(tgt)
        assert emb.dim() == 3  # len x batch x embedding_dim
        # topic
        #topic_emb = self.topic_emb(tgt)
        #topic_emb = topic_emb.squeeze(2)
        #emb = torch.cat((emb, topic_emb), 2)
        #emb = self.norm_linear_topic(emb)

        hidden = state.hidden

        # Input feed concatenates hidden state with
        # input at every time step.
        for outidx, emb_t in enumerate(emb.split(1)):
            # logger.info('generate %d word' %outidx)
            emb_t = emb_t.squeeze(0)
            decoder_input = torch.cat([emb_t, input_feed], 1)

            rnn_output, hidden = self.rnn(decoder_input, hidden)

            # attn
            decoder_output, attn = self.attn(
                rnn_output,
                word_memory_bank,
                word_memory_lengths,
                sent_memory_bank,
                sent_memory_lengths,
                sent_context,
                static_attn)

            decoder_output = self.dropout(decoder_output)
            input_feed = decoder_output

            decoder_outputs += [decoder_output]
            attns["std"] += [attn]

        # Return result.
        return hidden, decoder_outputs, attns
Example #23
0
    def forward(self, hidden, attn, src_map, align=None, ptrs=None, tags=None):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by copying
        source words.

        Args:
           hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
           attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
           src_map (FloatTensor):
               A sparse indicator matrix mapping each source word to
               its index in the "extended" vocab containing.
               ``(src_len, batch, extra_words)``
        """

        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.pad_idx] = -float('inf')
        prob = torch.softmax(logits, 1)

        # Probability of copying p(z=1) batch.
        p_copy = torch.sigmoid(self.linear_copy(hidden))
        # Probability of not copying: p_{word}(w) * (1 - p(z))

        if self.training and ptrs is not None:
            align_unk = align.eq(0).float().view(-1, 1)
            align_not_unk = align.ne(0).float().view(-1, 1)
            out_prob = torch.mul(prob, align_unk)
            mul_attn = torch.mul(attn, align_not_unk)
            mul_attn = torch.mul(mul_attn, ptrs.view(-1, slen_).float())
        else:
            out_prob = torch.mul(prob, 1 - p_copy)

            # Mask disallowed copys
            if tags is not None:
                mul_attn = torch.mul(attn, tags.t())*2
            else:
                mul_attn = attn

            mul_attn = torch.mul(mul_attn, p_copy)

        copy_prob = torch.bmm(
            mul_attn.view(-1, batch, slen).transpose(0, 1),
            src_map.transpose(0, 1)
        ).transpose(0, 1)
        # The P_copy actual contain the importance of the word from the training decision.
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return torch.cat([out_prob, copy_prob], 1), p_copy
Example #24
0
    def _run_forward_pass(self, tgt, word_memory_bank, sent_memory_bank, state,
                          word_memory_lengths, sent_memory_lengths, C):
        """
        See StdRNNDecoder._run_forward_pass() for description
        of arguments and return values.
        """
        # Additional args check.
        input_feed = state.input_feed.squeeze(0)
        input_feed_batch, _ = input_feed.size()
        _, tgt_batch, _ = tgt.size()
        aeq(tgt_batch, input_feed_batch)
        # END Additional args check.
        C_final = self.V(
            torch.tanh(C + input_feed.unsqueeze(1).expand(-1, C.size(1), -1))
        ).expand(-1, -1, C.size(2)) * C
        r_t = torch.sum(C_final, 1)

        # Initialize local and return variables.
        decoder_outputs = []
        attns = {"std": []}

        emb = self.embeddings(tgt)
        assert emb.dim() == 3  # len x batch x embedding_dim

        hidden = state.hidden

        # Input feed concatenates hidden state with
        # input at every time step.
        for outidx, emb_t in enumerate(emb.split(1)):
            # logger.info('generate %d word' %outidx)
            emb_t = emb_t.squeeze(0)
            decoder_input = torch.cat([emb_t, input_feed, r_t], 1)

            rnn_output, hidden = self.rnn(decoder_input, hidden)

            # attn
            decoder_output, attn = self.attn(rnn_output, word_memory_bank,
                                             word_memory_lengths,
                                             sent_memory_bank,
                                             sent_memory_lengths)

            decoder_output = self.dropout(decoder_output)
            input_feed = decoder_output
            C_final = self.V(
                torch.tanh(C + input_feed.unsqueeze(1).expand(
                    -1, C.size(1), -1))).expand(-1, -1, C.size(2)) * C
            r_t = torch.sum(C_final, 1)

            decoder_outputs += [decoder_output]
            attns["std"] += [attn]

        # Return result.
        return hidden, decoder_outputs, attns
Example #25
0
    def _example_dict_iter(self, line, index):
        sessions = line.strip('\n').split('||')
        for s in sessions:
            assert len(s.split('\t')) == 11
        session_id = [s.split('\t')[0] for s in sessions]
        item_sku_id = [s.split('\t')[1] for s in sessions]
        user_log = [s.split('\t')[2] for s in sessions]
        operator = [s.split('\t')[3] for s in sessions]
        user_site_cy = [s.split('\t')[4] for s in sessions]
        user_site_pro = [s.split('\t')[5] for s in sessions]
        user_site_ct = [s.split('\t')[6] for s in sessions]
        stm = [int(s.split('\t')[7]) for s in sessions]
        page_ts = [int(s.split('\t')[8]) for s in sessions]
        item_name = [s.split('\t')[9].split() for s in sessions]
        item_comment = [s.split('\t')[10].split() for s in sessions]

        line = []

        if self.line_truncate:
            for tmp_name, tmp_comment in zip(item_name, item_comment):
                line.extend(tmp_name[:self.line_truncate])
                line.extend(tmp_comment[:self.line_truncate])
        else:
            for tmp_name, tmp_comment in zip(item_name, item_comment):
                line.extend(tmp_name)
                line.extend(tmp_comment)

        words, feats, n_feats = TextDataset.extract_text_features(line)
        example_dict = {
            self.side: words,
            self.side + "_session_id": session_id,
            self.side + "_item_sku": item_sku_id,
            self.side + "_user_log": user_log,
            self.side + "_operator": operator,
            self.side + "_site_cy": user_site_cy,
            self.side + "_site_pro": user_site_pro,
            self.side + "_site_ct": user_site_ct,
            self.side + "_stm": stm,
            self.side + "_page_ts": page_ts,
            "indices": index
        }

        if feats:
            # All examples must have same number of features.
            aeq(self.n_feats, n_feats)

            prefix = self.side + "_feat_"
            example_dict.update(
                (prefix + str(j), f) for j, f in enumerate(feats))

        return example_dict
Example #26
0
    def _check_args(self, src, lengths=None, hidden=None):
        # print("in chcek Args")
        # print(type(src))
        _, n_batch, _ = src.size()
        #print(n_batch)
        #print(lengths.size())
        if lengths is not None:

            # print("encoder base \n")
            # print(lengths.size())
            x_batch_ = lengths.size()
            n_batch_, = lengths.size()
            #print(' <<<<<<<<<<<<<<<<<', (n_batch, n_batch_))
            aeq(n_batch, n_batch_)
Example #27
0
    def forward(self, hidden, attn, src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by copying
        source words.

        Args:
           hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
           attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
           src_map (FloatTensor):
               A sparse indicator matrix mapping each source word to
               its index in the "extended" vocab containing.
               ``(src_len, batch, extra_words)``
        """
        if self.conv_first:

            attn = torch.unsqueeze(attn, 1)
            original_seq_len = src_map.shape[0]

            if original_seq_len % 3 == 0:
                attn = self.conv_transpose(attn)
            elif original_seq_len % 3 == 1:
                attn = self.conv_transpose_pad1(attn)
            else:
                attn = self.conv_transpose_pad2(attn)

            attn = torch.squeeze(attn, 1)

        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.pad_idx] = -float('inf')
        prob = torch.softmax(logits, 1)

        # Probability of copying p(z=1) batch.
        p_copy = torch.sigmoid(self.linear_copy(hidden))
        # Probability of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob, 1 - p_copy)
        mul_attn = torch.mul(attn, p_copy)
        copy_prob = torch.bmm(
            mul_attn.view(-1, batch, slen).transpose(0, 1),
            src_map.transpose(0, 1)).transpose(0, 1)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return torch.cat([out_prob, copy_prob], 1)
Example #28
0
    def _example_dict_iter(self, line, index):
        line = line.split()
        if self.line_truncate:
            line = line[:self.line_truncate]
        words, feats, n_feats = TextDataset.extract_text_features(line)
        example_dict = {self.side: words, "indices": index}
        if feats:
            aeq(self.n_feats, n_feats)

            prefix = self.side + "_feat_"
            example_dict.update(
                (prefix + str(j), f) for j, f in enumerate(feats))

        return example_dict
Example #29
0
    def forward(self, hidden, attn, src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by copying
        source words.

        Args:
           hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
           attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
           src_map (FloatTensor):
               A sparse indicator matrix mapping each source word to
               its index in the "extended" vocab containing.
               ``(src_len, batch, extra_words)``
        """

        # CHECKS
        # hidden = (tgt_len * batch, hidden)
        # attn = (tgt_len * batch, src_len)
        # src_map = (src_len * batch, cvocab)
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        # logits = (tgt_len * batch, tvocab)
        logits = self.linear(hidden)
        logits[:, self.pad_idx] = -float('inf')
        # prob = (tgt_len * batch, tvocab)
        prob = torch.softmax(logits, 1)

        # Probability of copying p(z=1) batch.
        # p_copy = (tgt_len * batch, 1)
        p_copy = torch.sigmoid(self.linear_copy(hidden))
        # Probability of not copying: p_{word}(w) * (1 - p(z))
        # out_prob = (tgt_len * batch, tvocab)
        out_prob = torch.mul(prob, 1 - p_copy)
        # mul_attn = (tgt_len * batch, src_len)
        mul_attn = torch.mul(attn, p_copy)
        # copy_prob = (batch, tgt_len, src_len) x (batch, src_len, cvocab) --> (batch, tgt_len, cvocab)
        # copy_prob --> (tgt_len, batch, cvocab)
        copy_prob = torch.bmm(
            mul_attn.view(-1, batch, slen).transpose(0, 1),
            src_map.transpose(0, 1)).transpose(0, 1)
        # copy_prob --> (tgt_len * batch, cvocab)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        # --> (tgt_len * batch, tvocab + cvocab)
        return torch.cat([out_prob, copy_prob], 1)
    def forward(self, hidden, his_attn, cur_attn, his_mid, cur_mid, src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by copying
        source words.

        Args:
           hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
           his_mid (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
           cur_mid (FloatTensor): hidden outputs ``(batch x tlen, input_size)``

           his_attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
           cur_attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
           src_map (FloatTensor):
               A sparse indicator matrix mapping each source word to
               its index in the "extended" vocab containing.
               ``(src_len, batch, extra_words)``
        """

        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = his_attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.pad_idx] = -float('inf')
        prob = torch.softmax(logits, 1)

        # Probability of lambda
        feature = self.hidden_dense(hidden) + self.his_dense(
            his_mid) + self.cur_dense(cur_mid)
        lambda_gate = torch.sigmoid(feature)

        # Probability of copying p(z=1) batch.
        p_copy = torch.sigmoid(self.linear_copy(hidden))

        # Probability of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob, 1 - p_copy)

        attn = lambda_gate * his_attn + (1 - lambda_gate) * cur_attn
        mul_attn = torch.mul(attn, p_copy)
        copy_prob = torch.bmm(
            mul_attn.view(-1, batch, slen).transpose(0, 1),
            src_map.transpose(0, 1)).transpose(0, 1)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return torch.cat([out_prob, copy_prob], 1), attn
Example #31
0
    def forward(ctx, input, target):
        """
        input (FloatTensor): ``(n, num_classes)``.
        target (LongTensor): ``(n,)``, the indices of the target classes
        """
        input_batch, classes = input.size()
        target_batch = target.size(0)
        aeq(input_batch, target_batch)

        z_k = input.gather(1, target.unsqueeze(1)).squeeze()
        tau_z, support_size = _threshold_and_support(input, dim=1)
        support = input > tau_z
        x = torch.where(
            support, input**2 - tau_z**2,
            torch.tensor(0.0, device=input.device)
        ).sum(dim=1)
        ctx.save_for_backward(input, target, tau_z)
        # clamping necessary because of numerical errors: loss should be lower
        # bounded by zero, but negative values near zero are possible without
        # the clamp
        return torch.clamp(x / 2 - z_k + 0.5, min=0.0)
Example #32
0
    def forward(self, hidden, attn, src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by copying
        source words.

        Args:
           hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
           attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
           src_map (FloatTensor):
               A sparse indicator matrix mapping each source word to
               its index in the "extended" vocab containing.
               ``(src_len, batch, extra_words)``
        """

        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.pad_idx] = -float('inf')
        prob = torch.softmax(logits, 1)

        # Probability of copying p(z=1) batch.
        p_copy = torch.sigmoid(self.linear_copy(hidden))
        # Probability of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob, 1 - p_copy)
        mul_attn = torch.mul(attn, p_copy)
        copy_prob = torch.bmm(
            mul_attn.view(-1, batch, slen).transpose(0, 1),
            src_map.transpose(0, 1)
        ).transpose(0, 1)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return torch.cat([out_prob, copy_prob], 1)
    def forward(self, base_target_emb, input_from_dec, encoder_out_top,
                encoder_out_combine):
        """
        Args:
            base_target_emb: target emb tensor
            input_from_dec: output of decode conv
            encoder_out_top: the key matrix for calculation of attetion weight,
                which is the top output of encode conv
            encoder_out_combine:
                the value matrix for the attention-weighted sum,
                which is the combination of base emb and top output of encode
        """

        # checks
        # batch, channel, height, width = base_target_emb.size()
        batch, _, height, _ = base_target_emb.size()
        # batch_, channel_, height_, width_ = input_from_dec.size()
        batch_, _, height_, _ = input_from_dec.size()
        aeq(batch, batch_)
        aeq(height, height_)

        # enc_batch, enc_channel, enc_height = encoder_out_top.size()
        enc_batch, _, enc_height = encoder_out_top.size()
        # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size()
        enc_batch_, _, enc_height_ = encoder_out_combine.size()

        aeq(enc_batch, enc_batch_)
        aeq(enc_height, enc_height_)

        preatt = seq_linear(self.linear_in, input_from_dec)
        target = (base_target_emb + preatt) * SCALE_WEIGHT
        target = torch.squeeze(target, 3)
        target = torch.transpose(target, 1, 2)
        pre_attn = torch.bmm(target, encoder_out_top)

        if self.mask is not None:
            pre_attn.data.masked_fill_(self.mask, -float('inf'))

        attn = F.softmax(pre_attn, dim=2)

        context_output = torch.bmm(
            attn, torch.transpose(encoder_out_combine, 1, 2))
        context_output = torch.transpose(
            torch.unsqueeze(context_output, 3), 1, 2)
        return context_output, attn
Example #34
0
    def score(self, h_t, h_s):
        """
        Args:
          h_t (FloatTensor): sequence of queries ``(batch, tgt_len, dim)``
          h_s (FloatTensor): sequence of sources ``(batch, src_len, dim``

        Returns:
          FloatTensor: raw attention scores (unnormalized) for each src index
            ``(batch, tgt_len, src_len)``
        """

        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_type == "general":
                h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim)
                h_t_ = self.linear_in(h_t_)
                h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
            h_s_ = h_s.transpose(1, 2)
            # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, dim)
            wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

            uh = self.linear_context(h_s.contiguous().view(-1, dim))
            uh = uh.view(src_batch, 1, src_len, dim)
            uh = uh.expand(src_batch, tgt_len, src_len, dim)

            # (batch, t_len, s_len, d)
            wquh = torch.tanh(wq + uh)

            return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
Example #35
0
    def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          source (FloatTensor): query vectors ``(batch, tgt_len, dim)``
          memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)``
          memory_lengths (LongTensor): the source context lengths ``(batch,)``
          coverage (FloatTensor): None (not supported yet)

        Returns:
          (FloatTensor, FloatTensor):

          * Computed vector ``(tgt_len, batch, dim)``
          * Attention distribtutions for each query
            ``(tgt_len, batch, src_len)``
        """

        # one step input
        if source.dim() == 2:
            one_step = True
            source = source.unsqueeze(1)
        else:
            one_step = False

        batch, source_l, dim = memory_bank.size()
        batch_, target_l, dim_ = source.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, source_l_ = coverage.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = torch.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(source, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths, max_len=align.size(-1))
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.masked_fill_(1 - mask, -float('inf'))

        # Softmax or sparsemax to normalize attention weights
        if self.attn_func == "softmax":
            align_vectors = F.softmax(align.view(batch*target_l, source_l), -1)
        else:
            align_vectors = sparsemax(align.view(batch*target_l, source_l), -1)
        align_vectors = align_vectors.view(batch, target_l, source_l)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2)
        attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = torch.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, source_l_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()
            # Check output sizes
            target_l_, batch_, dim_ = attn_h.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            target_l_, batch_, source_l_ = align_vectors.size()
            aeq(target_l, target_l_)
            aeq(batch, batch_)
            aeq(source_l, source_l_)

        return attn_h, align_vectors
Example #36
0
    def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
        """
        See StdRNNDecoder._run_forward_pass() for description
        of arguments and return values.
        """
        # Additional args check.
        input_feed = self.state["input_feed"].squeeze(0)
        input_feed_batch, _ = input_feed.size()
        _, tgt_batch, _ = tgt.size()
        aeq(tgt_batch, input_feed_batch)
        # END Additional args check.

        dec_outs = []
        attns = {}
        if self.attn is not None:
            attns["std"] = []
        if self.copy_attn is not None or self._reuse_copy_attn:
            attns["copy"] = []
        if self._coverage:
            attns["coverage"] = []

        emb = self.embeddings(tgt)
        assert emb.dim() == 3  # len x batch x embedding_dim

        dec_state = self.state["hidden"]
        coverage = self.state["coverage"].squeeze(0) \
            if self.state["coverage"] is not None else None

        # Input feed concatenates hidden state with
        # input at every time step.
        for emb_t in emb.split(1):
            decoder_input = torch.cat([emb_t.squeeze(0), input_feed], 1)
            rnn_output, dec_state = self.rnn(decoder_input, dec_state)
            if self.attentional:
                decoder_output, p_attn = self.attn(
                    rnn_output,
                    memory_bank.transpose(0, 1),
                    memory_lengths=memory_lengths)
                attns["std"].append(p_attn)
            else:
                decoder_output = rnn_output
            if self.context_gate is not None:
                # TODO: context gate should be employed
                # instead of second RNN transform.
                decoder_output = self.context_gate(
                    decoder_input, rnn_output, decoder_output
                )
            decoder_output = self.dropout(decoder_output)
            input_feed = decoder_output

            dec_outs += [decoder_output]

            # Update the coverage attention.
            if self._coverage:
                coverage = p_attn if coverage is None else p_attn + coverage
                attns["coverage"] += [coverage]

            if self.copy_attn is not None:
                _, copy_attn = self.copy_attn(
                    decoder_output, memory_bank.transpose(0, 1))
                attns["copy"] += [copy_attn]
            elif self._reuse_copy_attn:
                attns["copy"] = attns["std"]

        return dec_state, dec_outs, attns
Example #37
0
    def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None):
        """
        Private helper for running the specific RNN forward pass.
        Must be overriden by all subclasses.

        Args:
            tgt (LongTensor): a sequence of input tokens tensors
                ``(len, batch, nfeats)``.
            memory_bank (FloatTensor): output(tensor sequence) from the
                encoder RNN of size ``(src_len, batch, hidden_size)``.
            memory_lengths (LongTensor): the source memory_bank lengths.

        Returns:
            (Tensor, List[FloatTensor], Dict[str, List[FloatTensor]):

            * dec_state: final hidden state from the decoder.
            * dec_outs: an array of output of every time
              step from the decoder.
            * attns: a dictionary of different
              type of attention Tensor array of every time
              step from the decoder.
        """

        assert self.copy_attn is None  # TODO, no support yet.
        assert not self._coverage  # TODO, no support yet.

        attns = {}
        emb = self.embeddings(tgt)

        if isinstance(self.rnn, nn.GRU):
            rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0])
        else:
            rnn_output, dec_state = self.rnn(emb, self.state["hidden"])

        # Check
        tgt_len, tgt_batch, _ = tgt.size()
        output_len, output_batch, _ = rnn_output.size()
        aeq(tgt_len, output_len)
        aeq(tgt_batch, output_batch)

        # Calculate the attention.
        if not self.attentional:
            dec_outs = rnn_output
        else:
            dec_outs, p_attn = self.attn(
                rnn_output.transpose(0, 1).contiguous(),
                memory_bank.transpose(0, 1),
                memory_lengths=memory_lengths
            )
            attns["std"] = p_attn

        # Calculate the context gate.
        if self.context_gate is not None:
            dec_outs = self.context_gate(
                emb.view(-1, emb.size(2)),
                rnn_output.view(-1, rnn_output.size(2)),
                dec_outs.view(-1, dec_outs.size(2))
            )
            dec_outs = dec_outs.view(tgt_len, tgt_batch, self.hidden_size)

        dec_outs = self.dropout(dec_outs)
        return dec_state, dec_outs, attns