Exemplo n.º 1
0
    def forward(self, input, lengths=None, hidden=None):
        """ See :obj:`EncoderBase.forward()`"""
        self._check_args(input, lengths, hidden)

        emb = self.embeddings(input)
        s_len, n_batch, emb_dim = emb.size()

        out = emb.transpose(0, 1).contiguous()
        words = input[:, :, 0].transpose(0, 1)
        # CHECKS
        out_batch, out_len, _ = out.size()
        w_batch, w_len = words.size()
        aeq(out_batch, w_batch)
        aeq(out_len, w_len)
        # END CHECKS

        # Make mask.
        padding_idx = self.embeddings.word_padding_idx
        mask = words.data.eq(padding_idx).unsqueeze(1) \
            .expand(w_batch, w_len, w_len)
        # Run the forward pass of every layer of the tranformer.
        for i in range(self.num_layers):
            out = self.transformer[i](out, mask)
        out = self.layer_norm(out)

        return Variable(emb.data), out.transpose(0, 1).contiguous()
Exemplo n.º 2
0
    def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None):
        """
        Private helper for running the specific RNN forward pass.
        Must be overriden by all subclasses.
        Args:
            tgt (LongTensor): a sequence of input tokens tensors
                                 [len x batch x nfeats].
            memory_bank (FloatTensor): output(tensor sequence) from the encoder
                        RNN of size (src_len x batch x hidden_size).
            state (FloatTensor): hidden state from the encoder RNN for
                                 initializing the decoder.
            memory_lengths (LongTensor): the source memory_bank lengths.
        Returns:
            decoder_final (Variable): final hidden state from the decoder.
            decoder_outputs ([FloatTensor]): an array of output of every time
                                     step from the decoder.
            attns (dict of (str, [FloatTensor]): a dictionary of different
                            type of attention Tensor array of every time
                            step from the decoder.
        """
        assert not self._copy  # TODO, no support yet.
        assert not self._coverage  # TODO, no support yet.

        # Initialize local and return variables.
        attns = {}
        emb = self.embeddings(tgt)

        # Run the forward pass of the RNN.
        if isinstance(self.rnn, nn.GRU):
            rnn_output, decoder_final = self.rnn(emb, state.hidden[0])
        else:
            rnn_output, decoder_final = self.rnn(emb, state.hidden)

        # Check
        tgt_len, tgt_batch, _ = tgt.size()
        output_len, output_batch, _ = rnn_output.size()
        aeq(tgt_len, output_len)
        aeq(tgt_batch, output_batch)
        # END

        # Calculate the attention.
        decoder_outputs, p_attn = self.attn(
            rnn_output.transpose(0, 1).contiguous(),
            memory_bank.transpose(0, 1),
            memory_lengths=memory_lengths
        )
        attns["std"] = p_attn

        # Calculate the context gate.
        if self.context_gate is not None:
            decoder_outputs = self.context_gate(
                emb.view(-1, emb.size(2)),
                rnn_output.view(-1, rnn_output.size(2)),
                decoder_outputs.view(-1, decoder_outputs.size(2))
            )
            decoder_outputs = \
                decoder_outputs.view(tgt_len, tgt_batch, self.hidden_size)

        decoder_outputs = self.dropout(decoder_outputs)
        return decoder_final, decoder_outputs, attns
Exemplo n.º 3
0
    def _example_dict_iter(self, line, index):
        line = line.split()
        if self.line_truncate:
            line = line[:self.line_truncate]
        words, feats, n_feats = TextDataset.extract_text_features(line)
        example_dict = {self.side: words, "indices": index}
        if feats:
            # All examples must have same number of features.
            aeq(self.n_feats, n_feats)

            prefix = self.side + "_feat_"
            example_dict.update((prefix + str(j), f)
                                for j, f in enumerate(feats))

        return example_dict
Exemplo n.º 4
0
    def _example_dict_iter(self, line, index):
        line = line.split()
        if self.line_truncate:
            line = line[:self.line_truncate]
        words, feats, n_feats = TextDataset.extract_text_features(line)
        example_dict = {self.side: words, "indices": index}
        if feats:
            # All examples must have same number of features.
            aeq(self.n_feats, n_feats)

            prefix = self.side + "_feat_"
            example_dict.update(
                (prefix + str(j), f) for j, f in enumerate(feats))

        return example_dict
Exemplo n.º 5
0
    def forward(self, input, context, state, context_lengths=None):
        """
        Forward through the decoder.
        Args:
            input (LongTensor): a sequence of input tokens tensors
                                of size (len x batch x nfeats).
            context (FloatTensor): output(tensor sequence) from the encoder
                        RNN of size (src_len x batch x hidden_size).
            state (FloatTensor): hidden state from the encoder RNN for
                                 initializing the decoder.
            context_lengths (LongTensor): the source context lengths.
        Returns:
            outputs (FloatTensor): a Tensor sequence of output from the decoder
                                   of shape (len x batch x hidden_size).
            state (FloatTensor): final hidden state from the decoder.
            attns (dict of (str, FloatTensor)): a dictionary of different
                                type of attention Tensor from the decoder
                                of shape (src_len x batch).
        """
        # Args Check
        assert isinstance(state, RNNDecoderState)
        input_len, input_batch, _ = input.size()
        contxt_len, contxt_batch, _ = context.size()
        aeq(input_batch, contxt_batch)
        # END Args Check

        # Run the forward pass of the RNN.
        hidden, outputs, attns, coverage, rnn_output, emb = \
            self._run_forward_pass(input, context, state,
                                   context_lengths=context_lengths)

        # Update the state with the result.
        final_output = outputs[-1]
        state.update_state(hidden, final_output.unsqueeze(0),
                           coverage.unsqueeze(0)
                           if coverage is not None else None)

        # Concatenates sequence of tensors along a new dimension.
        outputs = torch.stack(outputs)
        rnn_output = torch.stack(rnn_output)
        for k in attns:
            attns[k] = torch.stack(attns[k])

        return (
            outputs, state, attns,
            # pointer_gen
            rnn_output, emb
        )
 def _check_args_double_enc(self,
                            input,
                            lengths_src=None,
                            lengths_inter=None,
                            hidden_src=None,
                            hidden_inter=None):
     #SRC
     s_len, n_batch, n_feats = input[0].size()
     if lengths_src is not None:
         n_batch_, = lengths_src.size()
         aeq(n_batch, n_batch_)
     #INTER
     s_len, n_batch, n_feats = input[1].size()
     if lengths_inter is not None:
         n_batch_, = lengths_inter.size()
         aeq(n_batch, n_batch_)
Exemplo n.º 7
0
    def forward(self, tgt, memory_bank, state, memory_lengths=None):
        """
        Args:
            tgt (`LongTensor`): sequences of padded tokens
                                `[tgt_len x batch x nfeats]`.
            memory_bank (`FloatTensor`): vectors from the encoder
                 `[src_len x batch x hidden]`.
            state (:obj:`onmt.Models.DecoderState`):
                 decoder state object to initialize the decoder
            memory_lengths (`LongTensor`): the padded source lengths
                `[batch]`.
        Returns:
            (`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`):
                * decoder_outputs: output from the decoder (after attn)
                         `[tgt_len x batch x hidden]`.
                * decoder_state: final hidden state from the decoder
                * attns: distribution over src at each tgt
                        `[tgt_len x batch x src_len]`.
        """
        # Check
        assert isinstance(state, RNNDecoderState)
        tgt_len, tgt_batch, _ = tgt.size()
        _, memory_batch, _ = memory_bank.size()
        aeq(tgt_batch, memory_batch)
        # END

        # Run the forward pass of the RNN.
        decoder_final, decoder_outputs, attns = self._run_forward_pass(
            tgt, memory_bank, state, memory_lengths=memory_lengths)

        # Update the state with the result.
        final_output = decoder_outputs[-1]
        coverage = None
        if "coverage" in attns:
            coverage = attns["coverage"][-1].unsqueeze(0)
        state.update_state(decoder_final, final_output.unsqueeze(0), coverage)

        # Concatenates sequence of tensors along a new dimension.
        # Change for torch0.4
        if type(decoder_outputs
                ) is not torch.Tensor:  # If input feeding is being used
            decoder_outputs = torch.stack(decoder_outputs)
        for k in attns:
            if type(attns[k]) is not torch.Tensor:
                attns[k] = torch.stack(attns[k])

        return decoder_outputs, state, attns
Exemplo n.º 8
0
    def forward(self, input, context, state,
                context_lengths=None, **kwargs):
        """
        Args:
            input (`LongTensor`): sequences of padded tokens
                                `[tgt_len x batch x nfeats]`.
            context (`FloatTensor`): vectors from the encoder
                 `[src_len x batch x hidden]`.
            state (:obj:`onmt.Models.DecoderState`):
                 decoder state object to initialize the decoder
            context_lengths (`LongTensor`): the padded source lengths
                `[batch]`.
        Returns:
            (`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`):
                * outputs: output from the decoder
                         `[tgt_len x batch x hidden]`.
                * state: final hidden state from the decoder
                * attns: distribution over src at each tgt
                        `[tgt_len x batch x src_len]`.
        """
        # Args Check
        assert isinstance(state, RNNDecoderState)
        input_len, input_batch, _ = input.size()
        contxt_len, contxt_batch, _ = context.size()
        aeq(input_batch, contxt_batch)
        # END Args Check

        # Run the forward pass of the RNN.
        # All the latent variables and additional inputs are found in kwargs
        hidden, outputs, attns, coverage = self._run_forward_pass(
            input, context, state,
            context_lengths=context_lengths,
            **kwargs)

        # Update the state with the result.
        final_output = outputs[-1]
        state.update_state(hidden, final_output.unsqueeze(0),
                           coverage.unsqueeze(0)
                           if coverage is not None else None)

        # Concatenates sequence of tensors along a new dimension.
        outputs = torch.stack(outputs)
        for k in attns:
            if not k in ["q_latent", "p_latent"]:
                attns[k] = torch.stack(attns[k])

        return outputs, state, attns
Exemplo n.º 9
0
    def forward_mm(self, input, img_proj, lengths=None, hidden=None):
        """
                Args:
                    input (:obj:`LongTensor`):
                       padded sequences of sparse indices `[src_len x batch x nfeat]`
                    lengths (:obj:`LongTensor`): length of each sequence `[batch]`
                    hidden (class specific):
                       initial hidden state.

                Returns:k
                    (tuple of :obj:`FloatTensor`, :obj:`FloatTensor`):
                        * final encoder state, used to initialize decoder
                           `[layers x batch x hidden]`
                        * contexts for attention, `[src_len x batch x hidden]`
                """
        self._check_args(input, lengths, hidden)

        emb = self.embeddings(input)
        s_len, n_batch, emb_dim = emb.size()

        emb = emb.transpose(0, 1).contiguous()  # (batch, src_len, nfeat)
        input_mm = torch.cat([emb, img_proj], dim=1)
        out = input_mm
        # words = input[:, :, 0].transpose(0, 1) # (batch, src_len)
        words = emb[:, :, 0]
        # CHECKS
        out_batch, out_len, _ = out.size()
        w_batch, w_len = words.size()

        aeq(out_batch, w_batch)
        aeq(s_len, w_len)
        # END CHECKS

        # Make mask. the mask here is no use
        padding_idx = self.embeddings.word_padding_idx
        mask = words.data.eq(padding_idx).unsqueeze(1) \
            .expand(w_batch, out_len, s_len)

        # assert not words.data.eq(padding_idx).max(), "there are some mask items eqaul to 1"

        # Run the forward pass of every layer of the tranformer.
        for i in range(self.num_layers):
            out, attn = self.transformer[i](emb, out, mask)  # attn 1x49x10
        out = self.layer_norm(out)

        return Variable(input_mm.data), out.transpose(0, 1).contiguous(), attn
    def forward(self, hidden, attn, src_map, align=None, ptrs=None):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by compying
        source words.

        Args:
           hidden (`FloatTensor`): hidden outputs `[batch*tlen, input_size]`
           attn (`FloatTensor`): attn for each `[batch*tlen, input_size]`
           src_map (`FloatTensor`):
             A sparse indicator matrix mapping each source word to
             its index in the "extended" vocab containing.
             `[src_len, batch, extra_words]`
        """
        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.tgt_dict.stoi[onmt.io.PAD_WORD]] = -float('inf')
        prob = F.softmax(logits)

        # Probability of copying p(z=1) batch.
        p_copy = F.sigmoid(self.linear_copy(hidden))
        # Probibility of not copying: p_{word}(w) * (1 - p(z))

        if self.training:
            align_unk = align.eq(0).float().view(-1, 1)
            align_not_unk = align.ne(0).float().view(-1, 1)
            out_prob = torch.mul(prob, align_unk.expand_as(prob))
            mul_attn = torch.mul(attn, align_not_unk.expand_as(attn))
            mul_attn = torch.mul(mul_attn, ptrs.view(-1, slen_).float())
        else:
            out_prob = torch.mul(prob, 1 - p_copy.expand_as(prob))
            mul_attn = torch.mul(attn, p_copy.expand_as(attn))
        copy_prob = torch.bmm(
            mul_attn.view(-1, batch, slen).transpose(0, 1),
            src_map.transpose(0, 1)).transpose(0, 1)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)

        return torch.cat([out_prob, copy_prob], 1), p_copy
Exemplo n.º 11
0
    def score(self, h_t, h_s, entity_attn=False):
        """
        Args:
          h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]`
          h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]`

        Returns:
          :obj:`FloatTensor`:
           raw attention scores (unnormalized) for each src index
          `[batch x tgt_len x src_len]`

        """

        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(self.dim, tgt_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_type == "general":
                h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim)
                if entity_attn:
                    h_t_ = self.linear_in_entity(h_t_)
                else:
                    h_t_ = self.linear_in(h_t_)
                h_t = h_t_.view(tgt_batch, tgt_len, src_dim)
            h_s_ = h_s.transpose(1, 2)
            # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, dim)
            wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

            uh = self.linear_context(h_s.contiguous().view(-1, dim))
            uh = uh.view(src_batch, 1, src_len, dim)
            uh = uh.expand(src_batch, tgt_len, src_len, dim)

            # (batch, t_len, s_len, d)
            wquh = self.tanh(wq + uh)

            return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
Exemplo n.º 12
0
    def forward(self, tgt, memory_bank, state, memory_lengths=None):
        """
        Args:
            tgt (`LongTensor`): sequences of padded tokens
                                `[tgt_len x batch x nfeats]`.
            memory_bank (`FloatTensor`): vectors from the encoder
                 `[src_len x batch x hidden]`.
            state (:obj:`onmt.Models.DecoderState`):
                 decoder state object to initialize the decoder
            memory_lengths (`LongTensor`): the padded source lengths
                `[batch]`.
        Returns:
            (`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`):
                * decoder_outputs: output from the decoder (after attn)
                         `[tgt_len x batch x hidden]`.
                * decoder_state: final hidden state from the decoder
                * attns: distribution over src at each tgt
                        `[tgt_len x batch x src_len]`.
        """
        # Check
        assert isinstance(state, RNNDecoderState)
        tgt_len, tgt_batch, _ = tgt.size()
        _, memory_batch, _ = memory_bank.size()
        aeq(tgt_batch, memory_batch)
        # END

        # Run the forward pass of the RNN.
        decoder_final, decoder_outputs, attns = self._run_forward_pass(
            tgt, memory_bank, state, memory_lengths=memory_lengths)

        # Update the state with the result.
        final_output = decoder_outputs[-1]
        coverage = None
        if "coverage" in attns:
            coverage = attns["coverage"][-1].unsqueeze(0)
        state.update_state(decoder_final, final_output.unsqueeze(0), coverage)

        # Concatenates sequence of tensors along a new dimension.
        decoder_outputs = torch.stack(decoder_outputs)
        for k in attns:
            attns[k] = torch.stack(attns[k])

        return decoder_outputs, state, attns
Exemplo n.º 13
0
    def forward(self, input):
        """
        Computes the partly linked embeddings for words.

        Args:
            input (`LongTensor`): index tensor `[len x batch x 1]`
        Return:
            `FloatTensor`: word embeddings `[len x batch x embedding_size]`
        """
        in_length, in_batch, nfeat = input.size()
        aeq(nfeat, 1)

        flat_input = input.view(-1)
        cluster_indices = self.cluster_mapping.index_select(0, flat_input)
        concat = torch.cat([input, cluster_indices.view(input.shape)], dim=-1)

        emb = self.make_embedding(concat)

        return emb
Exemplo n.º 14
0
    def _get_word_context(self, query, context, index, mask_word):
        """  Verify sizes  """
        b_size, t_size, d_size = query.size()
        b_size_, s_size, d_size_ = context.size()
        aeq(d_size, d_size_)
        b_size__, c_size = index.size()
        aeq(b_size, b_size__)
        b_size__, t_size_, s_size_ = mask_word.size()
        aeq(b_size_, b_size__)
        aeq(s_size, s_size_)
        aeq(t_size, t_size_)
        """  Padding index of previous invalid sentences index (<0) to 0, and saving mask for sentences  """
        mask_sent = index < 0
        index_ = copy.deepcopy(index)
        index_[mask_sent] = 0
        """  Select context with index vector  """
        context_ = context.view(b_size_, -1).expand(b_size, b_size_,
                                                    s_size * d_size)
        index__ = index_.unsqueeze(2).expand(b_size, c_size, s_size * d_size)
        context_word = torch.gather(context_, 1,
                                    Variable(index__,
                                             requires_grad=False)).view(
                                                 b_size * c_size, s_size,
                                                 d_size)
        """  Create complete mask for context: word padding + sentence padding """
        mask_ = mask_word.contiguous().view(b_size_,
                                            -1).expand(b_size, b_size_,
                                                       t_size_ * s_size)
        index__ = index_.unsqueeze(2).expand(b_size, c_size, t_size_ * s_size)
        context_pad_mask = torch.gather(mask_, 1,
                                        index__).view(b_size * c_size, t_size_,
                                                      s_size)
        mask_sent_ = mask_sent.unsqueeze(2).expand(
            b_size, c_size,
            t_size_ * s_size).contiguous().view(b_size * c_size, t_size_,
                                                s_size)
        context_pad_mask[mask_sent_] = self.padding_idx
        """  Reshape query for future operations  """
        query_ = query.unsqueeze(1).expand(b_size, c_size, t_size,
                                           d_size).contiguous().view(
                                               b_size * c_size, t_size, d_size)

        return query_, context_word, context_pad_mask
Exemplo n.º 15
0
    def forward(self,
                input,
                lengths=None,
                hidden=None,
                contexts=None,
                neg=None,
                tau=0.5,
                scale=0.5):
        """ See :obj:`EncoderBase.forward()`"""
        self._check_args(input, lengths, hidden)
        emb = self.embeddings(input,
                              contexts=contexts,
                              neg=neg,
                              tau=tau,
                              scale=scale)
        if neg is not None:
            sense_loss = emb[1]
            emb = emb[0]
        s_len, n_batch, emb_dim = emb.size()

        out = emb.transpose(0, 1).contiguous()
        words = input[:, :, 0].transpose(0, 1)
        # CHECKS
        out_batch, out_len, _ = out.size()
        w_batch, w_len = words.size()
        aeq(out_batch, w_batch)
        aeq(out_len, w_len)
        # END CHECKS

        # Make mask.
        padding_idx = self.embeddings.word_padding_idx
        mask = words.data.eq(padding_idx).unsqueeze(1) \
            .expand(w_batch, w_len, w_len)
        # Run the forward pass of every layer of the tranformer.
        for i in range(self.num_layers):
            out = self.transformer[i](out, mask)
        out = self.layer_norm(out)
        if neg is None:
            return Variable(emb.data), out.transpose(0, 1).contiguous()
        else:
            return Variable(emb.data), out.transpose(
                0, 1).contiguous(), sense_loss
Exemplo n.º 16
0
    def forward(self,
                tgt,
                memory_bank,
                state,
                memory_lengths=None,
                q_scores=None,
                tgt_emb=None):
        # Check
        assert isinstance(state, RNNDecoderState)
        tgt_len, tgt_batch, _ = tgt.size()
        _, memory_batch, _ = memory_bank.size()
        aeq(tgt_batch, memory_batch)
        # END

        # Run the forward pass of the RNN.
        decoder_final, decoder_outputs, input_feed, attns, dist_info, decoder_outputs_baseline = self._run_forward_pass(
            tgt,
            memory_bank,
            state,
            memory_lengths=memory_lengths,
            q_scores=q_scores,
            tgt_emb=tgt_emb)

        # Update the state with the result.
        final_output = decoder_outputs[-1]
        coverage = None
        if "coverage" in attns:
            coverage = attns["coverage"][-1].unsqueeze(0)
        state.update_state(decoder_final, input_feed.unsqueeze(0), coverage)

        # Concatenates sequence of tensors along a new dimension.
        # T x K x N x H
        decoder_outputs = torch.stack(decoder_outputs, dim=0)
        if len(decoder_outputs_baseline) > 0:
            decoder_outputs_baseline = torch.stack(decoder_outputs_baseline,
                                                   dim=0)
        else:
            decoder_outputs_baseline = None
        for k in attns:
            attns[k] = torch.stack(attns[k])

        return decoder_outputs, state, attns, dist_info, decoder_outputs_baseline
Exemplo n.º 17
0
    def coalesce_datasets(datasets):
        """Coalesce all dataset instances. """
        final = datasets[0]
        for d in datasets[1:]:
            # `src_vocabs` is a list of `torchtext.vocab.Vocab`.
            # Each sentence transforms into on Vocab.
            # Coalesce them into one big list.
            final.src_vocabs += d.src_vocabs

            # All datasets have same number of features.
            aeq(final.n_src_feats, d.n_src_feats)
            aeq(final.n_tgt_feats, d.n_tgt_feats)

            # `examples` is a list of `torchtext.data.Example`.
            # Coalesce them into one big list.
            final.examples += d.examples

            # All datasets have same fields, no need to update.

        return final
Exemplo n.º 18
0
    def forward(self, input):

        in_length, in_batch, nfeat = input.size()
        aeq(nfeat, len(self.emb_luts))

        emb = self.make_embedding(input)

        out_length, out_batch, emb_size = emb.size()
        aeq(in_length, out_length)
        aeq(in_batch, out_batch)
        aeq(emb_size, self.embedding_size)

        return emb
Exemplo n.º 19
0
    def forward(self, src, lengths=None, encoder_state=None, entities_list=None, entities_len=None):
        "See :obj:`EncoderBase.forward()`"
        self._check_args(src, lengths, encoder_state)

        emb = self.dropout(self.embeddings(src))
        s_len, batch, emb_dim = emb.size()
        mean = emb.mean(0).expand(self.num_layers, batch, emb_dim)
        memory_bank = emb
        encoder_final = (mean, mean)
        s_len_batch, batch_ent, num_entities = entities_list.size()
        s_len_entities, batch_len_ent = entities_len.size()
        aeq(batch_len_ent, batch_ent)
        aeq(num_entities, s_len_entities)
        ent_emb = emb.unsqueeze(1).expand(-1, num_entities, -1, -1)
        s_len, batch, emb_dim = emb.size()
        ent_dim = entities_list.transpose(1,2).unsqueeze(3).expand(-1, -1, -1, emb_dim)
        ent_len_dim = entities_len.unsqueeze(2).expand(-1, -1, emb_dim)
        ent_emb = (ent_emb[:s_len_batch,:,:,:]*ent_dim).sum(0)
        ent_emb = ent_emb/ent_len_dim

        return encoder_final, memory_bank, ent_emb
Exemplo n.º 20
0
    def get_normal_scores(self, h_s, h_t):
        """ h_s: [batch x src_length x rnn_size]
            h_t: [batch x tgt_length x rnn_size]
        """
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.rnn_size, src_dim)

        #import pdb; pdb.set_trace()
        h_t_expand = h_t.unsqueeze(2).expand(-1, -1, src_len, -1)
        h_s_expand = h_s.unsqueeze(1).expand(-1, tgt_len, -1, -1)
        # [batch, tgt_len, src_len, src_dim]
        h_expand = torch.cat((h_t_expand, h_s_expand), dim=3)
        h_fold = h_expand.contiguous().view(-1, src_dim + tgt_dim)

        h_enc = self.softplus(self.linear_1(h_fold))
        h_enc = self.softplus(self.linear_2(h_enc))

        h_mean = self.softplus(self.mean_out(h_enc))
        h_var = self.softplus(self.var_out(h_enc))

        h_mean = h_mean.view(tgt_batch, tgt_len, src_len)
        h_var = h_var.view(tgt_batch, tgt_len, src_len)
        return [h_mean, h_var]
Exemplo n.º 21
0
    def score(self, h_t, h_s):

        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_type == "general":
                h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim)
                h_t_ = self.linear_in(h_t_)
                h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
            h_s_ = h_s.transpose(1, 2)
            # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, dim)
            wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

            uh = self.linear_context(h_s.contiguous().view(-1, dim))
            uh = uh.view(src_batch, 1, src_len, dim)
            uh = uh.expand(src_batch, tgt_len, src_len, dim)

            # (batch, t_len, s_len, d)
            wquh = self.tanh(wq + uh)

            return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
Exemplo n.º 22
0
    def forward(self, hidden, attn, src_map, rnn_output, src_emb):
        """
        Computes p(w) = p(z=1) p_{copy}(w|z=0)  +  p(z=0) * p_{softmax}(w|z=0)
        """
        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.tgt_dict.stoi[onmt.IO.PAD_WORD]] = -float('inf')
        prob = F.softmax(logits)

        # Probability of copying p(z=1) batch.
        if self.pointer_gen:
            """
            p_gen = sigm(w1*hidden + w2*decoder_state + w3*decoder_input)
            
            hidden = post-attention hidden_state
            decoder_state = pre-attention hidden_state 
            """
            copy = F.sigmoid(
                self.linear_hidden(hidden) +
                self.linear_decoder_state(rnn_output) +
                self.linear_decoder_input(src_emb))
        else:
            copy = F.sigmoid(self.linear_copy(hidden))

        # Probability of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob,
                             1 - copy.expand_as(prob.unsqueeze(0))).squeeze(0)
        mul_attn = torch.mul(attn, copy.expand_as(attn.unsqueeze(0)))
        copy_prob = torch.bmm(
            mul_attn.view(-1, batch, slen).transpose(0, 1),
            src_map.transpose(0, 1)).transpose(0, 1)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return torch.cat([out_prob, copy_prob], 1)
Exemplo n.º 23
0
    def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None):

        assert not self._copy  # TODO, no support yet.
        assert not self._coverage  # TODO, no support yet.

        # Initialize local and return variables.
        attns = {}
        emb = self.embeddings(tgt)

        # Run the forward pass of the RNN.
        if isinstance(self.rnn, nn.GRU):
            rnn_output, decoder_final = self.rnn(emb, state.hidden[0])
        else:
            rnn_output, decoder_final = self.rnn(emb, state.hidden)

        # Check
        tgt_len, tgt_batch, _ = tgt.size()
        output_len, output_batch, _ = rnn_output.size()
        aeq(tgt_len, output_len)
        aeq(tgt_batch, output_batch)
        # END

        # Calculate the attention.
        decoder_outputs, p_attn = self.attn(rnn_output.transpose(
            0, 1).contiguous(),
                                            memory_bank.transpose(0, 1),
                                            memory_lengths=memory_lengths)
        attns["std"] = p_attn

        # Calculate the context gate.
        if self.context_gate is not None:
            decoder_outputs = self.context_gate(
                emb.view(-1, emb.size(2)),
                rnn_output.view(-1, rnn_output.size(2)),
                decoder_outputs.view(-1, decoder_outputs.size(2)))
            decoder_outputs = \
                decoder_outputs.view(tgt_len, tgt_batch, self.hidden_size)

        decoder_outputs = self.dropout(decoder_outputs)
        return decoder_final, decoder_outputs, attns
Exemplo n.º 24
0
    def score(self, h_t, h_s):
        """
        Args:
          h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]`
          h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]`

        Returns:
          :obj:`FloatTensor`:
           raw attention scores (unnormalized) for each src index
          `[batch x tgt_len x src_len]`

        """

        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_type == "general":
                h_t_ = h_t.view(tgt_batch*tgt_len, tgt_dim)
                h_t_ = self.linear_in(h_t_)
                h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
            h_s_ = h_s.transpose(1, 2)
            # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
            return torch.bmm(h_t, h_s_)
Exemplo n.º 25
0
    def _get_word_context(self, query, context, index, mask_word):

        b_size, t_size, d_size = query.size()
        b_size_, s_size, d_size_ = context.size()
        aeq(d_size, d_size_)
        b_size__, c_size = index.size()
        aeq(b_size, b_size__)
        b_size__, t_size_, s_size_ = mask_word.size()
        aeq(b_size_, b_size__)
        aeq(s_size, s_size_)
        aeq(t_size, t_size_)

        # Create padding mask for previous sentences
        mask_sent = index < 0
        index_ = copy.deepcopy(index)
        index_[mask_sent] = 0

        # Get context
        context_ = context.view(b_size_, -1).expand(b_size, b_size_, s_size * d_size)
        index__ = index_.unsqueeze(2).expand(b_size, c_size, s_size * d_size)
        context_word = torch.gather(context_, 1, Variable(index__,
                                                          requires_grad=False)).view(b_size * c_size, s_size, d_size)

        # Get mask for context
        mask_ = mask_word.contiguous().view(b_size_, -1).expand(b_size, b_size_, t_size_ * s_size)
        index__ = index_.unsqueeze(2).expand(b_size, c_size, t_size_ * s_size)
        context_pad_mask = torch.gather(mask_, 1, index__).view(b_size * c_size, t_size_, s_size)

        # Mask previous sentences
        mask_sent_ = mask_sent.unsqueeze(2).expand(b_size,
                                                   c_size, t_size_ * s_size).contiguous().view(b_size * c_size, t_size_,
                                                                                               s_size)
        context_pad_mask[mask_sent_] = self.padding_idx

        # Expand query for each context sentence
        query_ = query.unsqueeze(1).expand(b_size, c_size,
                                           t_size, d_size).contiguous().view(b_size * c_size, t_size, d_size)

        return query_, context_word, context_pad_mask,
Exemplo n.º 26
0
    def forward(self, input, contexts=None, neg=None, tau=0.5, scale=0.5):
        """
        Computes the embeddings for words and features.

        Args:
            input (`LongTensor`): index tensor `[len x batch x nfeat]`
        Return:
            `FloatTensor`: word embeddings `[len x batch x embedding_size]`
        """

        in_length, in_batch, nfeat = input.size()

        word_emb, sense_loss = None, None
        if self.SenseModule is not None:
            aeq(nfeat - 1, len(self.emb_luts))
            assert contexts is not None
            if neg is None:
                word_emb = self.SenseModule(input[:, :, 0],
                                            contexts,
                                            tau=tau,
                                            scale=scale)
            else:
                word_emb, sense_loss = self.SenseModule(input[:, :, 0],
                                                        contexts,
                                                        neg=neg,
                                                        tau=tau,
                                                        scale=scale)
        else:
            aeq(nfeat, len(self.emb_luts))
        emb = self.make_embedding((input, word_emb))

        out_length, out_batch, emb_size = emb.size()
        aeq(in_length, out_length)
        aeq(in_batch, out_batch)
        aeq(emb_size, self.embedding_size)
        if neg is None:
            return emb
        else:
            return emb, sense_loss
Exemplo n.º 27
0
    def forward(self,
                input,
                context,
                state,
                context_lengths=None,
                r_input=None):
        # Args Check
        assert isinstance(state, RNNDecoderState)
        input_len, input_batch, _ = input.size()
        contxt_len, contxt_batch, _ = context.size()
        aeq(input_batch, contxt_batch)
        # END Args Check

        # Run the forward pass of the RNN.
        context = self.context_mlp(context)
        if self.training:
            bk_att_output, bk_rnn_output = self._run_backward_pass(
                r_input, context, state)
            self.bk_rnn_output = bk_rnn_output.detach()
            # self.bk_rnn_output = bk_rnn_output

        hidden, outputs, attns, coverage = self._run_forward_pass(
            input, context, state, context_lengths=context_lengths)

        # Update the state with the result.
        final_output = outputs[-1]
        state.update_state(
            hidden, final_output.unsqueeze(0),
            coverage.unsqueeze(0) if coverage is not None else None)

        # Concatenates sequence of tensors along a new dimension.
        outputs = torch.stack(outputs)
        for k in attns:
            attns[k] = torch.stack(attns[k])
        if self.training:
            return outputs, bk_att_output, state, attns
        else:
            return outputs, state, attns
Exemplo n.º 28
0
    def _example_dict_iter(self, line, index):
        if self.symbol_representation == "char":
            line = list(line.strip())
        elif self.symbol_representation == "word":
            line = line.split()
        if self.line_truncate:
            line = line[:self.line_truncate]
        words, feats, n_feats = TextDataset.extract_text_features(line)

        if self.revert:
            words = tuple(reversed(words))
            feats = tuple(reversed(feats))

        example_dict = {self.side: words, "indices": index}
        if feats:
            # All examples must have same number of features.
            aeq(self.n_feats, n_feats)

            prefix = self.side + "_feat_"
            example_dict.update(
                (prefix + str(j), f) for j, f in enumerate(feats))

        return example_dict
Exemplo n.º 29
0
    def _example_dict_iter(self, line, index):
        line = line.split()
        if self.line_truncate:
            line = line[:self.line_truncate]
        words, feats, n_feats = TextDataset.extract_text_features(line)
        example_dict = {self.side: words, "indices": index}
        if self.side == 'tgt1':
            example_dict = {
                self.side: words,
                'tgt1_planning': [int(word) for word in words],
                'player_row_indices': [int(word) for word in words],
                'team_row_indices': [int(word) for word in words],
                "indices": index
            }
        if feats:
            # All examples must have same number of features.
            aeq(self.n_feats, n_feats)

            prefix = self.side + "_feat_"
            example_dict.update(
                (prefix + str(j), f) for j, f in enumerate(feats))

        return example_dict
Exemplo n.º 30
0
    def _get_sent_context(self, query, context_word, context_index, attn_word):

        b_size, t_size, d_size = query.size()
        _, c_size = context_index.size()

        # Sequence size now context_word is context size
        context_sent = context_word.view(b_size, c_size, t_size, d_size).transpose(1, 2).contiguous().view(
            b_size * t_size, c_size, d_size)

        # Creating the mask for padding by word and sentence
        mask_sent = context_index < 0
        context_pad_mask = mask_sent.unsqueeze(1).expand(b_size,
                                                         t_size, c_size).contiguous().view(b_size * t_size, -1)
        context_pad_mask = context_pad_mask.unsqueeze(1).contiguous()

        # Re-arrange the query
        query_ = query.view(b_size * t_size, 1, d_size)

        _, h, t, s = attn_word.size()
        aeq(t, t_size)
        attn_word = attn_word.view(b_size, c_size, h, t, s)

        return query_, context_sent, context_pad_mask, attn_word
Exemplo n.º 31
0
    def forward(self, base_target_emb, input, encoder_out_top,
                encoder_out_combine):
        """
        It's like Luong Attetion.
        Conv attention takes a key matrix, a value matrix and a query vector.
        Attention weight is calculated by key matrix with the query vector
        and sum on the value matrix. And the same operation is applied
        in each decode conv layer.
        Args:
            base_target_emb: target emb tensor
            input: output of decode conv
            encoder_out_t: the key matrix for calculation of attetion weight,
                which is the top output of encode conv
            encoder_out_c: the value matrix for the attention-weighted sum,
                which is the combination of base emb and top output of encode

        """
        # checks
        batch, channel, height, width = base_target_emb.size()
        batch_, channel_, height_, width_ = input.size()
        aeq(batch, batch_)
        aeq(height, height_)

        enc_batch, enc_channel, enc_height = encoder_out_top.size()
        enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size()

        aeq(enc_batch, enc_batch_)
        aeq(enc_height, enc_height_)

        preatt = seq_linear(self.linear_in, input)
        target = (base_target_emb + preatt) * SCALE_WEIGHT
        target = torch.squeeze(target, 3)
        target = torch.transpose(target, 1, 2)
        pre_attn = torch.bmm(target, encoder_out_top)

        if self.mask is not None:
            pre_attn.data.masked_fill_(self.mask, -float('inf'))

        pre_attn = pre_attn.transpose(0, 2)
        attn = F.softmax(pre_attn)
        attn = attn.transpose(0, 2).contiguous()
        context_output = torch.bmm(
            attn, torch.transpose(encoder_out_combine, 1, 2))
        context_output = torch.transpose(
            torch.unsqueeze(context_output, 3), 1, 2)
        return context_output, attn
Exemplo n.º 32
0
    def forward(self, hidden, attn, src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by compying
        source words.

        Args:
           hidden (`FloatTensor`): hidden outputs `[batch*tlen, input_size]`
           attn (`FloatTensor`): attn for each `[batch*tlen, input_size]`
           src_map (`FloatTensor`):
             A sparse indicator matrix mapping each source word to
             its index in the "extended" vocab containing.
             `[src_len, batch, extra_words]`
        """
        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.tgt_dict.stoi[onmt.io.PAD_WORD]] = -float('inf')
        prob = F.softmax(logits)

        # Probability of copying p(z=1) batch.
        p_copy = F.sigmoid(self.linear_copy(hidden))
        # Probibility of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob,  1 - p_copy.expand_as(prob))
        mul_attn = torch.mul(attn, p_copy.expand_as(attn))
        copy_prob = torch.bmm(mul_attn.view(-1, batch, slen)
                              .transpose(0, 1),
                              src_map.transpose(0, 1)).transpose(0, 1)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return torch.cat([out_prob, copy_prob], 1)
Exemplo n.º 33
0
    def forward(self, hidden, attn, src_map):

        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.tgt_dict.stoi[onmt.io.PAD_WORD]] = -float('inf')
        prob = F.softmax(logits)

        # Probability of copying p(z=1) batch.
        p_copy = F.sigmoid(self.linear_copy(hidden))
        # Probibility of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob, 1 - p_copy.expand_as(prob))
        mul_attn = torch.mul(attn, p_copy.expand_as(attn))
        copy_prob = torch.bmm(
            mul_attn.view(-1, batch, slen).transpose(0, 1),
            src_map.transpose(0, 1)).transpose(0, 1)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return torch.cat([out_prob, copy_prob], 1)
Exemplo n.º 34
0
    def forward(self, input):
        """
        Return the embeddings for words, and features if there are any.
        Args:
            input (LongTensor): len x batch x nfeat
        Return:
            emb (FloatTensor): len x batch x self.embedding_size
        """
        in_length, in_batch, nfeat = input.size()
        aeq(nfeat, len(self.emb_luts))

        emb = self.make_embedding(input)

        out_length, out_batch, emb_size = emb.size()
        aeq(in_length, out_length)
        aeq(in_batch, out_batch)
        aeq(emb_size, self.embedding_size)

        return emb
Exemplo n.º 35
0
    def forward(self, input):
        """
        Computes the embeddings for words and features.
        Args:
            input (`LongTensor`): index tensor `[len x batch x nfeat]`
        Return:
            `FloatTensor`: word embeddings `[len x batch x embedding_size]`
        """
        in_length, in_batch, nfeat = input.size()
        aeq(nfeat, len(self.emb_luts))

        emb = self.make_embedding(input)

        out_length, out_batch, emb_size = emb.size()
        aeq(in_length, out_length)
        aeq(in_batch, out_batch)
        aeq(emb_size, self.embedding_size)

        return emb
    def forward(self, base_target_emb, input, encoder_out_top,
                encoder_out_combine):
        """
        Args:
            base_target_emb: target emb tensor
            input: output of decode conv
            encoder_out_t: the key matrix for calculation of attetion weight,
                which is the top output of encode conv
            encoder_out_combine:
                the value matrix for the attention-weighted sum,
                which is the combination of base emb and top output of encode

        """
        # checks
        batch, channel, height, width = base_target_emb.size()
        batch_, channel_, height_, width_ = input.size()
        aeq(batch, batch_)
        aeq(height, height_)

        enc_batch, enc_channel, enc_height = encoder_out_top.size()
        enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size()

        aeq(enc_batch, enc_batch_)
        aeq(enc_height, enc_height_)

        preatt = seq_linear(self.linear_in, input)
        target = (base_target_emb + preatt) * SCALE_WEIGHT
        target = torch.squeeze(target, 3)
        target = torch.transpose(target, 1, 2)
        pre_attn = torch.bmm(target, encoder_out_top)

        if self.mask is not None:
            pre_attn.data.masked_fill_(self.mask, -float('inf'))

        pre_attn = pre_attn.transpose(0, 2)
        attn = F.softmax(pre_attn)
        attn = attn.transpose(0, 2).contiguous()
        context_output = torch.bmm(
            attn, torch.transpose(encoder_out_combine, 1, 2))
        context_output = torch.transpose(
            torch.unsqueeze(context_output, 3), 1, 2)
        return context_output, attn
Exemplo n.º 37
0
    def forward(self, input):
        """
        Computes the embeddings for words and features.

        Args:
            input (`LongTensor`): index tensor `[len x batch x nfeat]`
        Return:
            `FloatTensor`: word embeddings `[len x batch x embedding_size]`
        """
        in_length, in_batch, nfeat = input.size()
        aeq(nfeat, len(self.emb_luts))

        emb = self.make_embedding(input)

        out_length, out_batch, emb_size = emb.size()
        aeq(in_length, out_length)
        aeq(in_batch, out_batch)
        aeq(emb_size, self.embedding_size)

        return emb
Exemplo n.º 38
0
    def score(self, h_t, h_s):
        """
        Args:
          h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]`
          h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]`

        Returns:
          :obj:`FloatTensor`:
           raw attention scores (unnormalized) for each src index
          `[batch x tgt_len x src_len]`

        """

        # Check input sizes
        src_batch, src_len, src_dim = h_s.size()
        tgt_batch, tgt_len, tgt_dim = h_t.size()
        aeq(src_batch, tgt_batch)
        aeq(src_dim, tgt_dim)
        aeq(self.dim, src_dim)

        if self.attn_type in ["general", "dot"]:
            if self.attn_type == "general":
                h_t_ = h_t.view(tgt_batch*tgt_len, tgt_dim)
                h_t_ = self.linear_in(h_t_)
                h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
            h_s_ = h_s.transpose(1, 2)
            # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
            return torch.bmm(h_t, h_s_)
        else:
            dim = self.dim
            wq = self.linear_query(h_t.view(-1, dim))
            wq = wq.view(tgt_batch, tgt_len, 1, dim)
            wq = wq.expand(tgt_batch, tgt_len, src_len, dim)

            uh = self.linear_context(h_s.contiguous().view(-1, dim))
            uh = uh.view(src_batch, 1, src_len, dim)
            uh = uh.expand(src_batch, tgt_len, src_len, dim)

            # (batch, t_len, s_len, d)
            wquh = self.tanh(wq + uh)

            return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
Exemplo n.º 39
0
    def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None):
        """
        See StdRNNDecoder._run_forward_pass() for description
        of arguments and return values.
        """
        # Additional args check.
        input_feed = state.input_feed.squeeze(0)
        input_feed_batch, _ = input_feed.size()
        tgt_len, tgt_batch, _ = tgt.size()
        aeq(tgt_batch, input_feed_batch)
        # END Additional args check.

        # Initialize local and return variables.
        decoder_outputs = []
        attns = {"std": []}
        if self._copy:
            attns["copy"] = []
        if self._coverage:
            attns["coverage"] = []

        emb = self.embeddings(tgt)
        assert emb.dim() == 3  # len x batch x embedding_dim

        hidden = state.hidden
        coverage = state.coverage.squeeze(0) \
            if state.coverage is not None else None

        # Input feed concatenates hidden state with
        # input at every time step.
        for i, emb_t in enumerate(emb.split(1)):
            emb_t = emb_t.squeeze(0)
            decoder_input = torch.cat([emb_t, input_feed], 1)

            rnn_output, hidden = self.rnn(decoder_input, hidden)
            decoder_output, p_attn = self.attn(
                rnn_output,
                memory_bank.transpose(0, 1),
                memory_lengths=memory_lengths)
            if self.context_gate is not None:
                # TODO: context gate should be employed
                # instead of second RNN transform.
                decoder_output = self.context_gate(
                    decoder_input, rnn_output, decoder_output
                )
            decoder_output = self.dropout(decoder_output)
            input_feed = decoder_output

            decoder_outputs += [decoder_output]
            attns["std"] += [p_attn]

            # Update the coverage attention.
            if self._coverage:
                coverage = coverage + p_attn \
                    if coverage is not None else p_attn
                attns["coverage"] += [coverage]

            # Run the forward pass of the copy attention layer.
            if self._copy and not self._reuse_copy_attn:
                _, copy_attn = self.copy_attn(decoder_output,
                                              memory_bank.transpose(0, 1))
                attns["copy"] += [copy_attn]
            elif self._copy:
                attns["copy"] = attns["std"]
        # Return result.
        return hidden, decoder_outputs, attns
Exemplo n.º 40
0
    def forward(self, input, memory_bank, memory_lengths=None, coverage=None):
        """

        Args:
          input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
          memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
          memory_lengths (`LongTensor`): the source context lengths `[batch]`
          coverage (`FloatTensor`): None (not supported yet)

        Returns:
          (`FloatTensor`, `FloatTensor`):

          * Computed vector `[tgt_len x batch x dim]`
          * Attention distribtutions for each query
             `[tgt_len x batch x src_len]`
        """

        # one step input
        if input.dim() == 2:
            one_step = True
            input = input.unsqueeze(1)
        else:
            one_step = False

        batch, sourceL, dim = memory_bank.size()
        batch_, targetL, dim_ = input.size()
        aeq(batch, batch_)
        aeq(dim, dim_)
        aeq(self.dim, dim)
        if coverage is not None:
            batch_, sourceL_ = coverage.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        if coverage is not None:
            cover = coverage.view(-1).unsqueeze(1)
            memory_bank += self.linear_cover(cover).view_as(memory_bank)
            memory_bank = self.tanh(memory_bank)

        # compute attention scores, as in Luong et al.
        align = self.score(input, memory_bank)

        if memory_lengths is not None:
            mask = sequence_mask(memory_lengths)
            mask = mask.unsqueeze(1)  # Make it broadcastable.
            align.data.masked_fill_(1 - mask, -float('inf'))

        # Softmax to normalize attention weights
        align_vectors = self.sm(align.view(batch*targetL, sourceL))
        align_vectors = align_vectors.view(batch, targetL, sourceL)

        # each context vector c_t is the weighted average
        # over all the source hidden states
        c = torch.bmm(align_vectors, memory_bank)

        # concatenate
        concat_c = torch.cat([c, input], 2).view(batch*targetL, dim*2)
        attn_h = self.linear_out(concat_c).view(batch, targetL, dim)
        if self.attn_type in ["general", "dot"]:
            attn_h = self.tanh(attn_h)

        if one_step:
            attn_h = attn_h.squeeze(1)
            align_vectors = align_vectors.squeeze(1)

            # Check output sizes
            batch_, dim_ = attn_h.size()
            aeq(batch, batch_)
            aeq(dim, dim_)
            batch_, sourceL_ = align_vectors.size()
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)
        else:
            attn_h = attn_h.transpose(0, 1).contiguous()
            align_vectors = align_vectors.transpose(0, 1).contiguous()

            # Check output sizes
            targetL_, batch_, dim_ = attn_h.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(dim, dim_)
            targetL_, batch_, sourceL_ = align_vectors.size()
            aeq(targetL, targetL_)
            aeq(batch, batch_)
            aeq(sourceL, sourceL_)

        return attn_h, align_vectors
Exemplo n.º 41
0
 def _check_args(self, input, lengths=None, hidden=None):
     s_len, n_batch, n_feats = input.size()
     if lengths is not None:
         n_batch_, = lengths.size()
         aeq(n_batch, n_batch_)
Exemplo n.º 42
0
    def forward(self, key, value, query, mask=None):
        """
        Compute the context vector and the attention vectors.

        Args:
           key (`FloatTensor`): set of `key_len`
                key vectors `[batch, key_len, dim]`
           value (`FloatTensor`): set of `key_len`
                value vectors `[batch, key_len, dim]`
           query (`FloatTensor`): set of `query_len`
                 query vectors  `[batch, query_len, dim]`
           mask: binary mask indicating which keys have
                 non-zero attention `[batch, query_len, key_len]`
        Returns:
           (`FloatTensor`, `FloatTensor`) :

           * output context vectors `[batch, query_len, dim]`
           * one of the attention vectors `[batch, query_len, key_len]`
        """

        # CHECKS
        batch, k_len, d = key.size()
        batch_, k_len_, d_ = value.size()
        aeq(batch, batch_)
        aeq(k_len, k_len_)
        aeq(d, d_)
        batch_, q_len, d_ = query.size()
        aeq(batch, batch_)
        aeq(d, d_)
        aeq(self.model_dim % 8, 0)
        if mask is not None:
            batch_, q_len_, k_len_ = mask.size()
            aeq(batch_, batch)
            aeq(k_len_, k_len)
            aeq(q_len_ == q_len)
        # END CHECKS

        batch_size = key.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count
        key_len = key.size(1)
        query_len = query.size(1)

        def shape(x):
            return x.view(batch_size, -1, head_count, dim_per_head) \
                .transpose(1, 2)

        def unshape(x):
            return x.transpose(1, 2).contiguous() \
                    .view(batch_size, -1, head_count * dim_per_head)

        # 1) Project key, value, and query.
        key_up = shape(self.linear_keys(key))
        value_up = shape(self.linear_values(value))
        query_up = shape(self.linear_query(query))

        # 2) Calculate and scale scores.
        query_up = query_up / math.sqrt(dim_per_head)
        scores = torch.matmul(query_up, key_up.transpose(2, 3))

        if mask is not None:
            mask = mask.unsqueeze(1).expand_as(scores)
            scores = scores.masked_fill(Variable(mask), -1e18)

        # 3) Apply attention dropout and compute context vectors.
        attn = self.sm(scores)
        drop_attn = self.dropout(attn)
        context = unshape(torch.matmul(drop_attn, value_up))

        output = self.final_linear(context)
        # CHECK
        batch_, q_len_, d_ = output.size()
        aeq(q_len, q_len_)
        aeq(batch, batch_)
        aeq(d, d_)

        # Return one attn
        top_attn = attn \
            .view(batch_size, head_count,
                  query_len, key_len)[:, 0, :, :] \
            .contiguous()
        # END CHECK
        return output, top_attn
Exemplo n.º 43
0
    def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
                previous_input=None):
        # Args Checks
        input_batch, input_len, _ = inputs.size()
        if previous_input is not None:
            pi_batch, _, _ = previous_input.size()
            aeq(pi_batch, input_batch)
        contxt_batch, contxt_len, _ = memory_bank.size()
        aeq(input_batch, contxt_batch)

        src_batch, t_len, s_len = src_pad_mask.size()
        tgt_batch, t_len_, t_len__ = tgt_pad_mask.size()
        aeq(input_batch, contxt_batch, src_batch, tgt_batch)
        # aeq(t_len, t_len_, t_len__, input_len)
        aeq(s_len, contxt_len)
        # END Args Checks

        dec_mask = torch.gt(tgt_pad_mask +
                            self.mask[:, :tgt_pad_mask.size(1),
                                      :tgt_pad_mask.size(1)], 0)
        input_norm = self.layer_norm_1(inputs)
        all_input = input_norm
        if previous_input is not None:
            all_input = torch.cat((previous_input, input_norm), dim=1)
            dec_mask = None
        query, attn = self.self_attn(all_input, all_input, input_norm,
                                     mask=dec_mask)
        query = self.drop(query) + inputs

        query_norm = self.layer_norm_2(query)
        mid, attn = self.context_attn(memory_bank, memory_bank, query_norm,
                                      mask=src_pad_mask)
        output = self.feed_forward(self.drop(mid) + query)

        # CHECKS
        output_batch, output_len, _ = output.size()
        aeq(input_len, output_len)
        aeq(contxt_batch, output_batch)

        n_batch_, t_len_, s_len_ = attn.size()
        aeq(input_batch, n_batch_)
        aeq(contxt_len, s_len_)
        aeq(input_len, t_len_)
        # END CHECKS

        return output, attn, all_input
Exemplo n.º 44
0
    def forward(self, tgt, memory_bank, state, memory_lengths=None):
        """
        See :obj:`onmt.modules.RNNDecoderBase.forward()`
        """
        # CHECKS
        assert isinstance(state, TransformerDecoderState)
        tgt_len, tgt_batch, _ = tgt.size()
        memory_len, memory_batch, _ = memory_bank.size()
        aeq(tgt_batch, memory_batch)

        src = state.src
        src_words = src[:, :, 0].transpose(0, 1)
        tgt_words = tgt[:, :, 0].transpose(0, 1)
        src_batch, src_len = src_words.size()
        tgt_batch, tgt_len = tgt_words.size()
        aeq(tgt_batch, memory_batch, src_batch, tgt_batch)
        aeq(memory_len, src_len)

        if state.previous_input is not None:
            tgt = torch.cat([state.previous_input, tgt], 0)
        # END CHECKS

        # Initialize return variables.
        outputs = []
        attns = {"std": []}
        if self._copy:
            attns["copy"] = []

        # Run the forward pass of the TransformerDecoder.
        emb = self.embeddings(tgt)
        if state.previous_input is not None:
            emb = emb[state.previous_input.size(0):, ]
        assert emb.dim() == 3  # len x batch x embedding_dim

        output = emb.transpose(0, 1).contiguous()
        src_memory_bank = memory_bank.transpose(0, 1).contiguous()

        padding_idx = self.embeddings.word_padding_idx
        src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \
            .expand(src_batch, tgt_len, src_len)
        tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \
            .expand(tgt_batch, tgt_len, tgt_len)

        saved_inputs = []
        for i in range(self.num_layers):
            prev_layer_input = None
            if state.previous_input is not None:
                prev_layer_input = state.previous_layer_inputs[i]
            output, attn, all_input \
                = self.transformer_layers[i](output, src_memory_bank,
                                             src_pad_mask, tgt_pad_mask,
                                             previous_input=prev_layer_input)
            saved_inputs.append(all_input)

        saved_inputs = torch.stack(saved_inputs)
        output = self.layer_norm(output)

        # Process the result and update the attentions.
        outputs = output.transpose(0, 1).contiguous()
        attn = attn.transpose(0, 1).contiguous()

        attns["std"] = attn
        if self._copy:
            attns["copy"] = attn

        # Update the state.
        state = state.update_state(tgt, saved_inputs)
        return outputs, state, attns